mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 16:35:59 +00:00
refactor(routers): reorganize routers with new langgraph/ subdirectory
Restructure app/gateway/routers/: - Add langgraph/ subdirectory for LangGraph-related endpoints: - threads.py - thread management - runs.py - run execution and streaming - feedback.py - feedback endpoints - suggestions.py - follow-up suggestions Remove old standalone routers: - threads.py → langgraph/threads.py - thread_runs.py → langgraph/runs.py - runs.py (stateless) → langgraph/runs.py - feedback.py → langgraph/feedback.py Update existing routers: - memory.py, uploads.py, artifacts.py, suggestions.py Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1,3 +1,3 @@
|
|||||||
from . import artifacts, assistants_compat, mcp, models, skills, suggestions, thread_runs, threads, uploads
|
from . import artifacts, mcp, models, skills, suggestions, uploads
|
||||||
|
|
||||||
__all__ = ["artifacts", "assistants_compat", "mcp", "models", "skills", "suggestions", "threads", "thread_runs", "uploads"]
|
__all__ = ["artifacts", "mcp", "models", "skills", "suggestions", "uploads"]
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from urllib.parse import quote
|
|||||||
from fastapi import APIRouter, HTTPException, Request
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
from fastapi.responses import FileResponse, PlainTextResponse, Response
|
from fastapi.responses import FileResponse, PlainTextResponse, Response
|
||||||
|
|
||||||
from app.gateway.authz import require_permission
|
|
||||||
from app.gateway.path_utils import resolve_thread_virtual_path
|
from app.gateway.path_utils import resolve_thread_virtual_path
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -82,7 +81,6 @@ def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> byte
|
|||||||
summary="Get Artifact File",
|
summary="Get Artifact File",
|
||||||
description="Retrieve an artifact file generated by the AI agent. Text and binary files can be viewed inline, while active web content is always downloaded.",
|
description="Retrieve an artifact file generated by the AI agent. Text and binary files can be viewed inline, while active web content is always downloaded.",
|
||||||
)
|
)
|
||||||
@require_permission("threads", "read", owner_check=True)
|
|
||||||
async def get_artifact(thread_id: str, path: str, request: Request, download: bool = False) -> Response:
|
async def get_artifact(thread_id: str, path: str, request: Request, download: bool = False) -> Response:
|
||||||
"""Get an artifact file by its path.
|
"""Get an artifact file by its path.
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,6 @@
|
|||||||
|
from .feedback import router as feedback_router
|
||||||
|
from .runs import router as runs_router
|
||||||
|
from .suggestions import router as suggestion_router
|
||||||
|
from .threads import router as threads_router
|
||||||
|
|
||||||
|
__all__ = ["feedback_router", "runs_router", "threads_router", "suggestion_router"]
|
||||||
+84
-93
@@ -1,8 +1,4 @@
|
|||||||
"""Feedback endpoints — create, list, stats, delete.
|
"""LangGraph-compatible run feedback endpoints."""
|
||||||
|
|
||||||
Allows users to submit thumbs-up/down feedback on runs,
|
|
||||||
optionally scoped to a specific message.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@@ -12,16 +8,12 @@ from typing import Any
|
|||||||
from fastapi import APIRouter, HTTPException, Request
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from app.gateway.authz import require_permission
|
from app.gateway.dependencies import get_feedback_repository, get_run_repository
|
||||||
from app.gateway.deps import get_current_user, get_feedback_repo, get_run_store
|
from app.plugins.auth.security.actor_context import bind_request_actor_context, resolve_request_user_id
|
||||||
|
from app.plugins.auth.security.dependencies import get_current_user_id
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter(prefix="/api/threads", tags=["feedback"])
|
router = APIRouter(tags=["feedback"])
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Request / response models
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class FeedbackCreateRequest(BaseModel):
|
class FeedbackCreateRequest(BaseModel):
|
||||||
@@ -30,16 +22,11 @@ class FeedbackCreateRequest(BaseModel):
|
|||||||
message_id: str | None = Field(default=None, description="Optional: scope feedback to a specific message")
|
message_id: str | None = Field(default=None, description="Optional: scope feedback to a specific message")
|
||||||
|
|
||||||
|
|
||||||
class FeedbackUpsertRequest(BaseModel):
|
|
||||||
rating: int = Field(..., description="Feedback rating: +1 (positive) or -1 (negative)")
|
|
||||||
comment: str | None = Field(default=None, description="Optional text feedback")
|
|
||||||
|
|
||||||
|
|
||||||
class FeedbackResponse(BaseModel):
|
class FeedbackResponse(BaseModel):
|
||||||
feedback_id: str
|
feedback_id: str
|
||||||
run_id: str
|
run_id: str
|
||||||
thread_id: str
|
thread_id: str
|
||||||
user_id: str | None = None
|
owner_id: str | None = None
|
||||||
message_id: str | None = None
|
message_id: str | None = None
|
||||||
rating: int
|
rating: int
|
||||||
comment: str | None = None
|
comment: str | None = None
|
||||||
@@ -53,85 +40,36 @@ class FeedbackStatsResponse(BaseModel):
|
|||||||
negative: int = 0
|
negative: int = 0
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
async def _validate_run_scope(thread_id: str, run_id: str, request: Request) -> None:
|
||||||
# Endpoints
|
run_store = get_run_repository(request)
|
||||||
# ---------------------------------------------------------------------------
|
if resolve_request_user_id(request) is None:
|
||||||
|
run = await run_store.get(run_id, user_id=None)
|
||||||
|
else:
|
||||||
@router.put("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
|
with bind_request_actor_context(request):
|
||||||
@require_permission("threads", "write", owner_check=True, require_existing=True)
|
run = await run_store.get(run_id)
|
||||||
async def upsert_feedback(
|
|
||||||
thread_id: str,
|
|
||||||
run_id: str,
|
|
||||||
body: FeedbackUpsertRequest,
|
|
||||||
request: Request,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Create or update feedback for a run (idempotent)."""
|
|
||||||
if body.rating not in (1, -1):
|
|
||||||
raise HTTPException(status_code=400, detail="rating must be +1 or -1")
|
|
||||||
|
|
||||||
user_id = await get_current_user(request)
|
|
||||||
|
|
||||||
run_store = get_run_store(request)
|
|
||||||
run = await run_store.get(run_id)
|
|
||||||
if run is None:
|
if run is None:
|
||||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||||
if run.get("thread_id") != thread_id:
|
if run.get("thread_id") != thread_id:
|
||||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found in thread {thread_id}")
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found in thread {thread_id}")
|
||||||
|
|
||||||
feedback_repo = get_feedback_repo(request)
|
|
||||||
return await feedback_repo.upsert(
|
async def _get_current_user(request: Request) -> str | None:
|
||||||
run_id=run_id,
|
"""Extract current user id from auth dependencies when available."""
|
||||||
thread_id=thread_id,
|
return await get_current_user_id(request)
|
||||||
rating=body.rating,
|
|
||||||
user_id=user_id,
|
|
||||||
comment=body.comment,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{thread_id}/runs/{run_id}/feedback")
|
async def _create_feedback(
|
||||||
@require_permission("threads", "delete", owner_check=True, require_existing=True)
|
|
||||||
async def delete_run_feedback(
|
|
||||||
thread_id: str,
|
|
||||||
run_id: str,
|
|
||||||
request: Request,
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Delete the current user's feedback for a run."""
|
|
||||||
user_id = await get_current_user(request)
|
|
||||||
feedback_repo = get_feedback_repo(request)
|
|
||||||
deleted = await feedback_repo.delete_by_run(
|
|
||||||
thread_id=thread_id,
|
|
||||||
run_id=run_id,
|
|
||||||
user_id=user_id,
|
|
||||||
)
|
|
||||||
if not deleted:
|
|
||||||
raise HTTPException(status_code=404, detail="No feedback found for this run")
|
|
||||||
return {"success": True}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
|
|
||||||
@require_permission("threads", "write", owner_check=True, require_existing=True)
|
|
||||||
async def create_feedback(
|
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
run_id: str,
|
run_id: str,
|
||||||
body: FeedbackCreateRequest,
|
body: FeedbackCreateRequest,
|
||||||
request: Request,
|
request: Request,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Submit feedback (thumbs-up/down) for a run."""
|
|
||||||
if body.rating not in (1, -1):
|
if body.rating not in (1, -1):
|
||||||
raise HTTPException(status_code=400, detail="rating must be +1 or -1")
|
raise HTTPException(status_code=400, detail="rating must be +1 or -1")
|
||||||
|
|
||||||
user_id = await get_current_user(request)
|
await _validate_run_scope(thread_id, run_id, request)
|
||||||
|
user_id = await _get_current_user(request)
|
||||||
# Validate run exists and belongs to thread
|
feedback_repo = get_feedback_repository(request)
|
||||||
run_store = get_run_store(request)
|
|
||||||
run = await run_store.get(run_id)
|
|
||||||
if run is None:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
|
||||||
if run.get("thread_id") != thread_id:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found in thread {thread_id}")
|
|
||||||
|
|
||||||
feedback_repo = get_feedback_repo(request)
|
|
||||||
return await feedback_repo.create(
|
return await feedback_repo.create(
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
@@ -142,41 +80,94 @@ async def create_feedback(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
|
||||||
|
async def upsert_feedback(
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
body: FeedbackCreateRequest,
|
||||||
|
request: Request,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create or replace the run-level feedback record."""
|
||||||
|
feedback_repo = get_feedback_repository(request)
|
||||||
|
user_id = await _get_current_user(request)
|
||||||
|
if user_id is not None:
|
||||||
|
return await feedback_repo.upsert(
|
||||||
|
run_id=run_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
rating=body.rating,
|
||||||
|
user_id=user_id,
|
||||||
|
comment=body.comment,
|
||||||
|
)
|
||||||
|
existing = await feedback_repo.list_by_run(thread_id, run_id, limit=100, user_id=None)
|
||||||
|
for item in existing:
|
||||||
|
feedback_id = item.get("feedback_id")
|
||||||
|
if isinstance(feedback_id, str):
|
||||||
|
await feedback_repo.delete(feedback_id)
|
||||||
|
return await _create_feedback(thread_id, run_id, body, request)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
|
||||||
|
async def create_feedback(
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
body: FeedbackCreateRequest,
|
||||||
|
request: Request,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Submit feedback for a run."""
|
||||||
|
return await _create_feedback(thread_id, run_id, body, request)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{thread_id}/runs/{run_id}/feedback", response_model=list[FeedbackResponse])
|
@router.get("/{thread_id}/runs/{run_id}/feedback", response_model=list[FeedbackResponse])
|
||||||
@require_permission("threads", "read", owner_check=True)
|
|
||||||
async def list_feedback(
|
async def list_feedback(
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
run_id: str,
|
run_id: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""List all feedback for a run."""
|
"""List all feedback for a run."""
|
||||||
feedback_repo = get_feedback_repo(request)
|
feedback_repo = get_feedback_repository(request)
|
||||||
return await feedback_repo.list_by_run(thread_id, run_id)
|
user_id = await _get_current_user(request)
|
||||||
|
return await feedback_repo.list_by_run(thread_id, run_id, user_id=user_id)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{thread_id}/runs/{run_id}/feedback/stats", response_model=FeedbackStatsResponse)
|
@router.get("/{thread_id}/runs/{run_id}/feedback/stats", response_model=FeedbackStatsResponse)
|
||||||
@require_permission("threads", "read", owner_check=True)
|
|
||||||
async def feedback_stats(
|
async def feedback_stats(
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
run_id: str,
|
run_id: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Get aggregated feedback stats (positive/negative counts) for a run."""
|
"""Get aggregated feedback stats for a run."""
|
||||||
feedback_repo = get_feedback_repo(request)
|
feedback_repo = get_feedback_repository(request)
|
||||||
return await feedback_repo.aggregate_by_run(thread_id, run_id)
|
return await feedback_repo.aggregate_by_run(thread_id, run_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{thread_id}/runs/{run_id}/feedback")
|
||||||
|
async def delete_run_feedback(
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
request: Request,
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Delete all feedback records for a run."""
|
||||||
|
feedback_repo = get_feedback_repository(request)
|
||||||
|
user_id = await _get_current_user(request)
|
||||||
|
if user_id is not None:
|
||||||
|
return {"success": await feedback_repo.delete_by_run(thread_id=thread_id, run_id=run_id, user_id=user_id)}
|
||||||
|
existing = await feedback_repo.list_by_run(thread_id, run_id, limit=100, user_id=None)
|
||||||
|
for item in existing:
|
||||||
|
feedback_id = item.get("feedback_id")
|
||||||
|
if isinstance(feedback_id, str):
|
||||||
|
await feedback_repo.delete(feedback_id)
|
||||||
|
return {"success": True}
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{thread_id}/runs/{run_id}/feedback/{feedback_id}")
|
@router.delete("/{thread_id}/runs/{run_id}/feedback/{feedback_id}")
|
||||||
@require_permission("threads", "delete", owner_check=True, require_existing=True)
|
|
||||||
async def delete_feedback(
|
async def delete_feedback(
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
run_id: str,
|
run_id: str,
|
||||||
feedback_id: str,
|
feedback_id: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
) -> dict[str, bool]:
|
) -> dict[str, bool]:
|
||||||
"""Delete a feedback record."""
|
"""Delete a single feedback record."""
|
||||||
feedback_repo = get_feedback_repo(request)
|
feedback_repo = get_feedback_repository(request)
|
||||||
# Verify feedback belongs to the specified thread/run before deleting
|
|
||||||
existing = await feedback_repo.get(feedback_id)
|
existing = await feedback_repo.get(feedback_id)
|
||||||
if existing is None:
|
if existing is None:
|
||||||
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found")
|
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found")
|
||||||
@@ -0,0 +1,501 @@
|
|||||||
|
"""LangGraph-compatible runs endpoints backed by RunsFacade."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from fastapi.responses import Response, StreamingResponse
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.plugins.auth.security.actor_context import bind_request_actor_context
|
||||||
|
from app.gateway.services.runs.facade_factory import build_runs_facade_from_request
|
||||||
|
from app.gateway.services.runs.input import (
|
||||||
|
AdaptedRunRequest,
|
||||||
|
RunSpecBuilder,
|
||||||
|
UnsupportedRunFeatureError,
|
||||||
|
adapt_create_run_request,
|
||||||
|
adapt_create_stream_request,
|
||||||
|
adapt_create_wait_request,
|
||||||
|
adapt_join_stream_request,
|
||||||
|
adapt_join_wait_request,
|
||||||
|
)
|
||||||
|
from deerflow.runtime.runs.types import RunRecord, RunSpec
|
||||||
|
from deerflow.runtime.stream_bridge import JSONValue, StreamEvent
|
||||||
|
|
||||||
|
router = APIRouter(tags=["runs"])
|
||||||
|
|
||||||
|
|
||||||
|
class RunCreateRequest(BaseModel):
|
||||||
|
assistant_id: str | None = Field(default=None, description="Agent / assistant to use")
|
||||||
|
follow_up_to_run_id: str | None = Field(default=None, description="Lineage link to the prior run")
|
||||||
|
input: dict[str, JSONValue] | None = Field(default=None, description="Graph input (e.g. {messages: [...]})")
|
||||||
|
command: dict[str, JSONValue] | None = Field(default=None, description="LangGraph Command")
|
||||||
|
metadata: dict[str, JSONValue] | None = Field(default=None, description="Run metadata")
|
||||||
|
config: dict[str, JSONValue] | None = Field(default=None, description="RunnableConfig overrides")
|
||||||
|
context: dict[str, JSONValue] | None = Field(default=None, description="DeerFlow context overrides (model_name, thinking_enabled, etc.)")
|
||||||
|
webhook: str | None = Field(default=None, description="Completion callback URL")
|
||||||
|
checkpoint_id: str | None = Field(default=None, description="Resume from checkpoint")
|
||||||
|
checkpoint: dict[str, JSONValue] | None = Field(default=None, description="Full checkpoint object")
|
||||||
|
interrupt_before: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt before")
|
||||||
|
interrupt_after: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt after")
|
||||||
|
stream_mode: list[str] | str | None = Field(default=None, description="Stream mode(s)")
|
||||||
|
stream_subgraphs: bool = Field(default=False, description="Include subgraph events")
|
||||||
|
stream_resumable: bool | None = Field(default=None, description="SSE resumable mode")
|
||||||
|
on_disconnect: Literal["cancel", "continue"] = Field(default="cancel", description="Behaviour on SSE disconnect")
|
||||||
|
on_completion: Literal["delete", "keep"] = Field(default="keep", description="Delete temp thread on completion")
|
||||||
|
multitask_strategy: Literal["reject", "rollback", "interrupt", "enqueue"] = Field(default="reject", description="Concurrency strategy")
|
||||||
|
after_seconds: float | None = Field(default=None, description="Delayed execution")
|
||||||
|
if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy")
|
||||||
|
feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys")
|
||||||
|
|
||||||
|
|
||||||
|
class RunResponse(BaseModel):
|
||||||
|
run_id: str
|
||||||
|
thread_id: str
|
||||||
|
assistant_id: str | None = None
|
||||||
|
status: str
|
||||||
|
metadata: dict[str, JSONValue] = Field(default_factory=dict)
|
||||||
|
multitask_strategy: str = "reject"
|
||||||
|
created_at: str = ""
|
||||||
|
updated_at: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class RunDeleteResponse(BaseModel):
|
||||||
|
deleted: bool
|
||||||
|
|
||||||
|
|
||||||
|
class RunMessageResponse(BaseModel):
|
||||||
|
run_id: str
|
||||||
|
content: JSONValue
|
||||||
|
metadata: dict[str, JSONValue] = Field(default_factory=dict)
|
||||||
|
created_at: str
|
||||||
|
seq: int
|
||||||
|
|
||||||
|
|
||||||
|
class RunMessagesResponse(BaseModel):
|
||||||
|
data: list[RunMessageResponse]
|
||||||
|
hasMore: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def format_sse(event: str, data: JSONValue, *, event_id: str | None = None) -> str:
|
||||||
|
"""Format a single SSE frame."""
|
||||||
|
payload = json.dumps(data, default=str, ensure_ascii=False)
|
||||||
|
parts = [f"event: {event}", f"data: {payload}"]
|
||||||
|
if event_id:
|
||||||
|
parts.append(f"id: {event_id}")
|
||||||
|
parts.append("")
|
||||||
|
parts.append("")
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _record_to_response(record: RunRecord) -> RunResponse:
|
||||||
|
return RunResponse(
|
||||||
|
run_id=record.run_id,
|
||||||
|
thread_id=record.thread_id,
|
||||||
|
assistant_id=record.assistant_id,
|
||||||
|
status=record.status,
|
||||||
|
metadata=record.metadata,
|
||||||
|
multitask_strategy=record.multitask_strategy,
|
||||||
|
created_at=record.created_at,
|
||||||
|
updated_at=record.updated_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _trim_paginated_rows(
|
||||||
|
rows: list[dict],
|
||||||
|
*,
|
||||||
|
limit: int,
|
||||||
|
after_seq: int | None,
|
||||||
|
) -> tuple[list[dict], bool]:
|
||||||
|
has_more = len(rows) > limit
|
||||||
|
if not has_more:
|
||||||
|
return rows, False
|
||||||
|
if after_seq is not None:
|
||||||
|
return rows[:limit], True
|
||||||
|
return rows[-limit:], True
|
||||||
|
|
||||||
|
|
||||||
|
def _event_to_run_message(event: dict) -> RunMessageResponse:
|
||||||
|
return RunMessageResponse(
|
||||||
|
run_id=str(event["run_id"]),
|
||||||
|
content=event.get("content"),
|
||||||
|
metadata=dict(event.get("metadata") or {}),
|
||||||
|
created_at=str(event.get("created_at") or ""),
|
||||||
|
seq=int(event["seq"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _sse_consumer(
|
||||||
|
stream: AsyncIterator[StreamEvent],
|
||||||
|
request: Request,
|
||||||
|
*,
|
||||||
|
cancel_on_disconnect: bool,
|
||||||
|
cancel_run,
|
||||||
|
run_id: str,
|
||||||
|
) -> AsyncIterator[str]:
|
||||||
|
try:
|
||||||
|
async for event in stream:
|
||||||
|
if await request.is_disconnected():
|
||||||
|
break
|
||||||
|
|
||||||
|
if event.event == "__heartbeat__":
|
||||||
|
yield ": heartbeat\n\n"
|
||||||
|
continue
|
||||||
|
|
||||||
|
if event.event == "__end__":
|
||||||
|
yield format_sse("end", None, event_id=event.id or None)
|
||||||
|
return
|
||||||
|
|
||||||
|
if event.event == "__cancelled__":
|
||||||
|
yield format_sse("cancel", None, event_id=event.id or None)
|
||||||
|
return
|
||||||
|
|
||||||
|
yield format_sse(event.event, event.data, event_id=event.id or None)
|
||||||
|
finally:
|
||||||
|
if cancel_on_disconnect:
|
||||||
|
await cancel_run(run_id)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_run_event_store(request: Request):
|
||||||
|
event_store = getattr(request.app.state, "run_event_store", None)
|
||||||
|
if event_store is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Run event store not available")
|
||||||
|
return event_store
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{thread_id}/runs", response_model=list[RunResponse])
|
||||||
|
async def list_runs(
|
||||||
|
thread_id: str,
|
||||||
|
request: Request,
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
status: str | None = None,
|
||||||
|
) -> list[RunResponse]:
|
||||||
|
# Accepted for API compatibility; field projection is not implemented yet.
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
records = await facade.list_runs(thread_id)
|
||||||
|
if status is not None:
|
||||||
|
records = [record for record in records if record.status == status]
|
||||||
|
records = records[offset : offset + limit]
|
||||||
|
return [_record_to_response(record) for record in records]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{thread_id}/runs/{run_id}", response_model=RunResponse)
|
||||||
|
async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record = await facade.get_run(run_id)
|
||||||
|
if record is None or record.thread_id != thread_id:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||||
|
return _record_to_response(record)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{thread_id}/runs/{run_id}/messages", response_model=RunMessagesResponse)
|
||||||
|
async def run_messages(
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
request: Request,
|
||||||
|
limit: int = 50,
|
||||||
|
before_seq: int | None = None,
|
||||||
|
after_seq: int | None = None,
|
||||||
|
) -> RunMessagesResponse:
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record = await facade.get_run(run_id)
|
||||||
|
if record is None or record.thread_id != thread_id:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||||
|
|
||||||
|
event_store = _get_run_event_store(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
rows = await event_store.list_messages_by_run(
|
||||||
|
thread_id,
|
||||||
|
run_id,
|
||||||
|
limit=limit + 1,
|
||||||
|
before_seq=before_seq,
|
||||||
|
after_seq=after_seq,
|
||||||
|
)
|
||||||
|
page, has_more = _trim_paginated_rows(rows, limit=limit, after_seq=after_seq)
|
||||||
|
return RunMessagesResponse(data=[_event_to_run_message(row) for row in page], hasMore=has_more)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_spec(
|
||||||
|
*,
|
||||||
|
adapted: AdaptedRunRequest,
|
||||||
|
) -> RunSpec:
|
||||||
|
try:
|
||||||
|
return RunSpecBuilder().build(adapted)
|
||||||
|
except UnsupportedRunFeatureError as exc:
|
||||||
|
raise HTTPException(status_code=501, detail=str(exc)) from exc
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{thread_id}/runs", response_model=RunResponse)
|
||||||
|
async def create_run(
|
||||||
|
thread_id: str,
|
||||||
|
body: RunCreateRequest,
|
||||||
|
request: Request,
|
||||||
|
) -> Response:
|
||||||
|
adapted = adapt_create_run_request(
|
||||||
|
thread_id=thread_id,
|
||||||
|
body=body.model_dump(),
|
||||||
|
headers=dict(request.headers),
|
||||||
|
query=dict(request.query_params),
|
||||||
|
)
|
||||||
|
spec = _build_spec(adapted=adapted)
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record = await facade.create_background(spec)
|
||||||
|
return Response(
|
||||||
|
content=_record_to_response(record).model_dump_json(),
|
||||||
|
media_type="application/json",
|
||||||
|
headers={"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{thread_id}/runs/stream")
|
||||||
|
async def stream_run(
|
||||||
|
thread_id: str,
|
||||||
|
body: RunCreateRequest,
|
||||||
|
request: Request,
|
||||||
|
) -> StreamingResponse:
|
||||||
|
adapted = adapt_create_stream_request(
|
||||||
|
thread_id=thread_id,
|
||||||
|
body=body.model_dump(),
|
||||||
|
headers=dict(request.headers),
|
||||||
|
query=dict(request.query_params),
|
||||||
|
)
|
||||||
|
|
||||||
|
spec = _build_spec(adapted=adapted)
|
||||||
|
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record, stream = await facade.create_and_stream(spec)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
_sse_consumer(
|
||||||
|
stream,
|
||||||
|
request,
|
||||||
|
cancel_on_disconnect=spec.on_disconnect == "cancel",
|
||||||
|
cancel_run=facade.cancel,
|
||||||
|
run_id=record.run_id,
|
||||||
|
),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{thread_id}/runs/wait")
|
||||||
|
async def wait_run(
|
||||||
|
thread_id: str,
|
||||||
|
body: RunCreateRequest,
|
||||||
|
request: Request,
|
||||||
|
) -> Response:
|
||||||
|
adapted = adapt_create_wait_request(
|
||||||
|
thread_id=thread_id,
|
||||||
|
body=body.model_dump(),
|
||||||
|
headers=dict(request.headers),
|
||||||
|
query=dict(request.query_params),
|
||||||
|
)
|
||||||
|
spec = _build_spec(adapted=adapted)
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record, result = await facade.create_and_wait(spec)
|
||||||
|
return Response(
|
||||||
|
content=json.dumps(result, default=str, ensure_ascii=False),
|
||||||
|
media_type="application/json",
|
||||||
|
headers={"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/runs", response_model=RunResponse)
|
||||||
|
async def create_stateless_run(body: RunCreateRequest, request: Request) -> Response:
|
||||||
|
adapted = adapt_create_run_request(
|
||||||
|
thread_id=None,
|
||||||
|
body=body.model_dump(),
|
||||||
|
headers=dict(request.headers),
|
||||||
|
query=dict(request.query_params),
|
||||||
|
)
|
||||||
|
spec = _build_spec(adapted=adapted)
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record = await facade.create_background(spec)
|
||||||
|
return Response(
|
||||||
|
content=_record_to_response(record).model_dump_json(),
|
||||||
|
media_type="application/json",
|
||||||
|
headers={"Content-Location": f"/api/threads/{record.thread_id}/runs/{record.run_id}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/runs/stream")
|
||||||
|
async def create_stateless_stream_run(body: RunCreateRequest, request: Request) -> StreamingResponse:
|
||||||
|
adapted = adapt_create_stream_request(
|
||||||
|
thread_id=None,
|
||||||
|
body=body.model_dump(),
|
||||||
|
headers=dict(request.headers),
|
||||||
|
query=dict(request.query_params),
|
||||||
|
)
|
||||||
|
spec = _build_spec(adapted=adapted)
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record, stream = await facade.create_and_stream(spec)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
_sse_consumer(
|
||||||
|
stream,
|
||||||
|
request,
|
||||||
|
cancel_on_disconnect=spec.on_disconnect == "cancel",
|
||||||
|
cancel_run=facade.cancel,
|
||||||
|
run_id=record.run_id,
|
||||||
|
),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
"Content-Location": f"/api/threads/{record.thread_id}/runs/{record.run_id}",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/runs/wait")
|
||||||
|
async def wait_stateless_run(body: RunCreateRequest, request: Request) -> Response:
|
||||||
|
adapted = adapt_create_wait_request(
|
||||||
|
thread_id=None,
|
||||||
|
body=body.model_dump(),
|
||||||
|
headers=dict(request.headers),
|
||||||
|
query=dict(request.query_params),
|
||||||
|
)
|
||||||
|
spec = _build_spec(adapted=adapted)
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record, result = await facade.create_and_wait(spec)
|
||||||
|
return Response(
|
||||||
|
content=json.dumps(result, default=str, ensure_ascii=False),
|
||||||
|
media_type="application/json",
|
||||||
|
headers={"Content-Location": f"/api/threads/{record.thread_id}/runs/{record.run_id}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.api_route("/{thread_id}/runs/{run_id}/stream", methods=["GET", "POST"], response_model=None)
|
||||||
|
async def stream_existing_run(
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
request: Request,
|
||||||
|
action: Literal["interrupt", "rollback"] | None = None,
|
||||||
|
wait: bool = False,
|
||||||
|
cancel_on_disconnect: bool = False,
|
||||||
|
stream_mode: str | None = None,
|
||||||
|
) -> StreamingResponse | Response:
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record = await facade.get_run(run_id)
|
||||||
|
if record is None or record.thread_id != thread_id:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||||
|
|
||||||
|
if action is not None:
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
cancelled = await facade.cancel(run_id, action=action)
|
||||||
|
if not cancelled:
|
||||||
|
raise HTTPException(status_code=409, detail=f"Run {run_id} is not cancellable")
|
||||||
|
if wait:
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
await facade.join_wait(run_id)
|
||||||
|
return Response(status_code=204)
|
||||||
|
|
||||||
|
adapted = adapt_join_stream_request(
|
||||||
|
thread_id=thread_id,
|
||||||
|
run_id=run_id,
|
||||||
|
headers=dict(request.headers),
|
||||||
|
query=dict(request.query_params),
|
||||||
|
)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
stream = await facade.join_stream(run_id, last_event_id=adapted.last_event_id)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
_sse_consumer(
|
||||||
|
stream,
|
||||||
|
request,
|
||||||
|
cancel_on_disconnect=cancel_on_disconnect,
|
||||||
|
cancel_run=facade.cancel,
|
||||||
|
run_id=run_id,
|
||||||
|
),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{thread_id}/runs/{run_id}/join")
|
||||||
|
async def join_existing_run(
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
request: Request,
|
||||||
|
cancel_on_disconnect: bool = False,
|
||||||
|
) -> JSONValue:
|
||||||
|
# Accepted for API compatibility; current join_wait path does not change
|
||||||
|
# behavior based on client disconnect.
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record = await facade.get_run(run_id)
|
||||||
|
if record is None or record.thread_id != thread_id:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||||
|
|
||||||
|
adapted = adapt_join_wait_request(
|
||||||
|
thread_id=thread_id,
|
||||||
|
run_id=run_id,
|
||||||
|
headers=dict(request.headers),
|
||||||
|
query=dict(request.query_params),
|
||||||
|
)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
return await facade.join_wait(run_id, last_event_id=adapted.last_event_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{thread_id}/runs/{run_id}/cancel")
|
||||||
|
async def cancel_existing_run(
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
request: Request,
|
||||||
|
wait: bool = False,
|
||||||
|
action: Literal["interrupt", "rollback"] = "interrupt",
|
||||||
|
) -> JSONValue:
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record = await facade.get_run(run_id)
|
||||||
|
if record is None or record.thread_id != thread_id:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||||
|
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
cancelled = await facade.cancel(run_id, action=action)
|
||||||
|
if not cancelled:
|
||||||
|
raise HTTPException(status_code=409, detail=f"Run {run_id} is not cancellable")
|
||||||
|
if wait:
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
return await facade.join_wait(run_id)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{thread_id}/runs/{run_id}", response_model=RunDeleteResponse)
|
||||||
|
async def delete_run(
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
request: Request,
|
||||||
|
) -> RunDeleteResponse:
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record = await facade.get_run(run_id)
|
||||||
|
if record is None or record.thread_id != thread_id:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
deleted = await facade.delete_run(run_id)
|
||||||
|
return RunDeleteResponse(deleted=deleted)
|
||||||
@@ -0,0 +1,132 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from deerflow.models import create_chat_model
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api", tags=["suggestions"])
|
||||||
|
|
||||||
|
|
||||||
|
class SuggestionMessage(BaseModel):
|
||||||
|
role: str = Field(..., description="Message role: user|assistant")
|
||||||
|
content: str = Field(..., description="Message content as plain text")
|
||||||
|
|
||||||
|
|
||||||
|
class SuggestionsRequest(BaseModel):
|
||||||
|
messages: list[SuggestionMessage] = Field(..., description="Recent conversation messages")
|
||||||
|
n: int = Field(default=3, ge=1, le=5, description="Number of suggestions to generate")
|
||||||
|
model_name: str | None = Field(default=None, description="Optional model override")
|
||||||
|
|
||||||
|
|
||||||
|
class SuggestionsResponse(BaseModel):
|
||||||
|
suggestions: list[str] = Field(default_factory=list, description="Suggested follow-up questions")
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_markdown_code_fence(text: str) -> str:
|
||||||
|
stripped = text.strip()
|
||||||
|
if not stripped.startswith("```"):
|
||||||
|
return stripped
|
||||||
|
lines = stripped.splitlines()
|
||||||
|
if len(lines) >= 3 and lines[0].startswith("```") and lines[-1].startswith("```"):
|
||||||
|
return "\n".join(lines[1:-1]).strip()
|
||||||
|
return stripped
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_json_string_list(text: str) -> list[str] | None:
|
||||||
|
candidate = _strip_markdown_code_fence(text)
|
||||||
|
start = candidate.find("[")
|
||||||
|
end = candidate.rfind("]")
|
||||||
|
if start == -1 or end == -1 or end <= start:
|
||||||
|
return None
|
||||||
|
candidate = candidate[start : end + 1]
|
||||||
|
try:
|
||||||
|
data = json.loads(candidate)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
if not isinstance(data, list):
|
||||||
|
return None
|
||||||
|
out: list[str] = []
|
||||||
|
for item in data:
|
||||||
|
if not isinstance(item, str):
|
||||||
|
continue
|
||||||
|
s = item.strip()
|
||||||
|
if not s:
|
||||||
|
continue
|
||||||
|
out.append(s)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_response_text(content: object) -> str:
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, str):
|
||||||
|
parts.append(block)
|
||||||
|
elif isinstance(block, dict) and block.get("type") in {"text", "output_text"}:
|
||||||
|
text = block.get("text")
|
||||||
|
if isinstance(text, str):
|
||||||
|
parts.append(text)
|
||||||
|
return "\n".join(parts) if parts else ""
|
||||||
|
if content is None:
|
||||||
|
return ""
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_conversation(messages: list[SuggestionMessage]) -> str:
|
||||||
|
parts: list[str] = []
|
||||||
|
for m in messages:
|
||||||
|
role = m.role.strip().lower()
|
||||||
|
if role in ("user", "human"):
|
||||||
|
parts.append(f"User: {m.content.strip()}")
|
||||||
|
elif role in ("assistant", "ai"):
|
||||||
|
parts.append(f"Assistant: {m.content.strip()}")
|
||||||
|
else:
|
||||||
|
parts.append(f"{m.role}: {m.content.strip()}")
|
||||||
|
return "\n".join(parts).strip()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/threads/{thread_id}/suggestions",
|
||||||
|
response_model=SuggestionsResponse,
|
||||||
|
summary="Generate Follow-up Questions",
|
||||||
|
description="Generate short follow-up questions a user might ask next, based on recent conversation context.",
|
||||||
|
)
|
||||||
|
async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> SuggestionsResponse:
|
||||||
|
if not request.messages:
|
||||||
|
return SuggestionsResponse(suggestions=[])
|
||||||
|
|
||||||
|
n = request.n
|
||||||
|
conversation = _format_conversation(request.messages)
|
||||||
|
if not conversation:
|
||||||
|
return SuggestionsResponse(suggestions=[])
|
||||||
|
|
||||||
|
system_instruction = (
|
||||||
|
"You are generating follow-up questions to help the user continue the conversation.\n"
|
||||||
|
f"Based on the conversation below, produce EXACTLY {n} short questions the user might ask next.\n"
|
||||||
|
"Requirements:\n"
|
||||||
|
"- Questions must be relevant to the preceding conversation.\n"
|
||||||
|
"- Questions must be written in the same language as the user.\n"
|
||||||
|
"- Keep each question concise (ideally <= 20 words / <= 40 Chinese characters).\n"
|
||||||
|
"- Do NOT include numbering, markdown, or any extra text.\n"
|
||||||
|
"- Output MUST be a JSON array of strings only.\n"
|
||||||
|
)
|
||||||
|
user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions"
|
||||||
|
|
||||||
|
try:
|
||||||
|
model = create_chat_model(name=request.model_name, thinking_enabled=False)
|
||||||
|
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)])
|
||||||
|
raw = _extract_response_text(response.content)
|
||||||
|
suggestions = _parse_json_string_list(raw) or []
|
||||||
|
cleaned = [s.replace("\n", " ").strip() for s in suggestions if s.strip()]
|
||||||
|
cleaned = cleaned[:n]
|
||||||
|
return SuggestionsResponse(suggestions=cleaned)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to generate suggestions: thread_id=%s err=%s", thread_id, exc)
|
||||||
|
return SuggestionsResponse(suggestions=[])
|
||||||
@@ -0,0 +1,455 @@
|
|||||||
|
"""Thread management endpoints.
|
||||||
|
|
||||||
|
Provides CRUD operations for threads and checkpoint state management.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.gateway.dependencies import CurrentCheckpointer, CurrentRunRepository, CurrentThreadMetaStorage
|
||||||
|
from app.infra.storage import ThreadMetaStorage
|
||||||
|
from app.plugins.auth.security.actor_context import bind_request_actor_context, resolve_request_user_id
|
||||||
|
from deerflow.config.paths import Paths, get_paths
|
||||||
|
from deerflow.runtime import serialize_channel_values
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
router = APIRouter(tags=["threads"])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Request / Response Models
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadCreateRequest(BaseModel):
|
||||||
|
thread_id: str | None = Field(default=None, description="Optional thread ID (auto-generated if omitted)")
|
||||||
|
assistant_id: str | None = Field(default=None, description="Associate thread with an assistant")
|
||||||
|
metadata: dict[str, Any] = Field(default_factory=dict, description="Initial metadata")
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadSearchRequest(BaseModel):
|
||||||
|
metadata: dict[str, Any] = Field(default_factory=dict, description="Metadata filter (exact match)")
|
||||||
|
limit: int = Field(default=100, ge=1, le=1000, description="Maximum results")
|
||||||
|
offset: int = Field(default=0, ge=0, description="Pagination offset")
|
||||||
|
status: str | None = Field(default=None, description="Filter by thread status")
|
||||||
|
user_id: str | None = Field(default=None, description="Filter by user ID")
|
||||||
|
assistant_id: str | None = Field(default=None, description="Filter by assistant ID")
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadResponse(BaseModel):
|
||||||
|
thread_id: str = Field(description="Unique thread identifier")
|
||||||
|
status: str = Field(default="idle", description="Thread status")
|
||||||
|
created_at: str = Field(default="", description="ISO timestamp")
|
||||||
|
updated_at: str = Field(default="", description="ISO timestamp")
|
||||||
|
metadata: dict[str, Any] = Field(default_factory=dict, description="Thread metadata")
|
||||||
|
values: dict[str, Any] = Field(default_factory=dict, description="Current state values")
|
||||||
|
interrupts: dict[str, Any] = Field(default_factory=dict, description="Pending interrupts")
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadDeleteResponse(BaseModel):
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadStateUpdateRequest(BaseModel):
|
||||||
|
values: dict[str, Any] | None = Field(default=None, description="Channel values to merge")
|
||||||
|
checkpoint_id: str | None = Field(default=None, description="Checkpoint to branch from")
|
||||||
|
checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object")
|
||||||
|
as_node: str | None = Field(default=None, description="Node identity for the update")
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadStateResponse(BaseModel):
|
||||||
|
values: dict[str, Any] = Field(default_factory=dict, description="Current channel values")
|
||||||
|
next: list[str] = Field(default_factory=list, description="Next nodes to execute")
|
||||||
|
tasks: list[dict[str, Any]] = Field(default_factory=list, description="Interrupted task details")
|
||||||
|
checkpoint: dict[str, Any] = Field(default_factory=dict, description="Checkpoint info")
|
||||||
|
checkpoint_id: str | None = Field(default=None, description="Current checkpoint ID")
|
||||||
|
parent_checkpoint_id: str | None = Field(default=None, description="Parent checkpoint ID")
|
||||||
|
metadata: dict[str, Any] = Field(default_factory=dict, description="Checkpoint metadata")
|
||||||
|
created_at: str | None = Field(default=None, description="Checkpoint timestamp")
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadHistoryRequest(BaseModel):
|
||||||
|
limit: int = Field(default=10, ge=1, le=100, description="Maximum entries")
|
||||||
|
before: str | None = Field(default=None, description="Cursor for pagination (checkpoint_id)")
|
||||||
|
|
||||||
|
|
||||||
|
class HistoryEntry(BaseModel):
|
||||||
|
checkpoint_id: str
|
||||||
|
parent_checkpoint_id: str | None = None
|
||||||
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
values: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
created_at: str | None = None
|
||||||
|
next: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_log_param(value: str) -> str:
|
||||||
|
"""Strip control characters to prevent log injection."""
|
||||||
|
|
||||||
|
return value.replace("\n", "").replace("\r", "").replace("\x00", "")
|
||||||
|
|
||||||
|
|
||||||
|
def _delete_thread_data(thread_id: str, paths: Paths | None = None) -> ThreadDeleteResponse:
|
||||||
|
"""Delete local filesystem data for a thread."""
|
||||||
|
path_manager = paths or get_paths()
|
||||||
|
try:
|
||||||
|
path_manager.delete_thread_dir(thread_id)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.debug("No local thread data to delete for %s", sanitize_log_param(thread_id))
|
||||||
|
return ThreadDeleteResponse(success=True, message=f"No local data for {thread_id}")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to delete thread data for %s", sanitize_log_param(thread_id))
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to delete local thread data.") from exc
|
||||||
|
|
||||||
|
logger.info("Deleted local thread data for %s", sanitize_log_param(thread_id))
|
||||||
|
return ThreadDeleteResponse(success=True, message=f"Deleted local thread data for {thread_id}")
|
||||||
|
|
||||||
|
|
||||||
|
async def _thread_or_run_exists(
|
||||||
|
*,
|
||||||
|
request: Request,
|
||||||
|
thread_id: str,
|
||||||
|
thread_meta_storage: ThreadMetaStorage,
|
||||||
|
run_repo,
|
||||||
|
) -> bool:
|
||||||
|
request_user_id = resolve_request_user_id(request)
|
||||||
|
|
||||||
|
if request_user_id is None:
|
||||||
|
thread = await thread_meta_storage.get_thread(thread_id, user_id=None)
|
||||||
|
if thread is not None:
|
||||||
|
return True
|
||||||
|
runs = await run_repo.list_by_thread(thread_id, limit=1, user_id=None)
|
||||||
|
return bool(runs)
|
||||||
|
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
thread = await thread_meta_storage.get_thread(thread_id)
|
||||||
|
if thread is not None:
|
||||||
|
return True
|
||||||
|
runs = await run_repo.list_by_thread(thread_id, limit=1)
|
||||||
|
return bool(runs)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Endpoints
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", response_model=ThreadResponse)
|
||||||
|
async def create_thread(
|
||||||
|
body: ThreadCreateRequest,
|
||||||
|
request: Request,
|
||||||
|
thread_meta_storage: CurrentThreadMetaStorage,
|
||||||
|
) -> ThreadResponse:
|
||||||
|
"""Create a new thread."""
|
||||||
|
thread_id = body.thread_id or str(uuid.uuid4())
|
||||||
|
|
||||||
|
request_user_id = resolve_request_user_id(request)
|
||||||
|
if request_user_id is None:
|
||||||
|
existing = await thread_meta_storage.get_thread(thread_id, user_id=None)
|
||||||
|
else:
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
existing = await thread_meta_storage.get_thread(thread_id)
|
||||||
|
if existing is not None:
|
||||||
|
return ThreadResponse(
|
||||||
|
thread_id=thread_id,
|
||||||
|
status=existing.status,
|
||||||
|
created_at=existing.created_time.isoformat() if existing.created_time else "",
|
||||||
|
updated_at=existing.updated_time.isoformat() if existing.updated_time else "",
|
||||||
|
metadata=existing.metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if request_user_id is None:
|
||||||
|
created = await thread_meta_storage.ensure_thread(
|
||||||
|
thread_id=thread_id,
|
||||||
|
assistant_id=body.assistant_id,
|
||||||
|
metadata=body.metadata,
|
||||||
|
user_id=None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
created = await thread_meta_storage.ensure_thread(
|
||||||
|
thread_id=thread_id,
|
||||||
|
assistant_id=body.assistant_id,
|
||||||
|
metadata=body.metadata,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to create thread %s", sanitize_log_param(thread_id))
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to create thread")
|
||||||
|
|
||||||
|
logger.info("Thread created: %s", sanitize_log_param(thread_id))
|
||||||
|
return ThreadResponse(
|
||||||
|
thread_id=thread_id,
|
||||||
|
status=created.status,
|
||||||
|
created_at=created.created_time.isoformat() if created.created_time else "",
|
||||||
|
updated_at=created.updated_time.isoformat() if created.updated_time else "",
|
||||||
|
metadata=created.metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/search", response_model=list[ThreadResponse])
|
||||||
|
async def search_threads(
|
||||||
|
body: ThreadSearchRequest,
|
||||||
|
request: Request,
|
||||||
|
thread_meta_storage: CurrentThreadMetaStorage,
|
||||||
|
) -> list[ThreadResponse]:
|
||||||
|
"""Search threads with filters."""
|
||||||
|
try:
|
||||||
|
request_user_id = resolve_request_user_id(request)
|
||||||
|
if request_user_id is None:
|
||||||
|
threads = await thread_meta_storage.search_threads(
|
||||||
|
metadata=body.metadata or None,
|
||||||
|
status=body.status,
|
||||||
|
user_id=body.user_id,
|
||||||
|
assistant_id=body.assistant_id,
|
||||||
|
limit=body.limit,
|
||||||
|
offset=body.offset,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
threads = await thread_meta_storage.search_threads(
|
||||||
|
metadata=body.metadata or None,
|
||||||
|
status=body.status,
|
||||||
|
assistant_id=body.assistant_id,
|
||||||
|
limit=body.limit,
|
||||||
|
offset=body.offset,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to search threads")
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to search threads")
|
||||||
|
|
||||||
|
return [
|
||||||
|
ThreadResponse(
|
||||||
|
thread_id=t.thread_id,
|
||||||
|
status=t.status,
|
||||||
|
created_at=t.created_time.isoformat() if t.created_time else "",
|
||||||
|
updated_at=t.updated_time.isoformat() if t.updated_time else "",
|
||||||
|
metadata=t.metadata,
|
||||||
|
values={"title": t.display_name} if t.display_name else {},
|
||||||
|
interrupts={},
|
||||||
|
)
|
||||||
|
for t in threads
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{thread_id}", response_model=ThreadDeleteResponse)
|
||||||
|
async def delete_thread(
|
||||||
|
thread_id: str,
|
||||||
|
checkpointer: CurrentCheckpointer,
|
||||||
|
thread_meta_storage: CurrentThreadMetaStorage,
|
||||||
|
) -> ThreadDeleteResponse:
|
||||||
|
"""Delete a thread and all associated data."""
|
||||||
|
response = _delete_thread_data(thread_id)
|
||||||
|
|
||||||
|
# Remove checkpoints (best-effort)
|
||||||
|
try:
|
||||||
|
if hasattr(checkpointer, "adelete_thread"):
|
||||||
|
await checkpointer.adelete_thread(thread_id)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Could not delete checkpoints for thread %s", sanitize_log_param(thread_id))
|
||||||
|
|
||||||
|
# Remove thread_meta (best-effort)
|
||||||
|
try:
|
||||||
|
await thread_meta_storage.delete_thread(thread_id)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Could not delete thread_meta for %s", sanitize_log_param(thread_id))
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{thread_id}/state", response_model=ThreadStateResponse)
|
||||||
|
async def get_thread_state(
|
||||||
|
thread_id: str,
|
||||||
|
request: Request,
|
||||||
|
checkpointer: CurrentCheckpointer,
|
||||||
|
thread_meta_storage: CurrentThreadMetaStorage,
|
||||||
|
run_repo: CurrentRunRepository,
|
||||||
|
) -> ThreadStateResponse:
|
||||||
|
"""Get the latest state snapshot for a thread."""
|
||||||
|
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||||
|
|
||||||
|
try:
|
||||||
|
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id))
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to get thread state")
|
||||||
|
|
||||||
|
if checkpoint_tuple is None:
|
||||||
|
if await _thread_or_run_exists(
|
||||||
|
request=request,
|
||||||
|
thread_id=thread_id,
|
||||||
|
thread_meta_storage=thread_meta_storage,
|
||||||
|
run_repo=run_repo,
|
||||||
|
):
|
||||||
|
return ThreadStateResponse()
|
||||||
|
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||||
|
|
||||||
|
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||||
|
metadata = getattr(checkpoint_tuple, "metadata", {}) or {}
|
||||||
|
channel_values = checkpoint.get("channel_values", {})
|
||||||
|
|
||||||
|
ckpt_config = getattr(checkpoint_tuple, "config", {}) or {}
|
||||||
|
checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id")
|
||||||
|
|
||||||
|
parent_config = getattr(checkpoint_tuple, "parent_config", None)
|
||||||
|
parent_checkpoint_id = parent_config.get("configurable", {}).get("checkpoint_id") if parent_config else None
|
||||||
|
|
||||||
|
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
|
||||||
|
next_nodes = [t.name for t in tasks_raw if hasattr(t, "name")]
|
||||||
|
tasks = [{"id": getattr(t, "id", ""), "name": getattr(t, "name", "")} for t in tasks_raw]
|
||||||
|
|
||||||
|
return ThreadStateResponse(
|
||||||
|
values=serialize_channel_values(channel_values),
|
||||||
|
next=next_nodes,
|
||||||
|
tasks=tasks,
|
||||||
|
checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))},
|
||||||
|
checkpoint_id=checkpoint_id,
|
||||||
|
parent_checkpoint_id=parent_checkpoint_id,
|
||||||
|
metadata=metadata,
|
||||||
|
created_at=str(metadata.get("created_at", "")),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{thread_id}/state", response_model=ThreadStateResponse)
|
||||||
|
async def update_thread_state(
|
||||||
|
thread_id: str,
|
||||||
|
body: ThreadStateUpdateRequest,
|
||||||
|
checkpointer: CurrentCheckpointer,
|
||||||
|
thread_meta_storage: CurrentThreadMetaStorage,
|
||||||
|
) -> ThreadStateResponse:
|
||||||
|
"""Update thread state (human-in-the-loop or title rename)."""
|
||||||
|
read_config: dict[str, Any] = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||||
|
if body.checkpoint_id:
|
||||||
|
read_config["configurable"]["checkpoint_id"] = body.checkpoint_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
checkpoint_tuple = await checkpointer.aget_tuple(read_config)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id))
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to get thread state")
|
||||||
|
|
||||||
|
if checkpoint_tuple is None:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||||
|
|
||||||
|
checkpoint: dict[str, Any] = dict(getattr(checkpoint_tuple, "checkpoint", {}) or {})
|
||||||
|
metadata: dict[str, Any] = dict(getattr(checkpoint_tuple, "metadata", {}) or {})
|
||||||
|
channel_values: dict[str, Any] = dict(checkpoint.get("channel_values", {}))
|
||||||
|
|
||||||
|
if body.values:
|
||||||
|
channel_values.update(body.values)
|
||||||
|
|
||||||
|
checkpoint["channel_values"] = channel_values
|
||||||
|
metadata["updated_at"] = time.time()
|
||||||
|
|
||||||
|
if body.as_node:
|
||||||
|
metadata["source"] = "update"
|
||||||
|
metadata["step"] = metadata.get("step", 0) + 1
|
||||||
|
metadata["writes"] = {body.as_node: body.values}
|
||||||
|
|
||||||
|
write_config: dict[str, Any] = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||||
|
try:
|
||||||
|
new_config = await checkpointer.aput(write_config, checkpoint, metadata, {})
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to update state for thread %s", sanitize_log_param(thread_id))
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to update thread state")
|
||||||
|
|
||||||
|
new_checkpoint_id: str | None = None
|
||||||
|
if isinstance(new_config, dict):
|
||||||
|
new_checkpoint_id = new_config.get("configurable", {}).get("checkpoint_id")
|
||||||
|
|
||||||
|
# Sync title to thread_meta
|
||||||
|
if body.values and "title" in body.values:
|
||||||
|
new_title = body.values["title"]
|
||||||
|
if new_title:
|
||||||
|
try:
|
||||||
|
await thread_meta_storage.sync_thread_title(
|
||||||
|
thread_id=thread_id,
|
||||||
|
title=new_title,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Failed to sync title for %s", sanitize_log_param(thread_id))
|
||||||
|
|
||||||
|
return ThreadStateResponse(
|
||||||
|
values=serialize_channel_values(channel_values),
|
||||||
|
next=[],
|
||||||
|
metadata=metadata,
|
||||||
|
checkpoint_id=new_checkpoint_id,
|
||||||
|
created_at=str(metadata.get("created_at", "")),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{thread_id}/history", response_model=list[HistoryEntry])
|
||||||
|
async def get_thread_history(
|
||||||
|
thread_id: str,
|
||||||
|
body: ThreadHistoryRequest,
|
||||||
|
request: Request,
|
||||||
|
checkpointer: CurrentCheckpointer,
|
||||||
|
thread_meta_storage: CurrentThreadMetaStorage,
|
||||||
|
run_repo: CurrentRunRepository,
|
||||||
|
) -> list[HistoryEntry]:
|
||||||
|
"""Get checkpoint history for a thread."""
|
||||||
|
config: dict[str, Any] = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||||
|
if body.before:
|
||||||
|
config["configurable"]["checkpoint_id"] = body.before
|
||||||
|
|
||||||
|
entries: list[HistoryEntry] = []
|
||||||
|
is_first = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for checkpoint_tuple in checkpointer.alist(config, limit=body.limit):
|
||||||
|
ckpt_config = getattr(checkpoint_tuple, "config", {}) or {}
|
||||||
|
parent_config = getattr(checkpoint_tuple, "parent_config", None)
|
||||||
|
metadata = getattr(checkpoint_tuple, "metadata", {}) or {}
|
||||||
|
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||||
|
|
||||||
|
checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id", "")
|
||||||
|
parent_id = parent_config.get("configurable", {}).get("checkpoint_id") if parent_config else None
|
||||||
|
channel_values = checkpoint.get("channel_values", {})
|
||||||
|
|
||||||
|
values: dict[str, Any] = {}
|
||||||
|
if title := channel_values.get("title"):
|
||||||
|
values["title"] = title
|
||||||
|
if is_first and (messages := channel_values.get("messages")):
|
||||||
|
values["messages"] = serialize_channel_values({"messages": messages}).get("messages", [])
|
||||||
|
is_first = False
|
||||||
|
|
||||||
|
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
|
||||||
|
next_nodes = [t.name for t in tasks_raw if hasattr(t, "name")]
|
||||||
|
|
||||||
|
entries.append(
|
||||||
|
HistoryEntry(
|
||||||
|
checkpoint_id=checkpoint_id,
|
||||||
|
parent_checkpoint_id=parent_id,
|
||||||
|
metadata=metadata,
|
||||||
|
values=values,
|
||||||
|
created_at=str(metadata.get("created_at", "")),
|
||||||
|
next=next_nodes,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to get history for thread %s", sanitize_log_param(thread_id))
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to get thread history")
|
||||||
|
|
||||||
|
if not entries and await _thread_or_run_exists(
|
||||||
|
request=request,
|
||||||
|
thread_id=thread_id,
|
||||||
|
thread_meta_storage=thread_meta_storage,
|
||||||
|
run_repo=run_repo,
|
||||||
|
):
|
||||||
|
return []
|
||||||
|
|
||||||
|
return entries
|
||||||
@@ -1,8 +1,9 @@
|
|||||||
"""Memory API router for retrieving and managing global memory data."""
|
"""Memory API router for retrieving and managing global memory data."""
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.plugins.auth.security.actor_context import bind_request_actor_context
|
||||||
from deerflow.agents.memory.updater import (
|
from deerflow.agents.memory.updater import (
|
||||||
clear_memory_data,
|
clear_memory_data,
|
||||||
create_memory_fact,
|
create_memory_fact,
|
||||||
@@ -13,7 +14,7 @@ from deerflow.agents.memory.updater import (
|
|||||||
update_memory_fact,
|
update_memory_fact,
|
||||||
)
|
)
|
||||||
from deerflow.config.memory_config import get_memory_config
|
from deerflow.config.memory_config import get_memory_config
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
from deerflow.runtime.actor_context import get_effective_user_id
|
||||||
|
|
||||||
router = APIRouter(prefix="/api", tags=["memory"])
|
router = APIRouter(prefix="/api", tags=["memory"])
|
||||||
|
|
||||||
@@ -114,7 +115,7 @@ class MemoryStatusResponse(BaseModel):
|
|||||||
summary="Get Memory Data",
|
summary="Get Memory Data",
|
||||||
description="Retrieve the current global memory data including user context, history, and facts.",
|
description="Retrieve the current global memory data including user context, history, and facts.",
|
||||||
)
|
)
|
||||||
async def get_memory() -> MemoryResponse:
|
async def get_memory(request: Request) -> MemoryResponse:
|
||||||
"""Get the current global memory data.
|
"""Get the current global memory data.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -148,8 +149,9 @@ async def get_memory() -> MemoryResponse:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
memory_data = get_memory_data(user_id=get_effective_user_id())
|
with bind_request_actor_context(request):
|
||||||
return MemoryResponse(**memory_data)
|
memory_data = get_memory_data(user_id=get_effective_user_id())
|
||||||
|
return MemoryResponse(**memory_data)
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
@@ -159,7 +161,7 @@ async def get_memory() -> MemoryResponse:
|
|||||||
summary="Reload Memory Data",
|
summary="Reload Memory Data",
|
||||||
description="Reload memory data from the storage file, refreshing the in-memory cache.",
|
description="Reload memory data from the storage file, refreshing the in-memory cache.",
|
||||||
)
|
)
|
||||||
async def reload_memory() -> MemoryResponse:
|
async def reload_memory(request: Request) -> MemoryResponse:
|
||||||
"""Reload memory data from file.
|
"""Reload memory data from file.
|
||||||
|
|
||||||
This forces a reload of the memory data from the storage file,
|
This forces a reload of the memory data from the storage file,
|
||||||
@@ -168,8 +170,9 @@ async def reload_memory() -> MemoryResponse:
|
|||||||
Returns:
|
Returns:
|
||||||
The reloaded memory data.
|
The reloaded memory data.
|
||||||
"""
|
"""
|
||||||
memory_data = reload_memory_data(user_id=get_effective_user_id())
|
with bind_request_actor_context(request):
|
||||||
return MemoryResponse(**memory_data)
|
memory_data = reload_memory_data(user_id=get_effective_user_id())
|
||||||
|
return MemoryResponse(**memory_data)
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
@router.delete(
|
||||||
@@ -179,14 +182,15 @@ async def reload_memory() -> MemoryResponse:
|
|||||||
summary="Clear All Memory Data",
|
summary="Clear All Memory Data",
|
||||||
description="Delete all saved memory data and reset the memory structure to an empty state.",
|
description="Delete all saved memory data and reset the memory structure to an empty state.",
|
||||||
)
|
)
|
||||||
async def clear_memory() -> MemoryResponse:
|
async def clear_memory(request: Request) -> MemoryResponse:
|
||||||
"""Clear all persisted memory data."""
|
"""Clear all persisted memory data."""
|
||||||
try:
|
with bind_request_actor_context(request):
|
||||||
memory_data = clear_memory_data(user_id=get_effective_user_id())
|
try:
|
||||||
except OSError as exc:
|
memory_data = clear_memory_data(user_id=get_effective_user_id())
|
||||||
raise HTTPException(status_code=500, detail="Failed to clear memory data.") from exc
|
except OSError as exc:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to clear memory data.") from exc
|
||||||
|
|
||||||
return MemoryResponse(**memory_data)
|
return MemoryResponse(**memory_data)
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
@@ -196,21 +200,22 @@ async def clear_memory() -> MemoryResponse:
|
|||||||
summary="Create Memory Fact",
|
summary="Create Memory Fact",
|
||||||
description="Create a single saved memory fact manually.",
|
description="Create a single saved memory fact manually.",
|
||||||
)
|
)
|
||||||
async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryResponse:
|
async def create_memory_fact_endpoint(request: Request, payload: FactCreateRequest) -> MemoryResponse:
|
||||||
"""Create a single fact manually."""
|
"""Create a single fact manually."""
|
||||||
try:
|
with bind_request_actor_context(request):
|
||||||
memory_data = create_memory_fact(
|
try:
|
||||||
content=request.content,
|
memory_data = create_memory_fact(
|
||||||
category=request.category,
|
content=payload.content,
|
||||||
confidence=request.confidence,
|
category=payload.category,
|
||||||
user_id=get_effective_user_id(),
|
confidence=payload.confidence,
|
||||||
)
|
user_id=get_effective_user_id(),
|
||||||
except ValueError as exc:
|
)
|
||||||
raise _map_memory_fact_value_error(exc) from exc
|
except ValueError as exc:
|
||||||
except OSError as exc:
|
raise _map_memory_fact_value_error(exc) from exc
|
||||||
raise HTTPException(status_code=500, detail="Failed to create memory fact.") from exc
|
except OSError as exc:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to create memory fact.") from exc
|
||||||
|
|
||||||
return MemoryResponse(**memory_data)
|
return MemoryResponse(**memory_data)
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
@router.delete(
|
||||||
@@ -220,16 +225,17 @@ async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryRespo
|
|||||||
summary="Delete Memory Fact",
|
summary="Delete Memory Fact",
|
||||||
description="Delete a single saved memory fact by its fact id.",
|
description="Delete a single saved memory fact by its fact id.",
|
||||||
)
|
)
|
||||||
async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse:
|
async def delete_memory_fact_endpoint(fact_id: str, request: Request) -> MemoryResponse:
|
||||||
"""Delete a single fact from memory by fact id."""
|
"""Delete a single fact from memory by fact id."""
|
||||||
try:
|
with bind_request_actor_context(request):
|
||||||
memory_data = delete_memory_fact(fact_id, user_id=get_effective_user_id())
|
try:
|
||||||
except KeyError as exc:
|
memory_data = delete_memory_fact(fact_id, user_id=get_effective_user_id())
|
||||||
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
|
except KeyError as exc:
|
||||||
except OSError as exc:
|
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
|
||||||
raise HTTPException(status_code=500, detail="Failed to delete memory fact.") from exc
|
except OSError as exc:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to delete memory fact.") from exc
|
||||||
|
|
||||||
return MemoryResponse(**memory_data)
|
return MemoryResponse(**memory_data)
|
||||||
|
|
||||||
|
|
||||||
@router.patch(
|
@router.patch(
|
||||||
@@ -239,24 +245,25 @@ async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse:
|
|||||||
summary="Patch Memory Fact",
|
summary="Patch Memory Fact",
|
||||||
description="Partially update a single saved memory fact by its fact id while preserving omitted fields.",
|
description="Partially update a single saved memory fact by its fact id while preserving omitted fields.",
|
||||||
)
|
)
|
||||||
async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -> MemoryResponse:
|
async def update_memory_fact_endpoint(fact_id: str, request: Request, payload: FactPatchRequest) -> MemoryResponse:
|
||||||
"""Partially update a single fact manually."""
|
"""Partially update a single fact manually."""
|
||||||
try:
|
with bind_request_actor_context(request):
|
||||||
memory_data = update_memory_fact(
|
try:
|
||||||
fact_id=fact_id,
|
memory_data = update_memory_fact(
|
||||||
content=request.content,
|
fact_id=fact_id,
|
||||||
category=request.category,
|
content=payload.content,
|
||||||
confidence=request.confidence,
|
category=payload.category,
|
||||||
user_id=get_effective_user_id(),
|
confidence=payload.confidence,
|
||||||
)
|
user_id=get_effective_user_id(),
|
||||||
except ValueError as exc:
|
)
|
||||||
raise _map_memory_fact_value_error(exc) from exc
|
except ValueError as exc:
|
||||||
except KeyError as exc:
|
raise _map_memory_fact_value_error(exc) from exc
|
||||||
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
|
except KeyError as exc:
|
||||||
except OSError as exc:
|
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
|
||||||
raise HTTPException(status_code=500, detail="Failed to update memory fact.") from exc
|
except OSError as exc:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to update memory fact.") from exc
|
||||||
|
|
||||||
return MemoryResponse(**memory_data)
|
return MemoryResponse(**memory_data)
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
@@ -266,10 +273,11 @@ async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -
|
|||||||
summary="Export Memory Data",
|
summary="Export Memory Data",
|
||||||
description="Export the current global memory data as JSON for backup or transfer.",
|
description="Export the current global memory data as JSON for backup or transfer.",
|
||||||
)
|
)
|
||||||
async def export_memory() -> MemoryResponse:
|
async def export_memory(request: Request) -> MemoryResponse:
|
||||||
"""Export the current memory data."""
|
"""Export the current memory data."""
|
||||||
memory_data = get_memory_data(user_id=get_effective_user_id())
|
with bind_request_actor_context(request):
|
||||||
return MemoryResponse(**memory_data)
|
memory_data = get_memory_data(user_id=get_effective_user_id())
|
||||||
|
return MemoryResponse(**memory_data)
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
@@ -279,14 +287,15 @@ async def export_memory() -> MemoryResponse:
|
|||||||
summary="Import Memory Data",
|
summary="Import Memory Data",
|
||||||
description="Import and overwrite the current global memory data from a JSON payload.",
|
description="Import and overwrite the current global memory data from a JSON payload.",
|
||||||
)
|
)
|
||||||
async def import_memory(request: MemoryResponse) -> MemoryResponse:
|
async def import_memory(request: Request, payload: MemoryResponse) -> MemoryResponse:
|
||||||
"""Import and persist memory data."""
|
"""Import and persist memory data."""
|
||||||
try:
|
with bind_request_actor_context(request):
|
||||||
memory_data = import_memory_data(request.model_dump(), user_id=get_effective_user_id())
|
try:
|
||||||
except OSError as exc:
|
memory_data = import_memory_data(payload.model_dump(), user_id=get_effective_user_id())
|
||||||
raise HTTPException(status_code=500, detail="Failed to import memory data.") from exc
|
except OSError as exc:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to import memory data.") from exc
|
||||||
|
|
||||||
return MemoryResponse(**memory_data)
|
return MemoryResponse(**memory_data)
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
@@ -333,24 +342,25 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
|
|||||||
summary="Get Memory Status",
|
summary="Get Memory Status",
|
||||||
description="Retrieve both memory configuration and current data in a single request.",
|
description="Retrieve both memory configuration and current data in a single request.",
|
||||||
)
|
)
|
||||||
async def get_memory_status() -> MemoryStatusResponse:
|
async def get_memory_status(request: Request) -> MemoryStatusResponse:
|
||||||
"""Get the memory system status including configuration and data.
|
"""Get the memory system status including configuration and data.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Combined memory configuration and current data.
|
Combined memory configuration and current data.
|
||||||
"""
|
"""
|
||||||
config = get_memory_config()
|
with bind_request_actor_context(request):
|
||||||
memory_data = get_memory_data(user_id=get_effective_user_id())
|
config = get_memory_config()
|
||||||
|
memory_data = get_memory_data(user_id=get_effective_user_id())
|
||||||
|
|
||||||
return MemoryStatusResponse(
|
return MemoryStatusResponse(
|
||||||
config=MemoryConfigResponse(
|
config=MemoryConfigResponse(
|
||||||
enabled=config.enabled,
|
enabled=config.enabled,
|
||||||
storage_path=config.storage_path,
|
storage_path=config.storage_path,
|
||||||
debounce_seconds=config.debounce_seconds,
|
debounce_seconds=config.debounce_seconds,
|
||||||
max_facts=config.max_facts,
|
max_facts=config.max_facts,
|
||||||
fact_confidence_threshold=config.fact_confidence_threshold,
|
fact_confidence_threshold=config.fact_confidence_threshold,
|
||||||
injection_enabled=config.injection_enabled,
|
injection_enabled=config.injection_enabled,
|
||||||
max_injection_tokens=config.max_injection_tokens,
|
max_injection_tokens=config.max_injection_tokens,
|
||||||
),
|
),
|
||||||
data=MemoryResponse(**memory_data),
|
data=MemoryResponse(**memory_data),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,143 +0,0 @@
|
|||||||
"""Stateless runs endpoints -- stream and wait without a pre-existing thread.
|
|
||||||
|
|
||||||
These endpoints auto-create a temporary thread when no ``thread_id`` is
|
|
||||||
supplied in the request body. When a ``thread_id`` **is** provided, it
|
|
||||||
is reused so that conversation history is preserved across calls.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Query, Request
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
|
|
||||||
from app.gateway.authz import require_permission
|
|
||||||
from app.gateway.deps import get_checkpointer, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
|
||||||
from app.gateway.routers.thread_runs import RunCreateRequest
|
|
||||||
from app.gateway.services import sse_consumer, start_run
|
|
||||||
from deerflow.runtime import serialize_channel_values
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
router = APIRouter(prefix="/api/runs", tags=["runs"])
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_thread_id(body: RunCreateRequest) -> str:
|
|
||||||
"""Return the thread_id from the request body, or generate a new one."""
|
|
||||||
thread_id = (body.config or {}).get("configurable", {}).get("thread_id")
|
|
||||||
if thread_id:
|
|
||||||
return str(thread_id)
|
|
||||||
return str(uuid.uuid4())
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/stream")
|
|
||||||
async def stateless_stream(body: RunCreateRequest, request: Request) -> StreamingResponse:
|
|
||||||
"""Create a run and stream events via SSE.
|
|
||||||
|
|
||||||
If ``config.configurable.thread_id`` is provided, the run is created
|
|
||||||
on the given thread so that conversation history is preserved.
|
|
||||||
Otherwise a new temporary thread is created.
|
|
||||||
"""
|
|
||||||
thread_id = _resolve_thread_id(body)
|
|
||||||
bridge = get_stream_bridge(request)
|
|
||||||
run_mgr = get_run_manager(request)
|
|
||||||
record = await start_run(body, thread_id, request)
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
sse_consumer(bridge, record, request, run_mgr),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers={
|
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
"Connection": "keep-alive",
|
|
||||||
"X-Accel-Buffering": "no",
|
|
||||||
"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/wait", response_model=dict)
|
|
||||||
async def stateless_wait(body: RunCreateRequest, request: Request) -> dict:
|
|
||||||
"""Create a run and block until completion.
|
|
||||||
|
|
||||||
If ``config.configurable.thread_id`` is provided, the run is created
|
|
||||||
on the given thread so that conversation history is preserved.
|
|
||||||
Otherwise a new temporary thread is created.
|
|
||||||
"""
|
|
||||||
thread_id = _resolve_thread_id(body)
|
|
||||||
record = await start_run(body, thread_id, request)
|
|
||||||
|
|
||||||
if record.task is not None:
|
|
||||||
try:
|
|
||||||
await record.task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
checkpointer = get_checkpointer(request)
|
|
||||||
config = {"configurable": {"thread_id": thread_id}}
|
|
||||||
try:
|
|
||||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
|
||||||
if checkpoint_tuple is not None:
|
|
||||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
|
||||||
channel_values = checkpoint.get("channel_values", {})
|
|
||||||
return serialize_channel_values(channel_values)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
|
||||||
|
|
||||||
return {"status": record.status.value, "error": record.error}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Run-scoped read endpoints
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
async def _resolve_run(run_id: str, request: Request) -> dict:
|
|
||||||
"""Fetch run by run_id with user ownership check. Raises 404 if not found."""
|
|
||||||
run_store = get_run_store(request)
|
|
||||||
record = await run_store.get(run_id) # user_id=AUTO filters by contextvar
|
|
||||||
if record is None:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
|
||||||
return record
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{run_id}/messages")
|
|
||||||
@require_permission("runs", "read")
|
|
||||||
async def run_messages(
|
|
||||||
run_id: str,
|
|
||||||
request: Request,
|
|
||||||
limit: int = Query(default=50, le=200, ge=1),
|
|
||||||
before_seq: int | None = Query(default=None),
|
|
||||||
after_seq: int | None = Query(default=None),
|
|
||||||
) -> dict:
|
|
||||||
"""Return paginated messages for a run (cursor-based).
|
|
||||||
|
|
||||||
Pagination:
|
|
||||||
- after_seq: messages with seq > after_seq (forward)
|
|
||||||
- before_seq: messages with seq < before_seq (backward)
|
|
||||||
- neither: latest messages
|
|
||||||
|
|
||||||
Response: { data: [...], has_more: bool }
|
|
||||||
"""
|
|
||||||
run = await _resolve_run(run_id, request)
|
|
||||||
event_store = get_run_event_store(request)
|
|
||||||
rows = await event_store.list_messages_by_run(
|
|
||||||
run["thread_id"],
|
|
||||||
run_id,
|
|
||||||
limit=limit + 1,
|
|
||||||
before_seq=before_seq,
|
|
||||||
after_seq=after_seq,
|
|
||||||
)
|
|
||||||
has_more = len(rows) > limit
|
|
||||||
data = rows[:limit] if has_more else rows
|
|
||||||
return {"data": data, "has_more": has_more}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{run_id}/feedback")
|
|
||||||
@require_permission("runs", "read")
|
|
||||||
async def run_feedback(run_id: str, request: Request) -> list[dict]:
|
|
||||||
"""Return all feedback for a run."""
|
|
||||||
run = await _resolve_run(run_id, request)
|
|
||||||
feedback_repo = get_feedback_repo(request)
|
|
||||||
return await feedback_repo.list_by_run(run["thread_id"], run_id)
|
|
||||||
@@ -5,7 +5,6 @@ from fastapi import APIRouter, Request
|
|||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from app.gateway.authz import require_permission
|
|
||||||
from deerflow.models import create_chat_model
|
from deerflow.models import create_chat_model
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -99,7 +98,6 @@ def _format_conversation(messages: list[SuggestionMessage]) -> str:
|
|||||||
summary="Generate Follow-up Questions",
|
summary="Generate Follow-up Questions",
|
||||||
description="Generate short follow-up questions a user might ask next, based on recent conversation context.",
|
description="Generate short follow-up questions a user might ask next, based on recent conversation context.",
|
||||||
)
|
)
|
||||||
@require_permission("threads", "read", owner_check=True)
|
|
||||||
async def generate_suggestions(thread_id: str, body: SuggestionsRequest, request: Request) -> SuggestionsResponse:
|
async def generate_suggestions(thread_id: str, body: SuggestionsRequest, request: Request) -> SuggestionsResponse:
|
||||||
if not body.messages:
|
if not body.messages:
|
||||||
return SuggestionsResponse(suggestions=[])
|
return SuggestionsResponse(suggestions=[])
|
||||||
|
|||||||
@@ -1,377 +0,0 @@
|
|||||||
"""Runs endpoints — create, stream, wait, cancel.
|
|
||||||
|
|
||||||
Implements the LangGraph Platform runs API on top of
|
|
||||||
:class:`deerflow.agents.runs.RunManager` and
|
|
||||||
:class:`deerflow.agents.stream_bridge.StreamBridge`.
|
|
||||||
|
|
||||||
SSE format is aligned with the LangGraph Platform protocol so that
|
|
||||||
the ``useStream`` React hook from ``@langchain/langgraph-sdk/react``
|
|
||||||
works without modification.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
from typing import Any, Literal
|
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Query, Request
|
|
||||||
from fastapi.responses import Response, StreamingResponse
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from app.gateway.authz import require_permission
|
|
||||||
from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
|
||||||
from app.gateway.services import sse_consumer, start_run
|
|
||||||
from deerflow.runtime import RunRecord, serialize_channel_values
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
router = APIRouter(prefix="/api/threads", tags=["runs"])
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Request / response models
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class RunCreateRequest(BaseModel):
|
|
||||||
assistant_id: str | None = Field(default=None, description="Agent / assistant to use")
|
|
||||||
input: dict[str, Any] | None = Field(default=None, description="Graph input (e.g. {messages: [...]})")
|
|
||||||
command: dict[str, Any] | None = Field(default=None, description="LangGraph Command")
|
|
||||||
metadata: dict[str, Any] | None = Field(default=None, description="Run metadata")
|
|
||||||
config: dict[str, Any] | None = Field(default=None, description="RunnableConfig overrides")
|
|
||||||
context: dict[str, Any] | None = Field(default=None, description="DeerFlow context overrides (model_name, thinking_enabled, etc.)")
|
|
||||||
webhook: str | None = Field(default=None, description="Completion callback URL")
|
|
||||||
checkpoint_id: str | None = Field(default=None, description="Resume from checkpoint")
|
|
||||||
checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object")
|
|
||||||
interrupt_before: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt before")
|
|
||||||
interrupt_after: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt after")
|
|
||||||
stream_mode: list[str] | str | None = Field(default=None, description="Stream mode(s)")
|
|
||||||
stream_subgraphs: bool = Field(default=False, description="Include subgraph events")
|
|
||||||
stream_resumable: bool | None = Field(default=None, description="SSE resumable mode")
|
|
||||||
on_disconnect: Literal["cancel", "continue"] = Field(default="cancel", description="Behaviour on SSE disconnect")
|
|
||||||
on_completion: Literal["delete", "keep"] = Field(default="keep", description="Delete temp thread on completion")
|
|
||||||
multitask_strategy: Literal["reject", "rollback", "interrupt", "enqueue"] = Field(default="reject", description="Concurrency strategy")
|
|
||||||
after_seconds: float | None = Field(default=None, description="Delayed execution")
|
|
||||||
if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy")
|
|
||||||
feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys")
|
|
||||||
|
|
||||||
|
|
||||||
class RunResponse(BaseModel):
|
|
||||||
run_id: str
|
|
||||||
thread_id: str
|
|
||||||
assistant_id: str | None = None
|
|
||||||
status: str
|
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
||||||
kwargs: dict[str, Any] = Field(default_factory=dict)
|
|
||||||
multitask_strategy: str = "reject"
|
|
||||||
created_at: str = ""
|
|
||||||
updated_at: str = ""
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Helpers
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _record_to_response(record: RunRecord) -> RunResponse:
|
|
||||||
return RunResponse(
|
|
||||||
run_id=record.run_id,
|
|
||||||
thread_id=record.thread_id,
|
|
||||||
assistant_id=record.assistant_id,
|
|
||||||
status=record.status.value,
|
|
||||||
metadata=record.metadata,
|
|
||||||
kwargs=record.kwargs,
|
|
||||||
multitask_strategy=record.multitask_strategy,
|
|
||||||
created_at=record.created_at,
|
|
||||||
updated_at=record.updated_at,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Endpoints
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{thread_id}/runs", response_model=RunResponse)
|
|
||||||
@require_permission("runs", "create", owner_check=True, require_existing=True)
|
|
||||||
async def create_run(thread_id: str, body: RunCreateRequest, request: Request) -> RunResponse:
|
|
||||||
"""Create a background run (returns immediately)."""
|
|
||||||
record = await start_run(body, thread_id, request)
|
|
||||||
return _record_to_response(record)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{thread_id}/runs/stream")
|
|
||||||
@require_permission("runs", "create", owner_check=True, require_existing=True)
|
|
||||||
async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -> StreamingResponse:
|
|
||||||
"""Create a run and stream events via SSE.
|
|
||||||
|
|
||||||
The response includes a ``Content-Location`` header with the run's
|
|
||||||
resource URL, matching the LangGraph Platform protocol. The
|
|
||||||
``useStream`` React hook uses this to extract run metadata.
|
|
||||||
"""
|
|
||||||
bridge = get_stream_bridge(request)
|
|
||||||
run_mgr = get_run_manager(request)
|
|
||||||
record = await start_run(body, thread_id, request)
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
sse_consumer(bridge, record, request, run_mgr),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers={
|
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
"Connection": "keep-alive",
|
|
||||||
"X-Accel-Buffering": "no",
|
|
||||||
# LangGraph Platform includes run metadata in this header.
|
|
||||||
# The SDK uses a greedy regex to extract the run id from this path,
|
|
||||||
# so it must point at the canonical run resource without extra suffixes.
|
|
||||||
"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{thread_id}/runs/wait", response_model=dict)
|
|
||||||
@require_permission("runs", "create", owner_check=True, require_existing=True)
|
|
||||||
async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) -> dict:
|
|
||||||
"""Create a run and block until it completes, returning the final state."""
|
|
||||||
record = await start_run(body, thread_id, request)
|
|
||||||
|
|
||||||
if record.task is not None:
|
|
||||||
try:
|
|
||||||
await record.task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
checkpointer = get_checkpointer(request)
|
|
||||||
config = {"configurable": {"thread_id": thread_id}}
|
|
||||||
try:
|
|
||||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
|
||||||
if checkpoint_tuple is not None:
|
|
||||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
|
||||||
channel_values = checkpoint.get("channel_values", {})
|
|
||||||
return serialize_channel_values(channel_values)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
|
||||||
|
|
||||||
return {"status": record.status.value, "error": record.error}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{thread_id}/runs", response_model=list[RunResponse])
|
|
||||||
@require_permission("runs", "read", owner_check=True)
|
|
||||||
async def list_runs(thread_id: str, request: Request) -> list[RunResponse]:
|
|
||||||
"""List all runs for a thread."""
|
|
||||||
run_mgr = get_run_manager(request)
|
|
||||||
records = await run_mgr.list_by_thread(thread_id)
|
|
||||||
return [_record_to_response(r) for r in records]
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{thread_id}/runs/{run_id}", response_model=RunResponse)
|
|
||||||
@require_permission("runs", "read", owner_check=True)
|
|
||||||
async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
|
|
||||||
"""Get details of a specific run."""
|
|
||||||
run_mgr = get_run_manager(request)
|
|
||||||
record = run_mgr.get(run_id)
|
|
||||||
if record is None or record.thread_id != thread_id:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
|
||||||
return _record_to_response(record)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{thread_id}/runs/{run_id}/cancel")
|
|
||||||
@require_permission("runs", "cancel", owner_check=True, require_existing=True)
|
|
||||||
async def cancel_run(
|
|
||||||
thread_id: str,
|
|
||||||
run_id: str,
|
|
||||||
request: Request,
|
|
||||||
wait: bool = Query(default=False, description="Block until run completes after cancel"),
|
|
||||||
action: Literal["interrupt", "rollback"] = Query(default="interrupt", description="Cancel action"),
|
|
||||||
) -> Response:
|
|
||||||
"""Cancel a running or pending run.
|
|
||||||
|
|
||||||
- action=interrupt: Stop execution, keep current checkpoint (can be resumed)
|
|
||||||
- action=rollback: Stop execution, revert to pre-run checkpoint state
|
|
||||||
- wait=true: Block until the run fully stops, return 204
|
|
||||||
- wait=false: Return immediately with 202
|
|
||||||
"""
|
|
||||||
run_mgr = get_run_manager(request)
|
|
||||||
record = run_mgr.get(run_id)
|
|
||||||
if record is None or record.thread_id != thread_id:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
|
||||||
|
|
||||||
cancelled = await run_mgr.cancel(run_id, action=action)
|
|
||||||
if not cancelled:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=409,
|
|
||||||
detail=f"Run {run_id} is not cancellable (status: {record.status.value})",
|
|
||||||
)
|
|
||||||
|
|
||||||
if wait and record.task is not None:
|
|
||||||
try:
|
|
||||||
await record.task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
return Response(status_code=204)
|
|
||||||
|
|
||||||
return Response(status_code=202)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{thread_id}/runs/{run_id}/join")
|
|
||||||
@require_permission("runs", "read", owner_check=True)
|
|
||||||
async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingResponse:
|
|
||||||
"""Join an existing run's SSE stream."""
|
|
||||||
bridge = get_stream_bridge(request)
|
|
||||||
run_mgr = get_run_manager(request)
|
|
||||||
record = run_mgr.get(run_id)
|
|
||||||
if record is None or record.thread_id != thread_id:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
sse_consumer(bridge, record, request, run_mgr),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers={
|
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
"Connection": "keep-alive",
|
|
||||||
"X-Accel-Buffering": "no",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.api_route("/{thread_id}/runs/{run_id}/stream", methods=["GET", "POST"], response_model=None)
|
|
||||||
@require_permission("runs", "read", owner_check=True)
|
|
||||||
async def stream_existing_run(
|
|
||||||
thread_id: str,
|
|
||||||
run_id: str,
|
|
||||||
request: Request,
|
|
||||||
action: Literal["interrupt", "rollback"] | None = Query(default=None, description="Cancel action"),
|
|
||||||
wait: int = Query(default=0, description="Block until cancelled (1) or return immediately (0)"),
|
|
||||||
):
|
|
||||||
"""Join an existing run's SSE stream (GET), or cancel-then-stream (POST).
|
|
||||||
|
|
||||||
The LangGraph SDK's ``joinStream`` and ``useStream`` stop button both use
|
|
||||||
``POST`` to this endpoint. When ``action=interrupt`` or ``action=rollback``
|
|
||||||
is present the run is cancelled first; the response then streams any
|
|
||||||
remaining buffered events so the client observes a clean shutdown.
|
|
||||||
"""
|
|
||||||
run_mgr = get_run_manager(request)
|
|
||||||
record = run_mgr.get(run_id)
|
|
||||||
if record is None or record.thread_id != thread_id:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
|
||||||
|
|
||||||
# Cancel if an action was requested (stop-button / interrupt flow)
|
|
||||||
if action is not None:
|
|
||||||
cancelled = await run_mgr.cancel(run_id, action=action)
|
|
||||||
if cancelled and wait and record.task is not None:
|
|
||||||
try:
|
|
||||||
await record.task
|
|
||||||
except (asyncio.CancelledError, Exception):
|
|
||||||
pass
|
|
||||||
return Response(status_code=204)
|
|
||||||
|
|
||||||
bridge = get_stream_bridge(request)
|
|
||||||
return StreamingResponse(
|
|
||||||
sse_consumer(bridge, record, request, run_mgr),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers={
|
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
"Connection": "keep-alive",
|
|
||||||
"X-Accel-Buffering": "no",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Messages / Events / Token usage endpoints
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{thread_id}/messages")
|
|
||||||
@require_permission("runs", "read", owner_check=True)
|
|
||||||
async def list_thread_messages(
|
|
||||||
thread_id: str,
|
|
||||||
request: Request,
|
|
||||||
limit: int = Query(default=50, le=200),
|
|
||||||
before_seq: int | None = Query(default=None),
|
|
||||||
after_seq: int | None = Query(default=None),
|
|
||||||
) -> list[dict]:
|
|
||||||
"""Return displayable messages for a thread (across all runs), with feedback attached."""
|
|
||||||
event_store = get_run_event_store(request)
|
|
||||||
messages = await event_store.list_messages(thread_id, limit=limit, before_seq=before_seq, after_seq=after_seq)
|
|
||||||
|
|
||||||
# Attach feedback to the last AI message of each run
|
|
||||||
feedback_repo = get_feedback_repo(request)
|
|
||||||
user_id = await get_current_user(request)
|
|
||||||
feedback_map = await feedback_repo.list_by_thread_grouped(thread_id, user_id=user_id)
|
|
||||||
|
|
||||||
# Find the last ai_message per run_id
|
|
||||||
last_ai_per_run: dict[str, int] = {} # run_id -> index in messages list
|
|
||||||
for i, msg in enumerate(messages):
|
|
||||||
if msg.get("event_type") == "ai_message":
|
|
||||||
last_ai_per_run[msg["run_id"]] = i
|
|
||||||
|
|
||||||
# Attach feedback field
|
|
||||||
last_ai_indices = set(last_ai_per_run.values())
|
|
||||||
for i, msg in enumerate(messages):
|
|
||||||
if i in last_ai_indices:
|
|
||||||
run_id = msg["run_id"]
|
|
||||||
fb = feedback_map.get(run_id)
|
|
||||||
msg["feedback"] = (
|
|
||||||
{
|
|
||||||
"feedback_id": fb["feedback_id"],
|
|
||||||
"rating": fb["rating"],
|
|
||||||
"comment": fb.get("comment"),
|
|
||||||
}
|
|
||||||
if fb
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
msg["feedback"] = None
|
|
||||||
|
|
||||||
return messages
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{thread_id}/runs/{run_id}/messages")
|
|
||||||
@require_permission("runs", "read", owner_check=True)
|
|
||||||
async def list_run_messages(
|
|
||||||
thread_id: str,
|
|
||||||
run_id: str,
|
|
||||||
request: Request,
|
|
||||||
limit: int = Query(default=50, le=200, ge=1),
|
|
||||||
before_seq: int | None = Query(default=None),
|
|
||||||
after_seq: int | None = Query(default=None),
|
|
||||||
) -> dict:
|
|
||||||
"""Return paginated messages for a specific run.
|
|
||||||
|
|
||||||
Response: { data: [...], has_more: bool }
|
|
||||||
"""
|
|
||||||
event_store = get_run_event_store(request)
|
|
||||||
rows = await event_store.list_messages_by_run(
|
|
||||||
thread_id,
|
|
||||||
run_id,
|
|
||||||
limit=limit + 1,
|
|
||||||
before_seq=before_seq,
|
|
||||||
after_seq=after_seq,
|
|
||||||
)
|
|
||||||
has_more = len(rows) > limit
|
|
||||||
data = rows[:limit] if has_more else rows
|
|
||||||
return {"data": data, "has_more": has_more}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{thread_id}/runs/{run_id}/events")
|
|
||||||
@require_permission("runs", "read", owner_check=True)
|
|
||||||
async def list_run_events(
|
|
||||||
thread_id: str,
|
|
||||||
run_id: str,
|
|
||||||
request: Request,
|
|
||||||
event_types: str | None = Query(default=None),
|
|
||||||
limit: int = Query(default=500, le=2000),
|
|
||||||
) -> list[dict]:
|
|
||||||
"""Return the full event stream for a run (debug/audit)."""
|
|
||||||
event_store = get_run_event_store(request)
|
|
||||||
types = event_types.split(",") if event_types else None
|
|
||||||
return await event_store.list_events(thread_id, run_id, event_types=types, limit=limit)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{thread_id}/token-usage")
|
|
||||||
@require_permission("threads", "read", owner_check=True)
|
|
||||||
async def thread_token_usage(thread_id: str, request: Request) -> dict:
|
|
||||||
"""Thread-level token usage aggregation."""
|
|
||||||
run_store = get_run_store(request)
|
|
||||||
agg = await run_store.aggregate_tokens_by_thread(thread_id)
|
|
||||||
return {"thread_id": thread_id, **agg}
|
|
||||||
@@ -1,621 +0,0 @@
|
|||||||
"""Thread CRUD, state, and history endpoints.
|
|
||||||
|
|
||||||
Combines the existing thread-local filesystem cleanup with LangGraph
|
|
||||||
Platform-compatible thread management backed by the checkpointer.
|
|
||||||
|
|
||||||
Channel values returned in state responses are serialized through
|
|
||||||
:func:`deerflow.runtime.serialization.serialize_channel_values` to
|
|
||||||
ensure LangChain message objects are converted to JSON-safe dicts
|
|
||||||
matching the LangGraph Platform wire format expected by the
|
|
||||||
``useStream`` React hook.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Request
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
|
||||||
|
|
||||||
from app.gateway.authz import require_permission
|
|
||||||
from app.gateway.deps import get_checkpointer
|
|
||||||
from app.gateway.utils import sanitize_log_param
|
|
||||||
from deerflow.config.paths import Paths, get_paths
|
|
||||||
from deerflow.runtime import serialize_channel_values
|
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
router = APIRouter(prefix="/api/threads", tags=["threads"])
|
|
||||||
|
|
||||||
|
|
||||||
# Metadata keys that the server controls; clients are not allowed to set
|
|
||||||
# them. Pydantic ``@field_validator("metadata")`` strips them on every
|
|
||||||
# inbound model below so a malicious client cannot reflect a forged
|
|
||||||
# owner identity through the API surface. Defense-in-depth — the
|
|
||||||
# row-level invariant is still ``threads_meta.user_id`` populated from
|
|
||||||
# the auth contextvar; this list closes the metadata-blob echo gap.
|
|
||||||
_SERVER_RESERVED_METADATA_KEYS: frozenset[str] = frozenset({"owner_id", "user_id"})
|
|
||||||
|
|
||||||
|
|
||||||
def _strip_reserved_metadata(metadata: dict[str, Any] | None) -> dict[str, Any]:
|
|
||||||
"""Return ``metadata`` with server-controlled keys removed."""
|
|
||||||
if not metadata:
|
|
||||||
return metadata or {}
|
|
||||||
return {k: v for k, v in metadata.items() if k not in _SERVER_RESERVED_METADATA_KEYS}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Response / request models
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadDeleteResponse(BaseModel):
|
|
||||||
"""Response model for thread cleanup."""
|
|
||||||
|
|
||||||
success: bool
|
|
||||||
message: str
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadResponse(BaseModel):
|
|
||||||
"""Response model for a single thread."""
|
|
||||||
|
|
||||||
thread_id: str = Field(description="Unique thread identifier")
|
|
||||||
status: str = Field(default="idle", description="Thread status: idle, busy, interrupted, error")
|
|
||||||
created_at: str = Field(default="", description="ISO timestamp")
|
|
||||||
updated_at: str = Field(default="", description="ISO timestamp")
|
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Thread metadata")
|
|
||||||
values: dict[str, Any] = Field(default_factory=dict, description="Current state channel values")
|
|
||||||
interrupts: dict[str, Any] = Field(default_factory=dict, description="Pending interrupts")
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadCreateRequest(BaseModel):
|
|
||||||
"""Request body for creating a thread."""
|
|
||||||
|
|
||||||
thread_id: str | None = Field(default=None, description="Optional thread ID (auto-generated if omitted)")
|
|
||||||
assistant_id: str | None = Field(default=None, description="Associate thread with an assistant")
|
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Initial metadata")
|
|
||||||
|
|
||||||
_strip_reserved = field_validator("metadata")(classmethod(lambda cls, v: _strip_reserved_metadata(v)))
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadSearchRequest(BaseModel):
|
|
||||||
"""Request body for searching threads."""
|
|
||||||
|
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Metadata filter (exact match)")
|
|
||||||
limit: int = Field(default=100, ge=1, le=1000, description="Maximum results")
|
|
||||||
offset: int = Field(default=0, ge=0, description="Pagination offset")
|
|
||||||
status: str | None = Field(default=None, description="Filter by thread status")
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadStateResponse(BaseModel):
|
|
||||||
"""Response model for thread state."""
|
|
||||||
|
|
||||||
values: dict[str, Any] = Field(default_factory=dict, description="Current channel values")
|
|
||||||
next: list[str] = Field(default_factory=list, description="Next tasks to execute")
|
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Checkpoint metadata")
|
|
||||||
checkpoint: dict[str, Any] = Field(default_factory=dict, description="Checkpoint info")
|
|
||||||
checkpoint_id: str | None = Field(default=None, description="Current checkpoint ID")
|
|
||||||
parent_checkpoint_id: str | None = Field(default=None, description="Parent checkpoint ID")
|
|
||||||
created_at: str | None = Field(default=None, description="Checkpoint timestamp")
|
|
||||||
tasks: list[dict[str, Any]] = Field(default_factory=list, description="Interrupted task details")
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadPatchRequest(BaseModel):
|
|
||||||
"""Request body for patching thread metadata."""
|
|
||||||
|
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Metadata to merge")
|
|
||||||
|
|
||||||
_strip_reserved = field_validator("metadata")(classmethod(lambda cls, v: _strip_reserved_metadata(v)))
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadStateUpdateRequest(BaseModel):
|
|
||||||
"""Request body for updating thread state (human-in-the-loop resume)."""
|
|
||||||
|
|
||||||
values: dict[str, Any] | None = Field(default=None, description="Channel values to merge")
|
|
||||||
checkpoint_id: str | None = Field(default=None, description="Checkpoint to branch from")
|
|
||||||
checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object")
|
|
||||||
as_node: str | None = Field(default=None, description="Node identity for the update")
|
|
||||||
|
|
||||||
|
|
||||||
class HistoryEntry(BaseModel):
|
|
||||||
"""Single checkpoint history entry."""
|
|
||||||
|
|
||||||
checkpoint_id: str
|
|
||||||
parent_checkpoint_id: str | None = None
|
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
||||||
values: dict[str, Any] = Field(default_factory=dict)
|
|
||||||
created_at: str | None = None
|
|
||||||
next: list[str] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadHistoryRequest(BaseModel):
|
|
||||||
"""Request body for checkpoint history."""
|
|
||||||
|
|
||||||
limit: int = Field(default=10, ge=1, le=100, description="Maximum entries")
|
|
||||||
before: str | None = Field(default=None, description="Cursor for pagination")
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Helpers
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _delete_thread_data(thread_id: str, paths: Paths | None = None, *, user_id: str | None = None) -> ThreadDeleteResponse:
|
|
||||||
"""Delete local persisted filesystem data for a thread."""
|
|
||||||
path_manager = paths or get_paths()
|
|
||||||
try:
|
|
||||||
path_manager.delete_thread_dir(thread_id, user_id=user_id)
|
|
||||||
except ValueError as exc:
|
|
||||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
|
||||||
except FileNotFoundError:
|
|
||||||
# Not critical — thread data may not exist on disk
|
|
||||||
logger.debug("No local thread data to delete for %s", sanitize_log_param(thread_id))
|
|
||||||
return ThreadDeleteResponse(success=True, message=f"No local data for {thread_id}")
|
|
||||||
except Exception as exc:
|
|
||||||
logger.exception("Failed to delete thread data for %s", sanitize_log_param(thread_id))
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to delete local thread data.") from exc
|
|
||||||
|
|
||||||
logger.info("Deleted local thread data for %s", sanitize_log_param(thread_id))
|
|
||||||
return ThreadDeleteResponse(success=True, message=f"Deleted local thread data for {thread_id}")
|
|
||||||
|
|
||||||
|
|
||||||
def _derive_thread_status(checkpoint_tuple) -> str:
|
|
||||||
"""Derive thread status from checkpoint metadata."""
|
|
||||||
if checkpoint_tuple is None:
|
|
||||||
return "idle"
|
|
||||||
pending_writes = getattr(checkpoint_tuple, "pending_writes", None) or []
|
|
||||||
|
|
||||||
# Check for error in pending writes
|
|
||||||
for pw in pending_writes:
|
|
||||||
if len(pw) >= 2 and pw[1] == "__error__":
|
|
||||||
return "error"
|
|
||||||
|
|
||||||
# Check for pending next tasks (indicates interrupt)
|
|
||||||
tasks = getattr(checkpoint_tuple, "tasks", None)
|
|
||||||
if tasks:
|
|
||||||
return "interrupted"
|
|
||||||
|
|
||||||
return "idle"
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Endpoints
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{thread_id}", response_model=ThreadDeleteResponse)
|
|
||||||
@require_permission("threads", "delete", owner_check=True, require_existing=True)
|
|
||||||
async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteResponse:
|
|
||||||
"""Delete local persisted filesystem data for a thread.
|
|
||||||
|
|
||||||
Cleans DeerFlow-managed thread directories, removes checkpoint data,
|
|
||||||
and removes the thread_meta row from the configured ThreadMetaStore
|
|
||||||
(sqlite or memory).
|
|
||||||
"""
|
|
||||||
from app.gateway.deps import get_thread_store
|
|
||||||
|
|
||||||
# Clean local filesystem
|
|
||||||
response = _delete_thread_data(thread_id, user_id=get_effective_user_id())
|
|
||||||
|
|
||||||
# Remove checkpoints (best-effort)
|
|
||||||
checkpointer = getattr(request.app.state, "checkpointer", None)
|
|
||||||
if checkpointer is not None:
|
|
||||||
try:
|
|
||||||
if hasattr(checkpointer, "adelete_thread"):
|
|
||||||
await checkpointer.adelete_thread(thread_id)
|
|
||||||
except Exception:
|
|
||||||
logger.debug("Could not delete checkpoints for thread %s (not critical)", sanitize_log_param(thread_id))
|
|
||||||
|
|
||||||
# Remove thread_meta row (best-effort) — required for sqlite backend
|
|
||||||
# so the deleted thread no longer appears in /threads/search.
|
|
||||||
try:
|
|
||||||
thread_store = get_thread_store(request)
|
|
||||||
await thread_store.delete(thread_id)
|
|
||||||
except Exception:
|
|
||||||
logger.debug("Could not delete thread_meta for %s (not critical)", sanitize_log_param(thread_id))
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=ThreadResponse)
|
|
||||||
async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadResponse:
|
|
||||||
"""Create a new thread.
|
|
||||||
|
|
||||||
Writes a thread_meta record (so the thread appears in /threads/search)
|
|
||||||
and an empty checkpoint (so state endpoints work immediately).
|
|
||||||
Idempotent: returns the existing record when ``thread_id`` already exists.
|
|
||||||
"""
|
|
||||||
from app.gateway.deps import get_thread_store
|
|
||||||
|
|
||||||
checkpointer = get_checkpointer(request)
|
|
||||||
thread_store = get_thread_store(request)
|
|
||||||
thread_id = body.thread_id or str(uuid.uuid4())
|
|
||||||
now = time.time()
|
|
||||||
# ``body.metadata`` is already stripped of server-reserved keys by
|
|
||||||
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
|
|
||||||
|
|
||||||
# Idempotency: return existing record when already present
|
|
||||||
existing_record = await thread_store.get(thread_id)
|
|
||||||
if existing_record is not None:
|
|
||||||
return ThreadResponse(
|
|
||||||
thread_id=thread_id,
|
|
||||||
status=existing_record.get("status", "idle"),
|
|
||||||
created_at=str(existing_record.get("created_at", "")),
|
|
||||||
updated_at=str(existing_record.get("updated_at", "")),
|
|
||||||
metadata=existing_record.get("metadata", {}),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Write thread_meta so the thread appears in /threads/search immediately
|
|
||||||
try:
|
|
||||||
await thread_store.create(
|
|
||||||
thread_id,
|
|
||||||
assistant_id=getattr(body, "assistant_id", None),
|
|
||||||
metadata=body.metadata,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to write thread_meta for %s", sanitize_log_param(thread_id))
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to create thread")
|
|
||||||
|
|
||||||
# Write an empty checkpoint so state endpoints work immediately
|
|
||||||
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
|
||||||
try:
|
|
||||||
from langgraph.checkpoint.base import empty_checkpoint
|
|
||||||
|
|
||||||
ckpt_metadata = {
|
|
||||||
"step": -1,
|
|
||||||
"source": "input",
|
|
||||||
"writes": None,
|
|
||||||
"parents": {},
|
|
||||||
**body.metadata,
|
|
||||||
"created_at": now,
|
|
||||||
}
|
|
||||||
await checkpointer.aput(config, empty_checkpoint(), ckpt_metadata, {})
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to create checkpoint for thread %s", sanitize_log_param(thread_id))
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to create thread")
|
|
||||||
|
|
||||||
logger.info("Thread created: %s", sanitize_log_param(thread_id))
|
|
||||||
return ThreadResponse(
|
|
||||||
thread_id=thread_id,
|
|
||||||
status="idle",
|
|
||||||
created_at=str(now),
|
|
||||||
updated_at=str(now),
|
|
||||||
metadata=body.metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/search", response_model=list[ThreadResponse])
|
|
||||||
async def search_threads(body: ThreadSearchRequest, request: Request) -> list[ThreadResponse]:
|
|
||||||
"""Search and list threads.
|
|
||||||
|
|
||||||
Delegates to the configured ThreadMetaStore implementation
|
|
||||||
(SQL-backed for sqlite/postgres, Store-backed for memory mode).
|
|
||||||
"""
|
|
||||||
from app.gateway.deps import get_thread_store
|
|
||||||
|
|
||||||
repo = get_thread_store(request)
|
|
||||||
rows = await repo.search(
|
|
||||||
metadata=body.metadata or None,
|
|
||||||
status=body.status,
|
|
||||||
limit=body.limit,
|
|
||||||
offset=body.offset,
|
|
||||||
)
|
|
||||||
return [
|
|
||||||
ThreadResponse(
|
|
||||||
thread_id=r["thread_id"],
|
|
||||||
status=r.get("status", "idle"),
|
|
||||||
created_at=r.get("created_at", ""),
|
|
||||||
updated_at=r.get("updated_at", ""),
|
|
||||||
metadata=r.get("metadata", {}),
|
|
||||||
values={"title": r["display_name"]} if r.get("display_name") else {},
|
|
||||||
interrupts={},
|
|
||||||
)
|
|
||||||
for r in rows
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@router.patch("/{thread_id}", response_model=ThreadResponse)
|
|
||||||
@require_permission("threads", "write", owner_check=True, require_existing=True)
|
|
||||||
async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse:
|
|
||||||
"""Merge metadata into a thread record."""
|
|
||||||
from app.gateway.deps import get_thread_store
|
|
||||||
|
|
||||||
thread_store = get_thread_store(request)
|
|
||||||
record = await thread_store.get(thread_id)
|
|
||||||
if record is None:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
|
||||||
|
|
||||||
# ``body.metadata`` already stripped by ``ThreadPatchRequest._strip_reserved``.
|
|
||||||
try:
|
|
||||||
await thread_store.update_metadata(thread_id, body.metadata)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to patch thread %s", sanitize_log_param(thread_id))
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to update thread")
|
|
||||||
|
|
||||||
# Re-read to get the merged metadata + refreshed updated_at
|
|
||||||
record = await thread_store.get(thread_id) or record
|
|
||||||
return ThreadResponse(
|
|
||||||
thread_id=thread_id,
|
|
||||||
status=record.get("status", "idle"),
|
|
||||||
created_at=str(record.get("created_at", "")),
|
|
||||||
updated_at=str(record.get("updated_at", "")),
|
|
||||||
metadata=record.get("metadata", {}),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{thread_id}", response_model=ThreadResponse)
|
|
||||||
@require_permission("threads", "read", owner_check=True)
|
|
||||||
async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
|
||||||
"""Get thread info.
|
|
||||||
|
|
||||||
Reads metadata from the ThreadMetaStore and derives the accurate
|
|
||||||
execution status from the checkpointer. Falls back to the checkpointer
|
|
||||||
alone for threads that pre-date ThreadMetaStore adoption (backward compat).
|
|
||||||
"""
|
|
||||||
from app.gateway.deps import get_thread_store
|
|
||||||
|
|
||||||
thread_store = get_thread_store(request)
|
|
||||||
checkpointer = get_checkpointer(request)
|
|
||||||
|
|
||||||
record: dict | None = await thread_store.get(thread_id)
|
|
||||||
|
|
||||||
# Derive accurate status from the checkpointer
|
|
||||||
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
|
||||||
try:
|
|
||||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to get checkpoint for thread %s", sanitize_log_param(thread_id))
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to get thread")
|
|
||||||
|
|
||||||
if record is None and checkpoint_tuple is None:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
|
||||||
|
|
||||||
# If the thread exists in the checkpointer but not in thread_meta (e.g.
|
|
||||||
# legacy data created before thread_meta adoption), synthesize a minimal
|
|
||||||
# record from the checkpoint metadata.
|
|
||||||
if record is None and checkpoint_tuple is not None:
|
|
||||||
ckpt_meta = getattr(checkpoint_tuple, "metadata", {}) or {}
|
|
||||||
record = {
|
|
||||||
"thread_id": thread_id,
|
|
||||||
"status": "idle",
|
|
||||||
"created_at": ckpt_meta.get("created_at", ""),
|
|
||||||
"updated_at": ckpt_meta.get("updated_at", ckpt_meta.get("created_at", "")),
|
|
||||||
"metadata": {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")},
|
|
||||||
}
|
|
||||||
|
|
||||||
if record is None:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
|
||||||
|
|
||||||
status = _derive_thread_status(checkpoint_tuple) if checkpoint_tuple is not None else record.get("status", "idle")
|
|
||||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {} if checkpoint_tuple is not None else {}
|
|
||||||
channel_values = checkpoint.get("channel_values", {})
|
|
||||||
|
|
||||||
return ThreadResponse(
|
|
||||||
thread_id=thread_id,
|
|
||||||
status=status,
|
|
||||||
created_at=str(record.get("created_at", "")),
|
|
||||||
updated_at=str(record.get("updated_at", "")),
|
|
||||||
metadata=record.get("metadata", {}),
|
|
||||||
values=serialize_channel_values(channel_values),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
@router.get("/{thread_id}/state", response_model=ThreadStateResponse)
|
|
||||||
@require_permission("threads", "read", owner_check=True)
|
|
||||||
async def get_thread_state(thread_id: str, request: Request) -> ThreadStateResponse:
|
|
||||||
"""Get the latest state snapshot for a thread.
|
|
||||||
|
|
||||||
Channel values are serialized to ensure LangChain message objects
|
|
||||||
are converted to JSON-safe dicts.
|
|
||||||
"""
|
|
||||||
checkpointer = get_checkpointer(request)
|
|
||||||
|
|
||||||
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
|
||||||
try:
|
|
||||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id))
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to get thread state")
|
|
||||||
|
|
||||||
if checkpoint_tuple is None:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
|
||||||
|
|
||||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
|
||||||
metadata = getattr(checkpoint_tuple, "metadata", {}) or {}
|
|
||||||
checkpoint_id = None
|
|
||||||
ckpt_config = getattr(checkpoint_tuple, "config", {})
|
|
||||||
if ckpt_config:
|
|
||||||
checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id")
|
|
||||||
|
|
||||||
channel_values = checkpoint.get("channel_values", {})
|
|
||||||
|
|
||||||
parent_config = getattr(checkpoint_tuple, "parent_config", None)
|
|
||||||
parent_checkpoint_id = None
|
|
||||||
if parent_config:
|
|
||||||
parent_checkpoint_id = parent_config.get("configurable", {}).get("checkpoint_id")
|
|
||||||
|
|
||||||
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
|
|
||||||
next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")]
|
|
||||||
tasks = [{"id": getattr(t, "id", ""), "name": getattr(t, "name", "")} for t in tasks_raw]
|
|
||||||
|
|
||||||
values = serialize_channel_values(channel_values)
|
|
||||||
|
|
||||||
return ThreadStateResponse(
|
|
||||||
values=values,
|
|
||||||
next=next_tasks,
|
|
||||||
metadata=metadata,
|
|
||||||
checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))},
|
|
||||||
checkpoint_id=checkpoint_id,
|
|
||||||
parent_checkpoint_id=parent_checkpoint_id,
|
|
||||||
created_at=str(metadata.get("created_at", "")),
|
|
||||||
tasks=tasks,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{thread_id}/state", response_model=ThreadStateResponse)
|
|
||||||
@require_permission("threads", "write", owner_check=True, require_existing=True)
|
|
||||||
async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, request: Request) -> ThreadStateResponse:
|
|
||||||
"""Update thread state (e.g. for human-in-the-loop resume or title rename).
|
|
||||||
|
|
||||||
Writes a new checkpoint that merges *body.values* into the latest
|
|
||||||
channel values, then syncs any updated ``title`` field through the
|
|
||||||
ThreadMetaStore abstraction so that ``/threads/search`` reflects the
|
|
||||||
change immediately in both sqlite and memory backends.
|
|
||||||
"""
|
|
||||||
from app.gateway.deps import get_thread_store
|
|
||||||
|
|
||||||
checkpointer = get_checkpointer(request)
|
|
||||||
thread_store = get_thread_store(request)
|
|
||||||
|
|
||||||
# checkpoint_ns must be present in the config for aput — default to ""
|
|
||||||
# (the root graph namespace). checkpoint_id is optional; omitting it
|
|
||||||
# fetches the latest checkpoint for the thread.
|
|
||||||
read_config: dict[str, Any] = {
|
|
||||||
"configurable": {
|
|
||||||
"thread_id": thread_id,
|
|
||||||
"checkpoint_ns": "",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if body.checkpoint_id:
|
|
||||||
read_config["configurable"]["checkpoint_id"] = body.checkpoint_id
|
|
||||||
|
|
||||||
try:
|
|
||||||
checkpoint_tuple = await checkpointer.aget_tuple(read_config)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id))
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to get thread state")
|
|
||||||
|
|
||||||
if checkpoint_tuple is None:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
|
||||||
|
|
||||||
# Work on mutable copies so we don't accidentally mutate cached objects.
|
|
||||||
checkpoint: dict[str, Any] = dict(getattr(checkpoint_tuple, "checkpoint", {}) or {})
|
|
||||||
metadata: dict[str, Any] = dict(getattr(checkpoint_tuple, "metadata", {}) or {})
|
|
||||||
channel_values: dict[str, Any] = dict(checkpoint.get("channel_values", {}))
|
|
||||||
|
|
||||||
if body.values:
|
|
||||||
channel_values.update(body.values)
|
|
||||||
|
|
||||||
checkpoint["channel_values"] = channel_values
|
|
||||||
metadata["updated_at"] = time.time()
|
|
||||||
|
|
||||||
if body.as_node:
|
|
||||||
metadata["source"] = "update"
|
|
||||||
metadata["step"] = metadata.get("step", 0) + 1
|
|
||||||
metadata["writes"] = {body.as_node: body.values}
|
|
||||||
|
|
||||||
# aput requires checkpoint_ns in the config — use the same config used for the
|
|
||||||
# read (which always includes checkpoint_ns=""). Do NOT include checkpoint_id
|
|
||||||
# so that aput generates a fresh checkpoint ID for the new snapshot.
|
|
||||||
write_config: dict[str, Any] = {
|
|
||||||
"configurable": {
|
|
||||||
"thread_id": thread_id,
|
|
||||||
"checkpoint_ns": "",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
try:
|
|
||||||
new_config = await checkpointer.aput(write_config, checkpoint, metadata, {})
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to update state for thread %s", sanitize_log_param(thread_id))
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to update thread state")
|
|
||||||
|
|
||||||
new_checkpoint_id: str | None = None
|
|
||||||
if isinstance(new_config, dict):
|
|
||||||
new_checkpoint_id = new_config.get("configurable", {}).get("checkpoint_id")
|
|
||||||
|
|
||||||
# Sync title changes through the ThreadMetaStore abstraction so /threads/search
|
|
||||||
# reflects them immediately in both sqlite and memory backends.
|
|
||||||
if body.values and "title" in body.values:
|
|
||||||
new_title = body.values["title"]
|
|
||||||
if new_title: # Skip empty strings and None
|
|
||||||
try:
|
|
||||||
await thread_store.update_display_name(thread_id, new_title)
|
|
||||||
except Exception:
|
|
||||||
logger.debug("Failed to sync title to thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
|
||||||
|
|
||||||
return ThreadStateResponse(
|
|
||||||
values=serialize_channel_values(channel_values),
|
|
||||||
next=[],
|
|
||||||
metadata=metadata,
|
|
||||||
checkpoint_id=new_checkpoint_id,
|
|
||||||
created_at=str(metadata.get("created_at", "")),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{thread_id}/history", response_model=list[HistoryEntry])
|
|
||||||
@require_permission("threads", "read", owner_check=True)
|
|
||||||
async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request: Request) -> list[HistoryEntry]:
|
|
||||||
"""Get checkpoint history for a thread.
|
|
||||||
|
|
||||||
Messages are read from the checkpointer's channel values (the
|
|
||||||
authoritative source) and serialized via
|
|
||||||
:func:`~deerflow.runtime.serialization.serialize_channel_values`.
|
|
||||||
Only the latest (first) checkpoint carries the ``messages`` key to
|
|
||||||
avoid duplicating them across every entry.
|
|
||||||
"""
|
|
||||||
checkpointer = get_checkpointer(request)
|
|
||||||
|
|
||||||
config: dict[str, Any] = {"configurable": {"thread_id": thread_id}}
|
|
||||||
if body.before:
|
|
||||||
config["configurable"]["checkpoint_id"] = body.before
|
|
||||||
|
|
||||||
entries: list[HistoryEntry] = []
|
|
||||||
is_latest_checkpoint = True
|
|
||||||
try:
|
|
||||||
async for checkpoint_tuple in checkpointer.alist(config, limit=body.limit):
|
|
||||||
ckpt_config = getattr(checkpoint_tuple, "config", {})
|
|
||||||
parent_config = getattr(checkpoint_tuple, "parent_config", None)
|
|
||||||
metadata = getattr(checkpoint_tuple, "metadata", {}) or {}
|
|
||||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
|
||||||
|
|
||||||
checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id", "")
|
|
||||||
parent_id = None
|
|
||||||
if parent_config:
|
|
||||||
parent_id = parent_config.get("configurable", {}).get("checkpoint_id")
|
|
||||||
|
|
||||||
channel_values = checkpoint.get("channel_values", {})
|
|
||||||
|
|
||||||
# Build values from checkpoint channel_values
|
|
||||||
values: dict[str, Any] = {}
|
|
||||||
if title := channel_values.get("title"):
|
|
||||||
values["title"] = title
|
|
||||||
if thread_data := channel_values.get("thread_data"):
|
|
||||||
values["thread_data"] = thread_data
|
|
||||||
|
|
||||||
# Attach messages only to the latest checkpoint entry.
|
|
||||||
if is_latest_checkpoint:
|
|
||||||
messages = channel_values.get("messages")
|
|
||||||
if messages:
|
|
||||||
values["messages"] = serialize_channel_values({"messages": messages}).get("messages", [])
|
|
||||||
is_latest_checkpoint = False
|
|
||||||
|
|
||||||
# Derive next tasks
|
|
||||||
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
|
|
||||||
next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")]
|
|
||||||
|
|
||||||
# Strip LangGraph internal keys from metadata
|
|
||||||
user_meta = {k: v for k, v in metadata.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")}
|
|
||||||
# Keep step for ordering context
|
|
||||||
if "step" in metadata:
|
|
||||||
user_meta["step"] = metadata["step"]
|
|
||||||
|
|
||||||
entries.append(
|
|
||||||
HistoryEntry(
|
|
||||||
checkpoint_id=checkpoint_id,
|
|
||||||
parent_checkpoint_id=parent_id,
|
|
||||||
metadata=user_meta,
|
|
||||||
values=values,
|
|
||||||
created_at=str(metadata.get("created_at", "")),
|
|
||||||
next=next_tasks,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to get history for thread %s", sanitize_log_param(thread_id))
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to get thread history")
|
|
||||||
|
|
||||||
return entries
|
|
||||||
@@ -7,10 +7,10 @@ import stat
|
|||||||
from fastapi import APIRouter, File, HTTPException, Request, UploadFile
|
from fastapi import APIRouter, File, HTTPException, Request, UploadFile
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from app.gateway.authz import require_permission
|
from app.plugins.auth.security.actor_context import bind_request_actor_context
|
||||||
from deerflow.config.paths import get_paths
|
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
|
||||||
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
||||||
|
from deerflow.config.paths import get_paths
|
||||||
|
from deerflow.runtime.actor_context import get_effective_user_id
|
||||||
from deerflow.uploads.manager import (
|
from deerflow.uploads.manager import (
|
||||||
PathTraversalError,
|
PathTraversalError,
|
||||||
delete_file_safe,
|
delete_file_safe,
|
||||||
@@ -56,7 +56,6 @@ def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=UploadResponse)
|
@router.post("", response_model=UploadResponse)
|
||||||
@require_permission("threads", "write", owner_check=True, require_existing=False)
|
|
||||||
async def upload_files(
|
async def upload_files(
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
@@ -66,68 +65,69 @@ async def upload_files(
|
|||||||
if not files:
|
if not files:
|
||||||
raise HTTPException(status_code=400, detail="No files provided")
|
raise HTTPException(status_code=400, detail="No files provided")
|
||||||
|
|
||||||
try:
|
with bind_request_actor_context(request):
|
||||||
uploads_dir = ensure_uploads_dir(thread_id)
|
|
||||||
except ValueError as e:
|
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
|
||||||
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
|
|
||||||
uploaded_files = []
|
|
||||||
|
|
||||||
sandbox_provider = get_sandbox_provider()
|
|
||||||
sandbox_id = sandbox_provider.acquire(thread_id)
|
|
||||||
sandbox = sandbox_provider.get(sandbox_id)
|
|
||||||
|
|
||||||
for file in files:
|
|
||||||
if not file.filename:
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
safe_filename = normalize_filename(file.filename)
|
uploads_dir = ensure_uploads_dir(thread_id)
|
||||||
except ValueError:
|
except ValueError as e:
|
||||||
logger.warning(f"Skipping file with unsafe filename: {file.filename!r}")
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
continue
|
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
|
||||||
|
uploaded_files = []
|
||||||
|
|
||||||
try:
|
sandbox_provider = get_sandbox_provider()
|
||||||
content = await file.read()
|
sandbox_id = sandbox_provider.acquire(thread_id)
|
||||||
file_path = uploads_dir / safe_filename
|
sandbox = sandbox_provider.get(sandbox_id)
|
||||||
file_path.write_bytes(content)
|
|
||||||
|
|
||||||
virtual_path = upload_virtual_path(safe_filename)
|
for file in files:
|
||||||
|
if not file.filename:
|
||||||
|
continue
|
||||||
|
|
||||||
if sandbox_id != "local":
|
try:
|
||||||
_make_file_sandbox_writable(file_path)
|
safe_filename = normalize_filename(file.filename)
|
||||||
sandbox.update_file(virtual_path, content)
|
except ValueError:
|
||||||
|
logger.warning(f"Skipping file with unsafe filename: {file.filename!r}")
|
||||||
|
continue
|
||||||
|
|
||||||
file_info = {
|
try:
|
||||||
"filename": safe_filename,
|
content = await file.read()
|
||||||
"size": str(len(content)),
|
file_path = uploads_dir / safe_filename
|
||||||
"path": str(sandbox_uploads / safe_filename),
|
file_path.write_bytes(content)
|
||||||
"virtual_path": virtual_path,
|
|
||||||
"artifact_url": upload_artifact_url(thread_id, safe_filename),
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(f"Saved file: {safe_filename} ({len(content)} bytes) to {file_info['path']}")
|
virtual_path = upload_virtual_path(safe_filename)
|
||||||
|
|
||||||
file_ext = file_path.suffix.lower()
|
if sandbox_id != "local":
|
||||||
if file_ext in CONVERTIBLE_EXTENSIONS:
|
_make_file_sandbox_writable(file_path)
|
||||||
md_path = await convert_file_to_markdown(file_path)
|
sandbox.update_file(virtual_path, content)
|
||||||
if md_path:
|
|
||||||
md_virtual_path = upload_virtual_path(md_path.name)
|
|
||||||
|
|
||||||
if sandbox_id != "local":
|
file_info = {
|
||||||
_make_file_sandbox_writable(md_path)
|
"filename": safe_filename,
|
||||||
sandbox.update_file(md_virtual_path, md_path.read_bytes())
|
"size": str(len(content)),
|
||||||
|
"path": str(sandbox_uploads / safe_filename),
|
||||||
|
"virtual_path": virtual_path,
|
||||||
|
"artifact_url": upload_artifact_url(thread_id, safe_filename),
|
||||||
|
}
|
||||||
|
|
||||||
file_info["markdown_file"] = md_path.name
|
logger.info(f"Saved file: {safe_filename} ({len(content)} bytes) to {file_info['path']}")
|
||||||
file_info["markdown_path"] = str(sandbox_uploads / md_path.name)
|
|
||||||
file_info["markdown_virtual_path"] = md_virtual_path
|
|
||||||
file_info["markdown_artifact_url"] = upload_artifact_url(thread_id, md_path.name)
|
|
||||||
|
|
||||||
uploaded_files.append(file_info)
|
file_ext = file_path.suffix.lower()
|
||||||
|
if file_ext in CONVERTIBLE_EXTENSIONS:
|
||||||
|
md_path = await convert_file_to_markdown(file_path)
|
||||||
|
if md_path:
|
||||||
|
md_virtual_path = upload_virtual_path(md_path.name)
|
||||||
|
|
||||||
except Exception as e:
|
if sandbox_id != "local":
|
||||||
logger.error(f"Failed to upload {file.filename}: {e}")
|
_make_file_sandbox_writable(md_path)
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to upload {file.filename}: {str(e)}")
|
sandbox.update_file(md_virtual_path, md_path.read_bytes())
|
||||||
|
|
||||||
|
file_info["markdown_file"] = md_path.name
|
||||||
|
file_info["markdown_path"] = str(sandbox_uploads / md_path.name)
|
||||||
|
file_info["markdown_virtual_path"] = md_virtual_path
|
||||||
|
file_info["markdown_artifact_url"] = upload_artifact_url(thread_id, md_path.name)
|
||||||
|
|
||||||
|
uploaded_files.append(file_info)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to upload {file.filename}: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"Failed to upload {file.filename}: {str(e)}")
|
||||||
|
|
||||||
return UploadResponse(
|
return UploadResponse(
|
||||||
success=True,
|
success=True,
|
||||||
@@ -137,26 +137,25 @@ async def upload_files(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/list", response_model=dict)
|
@router.get("/list", response_model=dict)
|
||||||
@require_permission("threads", "read", owner_check=True)
|
|
||||||
async def list_uploaded_files(thread_id: str, request: Request) -> dict:
|
async def list_uploaded_files(thread_id: str, request: Request) -> dict:
|
||||||
"""List all files in a thread's uploads directory."""
|
"""List all files in a thread's uploads directory."""
|
||||||
try:
|
with bind_request_actor_context(request):
|
||||||
uploads_dir = get_uploads_dir(thread_id)
|
try:
|
||||||
except ValueError as e:
|
uploads_dir = get_uploads_dir(thread_id)
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
except ValueError as e:
|
||||||
result = list_files_in_dir(uploads_dir)
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
enrich_file_listing(result, thread_id)
|
result = list_files_in_dir(uploads_dir)
|
||||||
|
enrich_file_listing(result, thread_id)
|
||||||
|
|
||||||
# Gateway additionally includes the sandbox-relative path.
|
# Gateway additionally includes the sandbox-relative path.
|
||||||
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
|
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
|
||||||
for f in result["files"]:
|
for f in result["files"]:
|
||||||
f["path"] = str(sandbox_uploads / f["filename"])
|
f["path"] = str(sandbox_uploads / f["filename"])
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{filename}")
|
@router.delete("/{filename}")
|
||||||
@require_permission("threads", "delete", owner_check=True, require_existing=True)
|
|
||||||
async def delete_uploaded_file(thread_id: str, filename: str, request: Request) -> dict:
|
async def delete_uploaded_file(thread_id: str, filename: str, request: Request) -> dict:
|
||||||
"""Delete a file from a thread's uploads directory."""
|
"""Delete a file from a thread's uploads directory."""
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user