mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-24 08:55:59 +00:00
refactor(routers): reorganize routers with new langgraph/ subdirectory
Restructure app/gateway/routers/: - Add langgraph/ subdirectory for LangGraph-related endpoints: - threads.py - thread management - runs.py - run execution and streaming - feedback.py - feedback endpoints - suggestions.py - follow-up suggestions Remove old standalone routers: - threads.py → langgraph/threads.py - thread_runs.py → langgraph/runs.py - runs.py (stateless) → langgraph/runs.py - feedback.py → langgraph/feedback.py Update existing routers: - memory.py, uploads.py, artifacts.py, suggestions.py Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,6 @@
|
||||
from .feedback import router as feedback_router
|
||||
from .runs import router as runs_router
|
||||
from .suggestions import router as suggestion_router
|
||||
from .threads import router as threads_router
|
||||
|
||||
__all__ = ["feedback_router", "runs_router", "threads_router", "suggestion_router"]
|
||||
@@ -0,0 +1,179 @@
|
||||
"""LangGraph-compatible run feedback endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.gateway.dependencies import get_feedback_repository, get_run_repository
|
||||
from app.plugins.auth.security.actor_context import bind_request_actor_context, resolve_request_user_id
|
||||
from app.plugins.auth.security.dependencies import get_current_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(tags=["feedback"])
|
||||
|
||||
|
||||
class FeedbackCreateRequest(BaseModel):
|
||||
rating: int = Field(..., description="Feedback rating: +1 (positive) or -1 (negative)")
|
||||
comment: str | None = Field(default=None, description="Optional text feedback")
|
||||
message_id: str | None = Field(default=None, description="Optional: scope feedback to a specific message")
|
||||
|
||||
|
||||
class FeedbackResponse(BaseModel):
|
||||
feedback_id: str
|
||||
run_id: str
|
||||
thread_id: str
|
||||
owner_id: str | None = None
|
||||
message_id: str | None = None
|
||||
rating: int
|
||||
comment: str | None = None
|
||||
created_at: str = ""
|
||||
|
||||
|
||||
class FeedbackStatsResponse(BaseModel):
|
||||
run_id: str
|
||||
total: int = 0
|
||||
positive: int = 0
|
||||
negative: int = 0
|
||||
|
||||
|
||||
async def _validate_run_scope(thread_id: str, run_id: str, request: Request) -> None:
|
||||
run_store = get_run_repository(request)
|
||||
if resolve_request_user_id(request) is None:
|
||||
run = await run_store.get(run_id, user_id=None)
|
||||
else:
|
||||
with bind_request_actor_context(request):
|
||||
run = await run_store.get(run_id)
|
||||
if run is None:
|
||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||
if run.get("thread_id") != thread_id:
|
||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found in thread {thread_id}")
|
||||
|
||||
|
||||
async def _get_current_user(request: Request) -> str | None:
|
||||
"""Extract current user id from auth dependencies when available."""
|
||||
return await get_current_user_id(request)
|
||||
|
||||
|
||||
async def _create_feedback(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
body: FeedbackCreateRequest,
|
||||
request: Request,
|
||||
) -> dict[str, Any]:
|
||||
if body.rating not in (1, -1):
|
||||
raise HTTPException(status_code=400, detail="rating must be +1 or -1")
|
||||
|
||||
await _validate_run_scope(thread_id, run_id, request)
|
||||
user_id = await _get_current_user(request)
|
||||
feedback_repo = get_feedback_repository(request)
|
||||
return await feedback_repo.create(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
rating=body.rating,
|
||||
user_id=user_id,
|
||||
message_id=body.message_id,
|
||||
comment=body.comment,
|
||||
)
|
||||
|
||||
|
||||
@router.put("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
|
||||
async def upsert_feedback(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
body: FeedbackCreateRequest,
|
||||
request: Request,
|
||||
) -> dict[str, Any]:
|
||||
"""Create or replace the run-level feedback record."""
|
||||
feedback_repo = get_feedback_repository(request)
|
||||
user_id = await _get_current_user(request)
|
||||
if user_id is not None:
|
||||
return await feedback_repo.upsert(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
rating=body.rating,
|
||||
user_id=user_id,
|
||||
comment=body.comment,
|
||||
)
|
||||
existing = await feedback_repo.list_by_run(thread_id, run_id, limit=100, user_id=None)
|
||||
for item in existing:
|
||||
feedback_id = item.get("feedback_id")
|
||||
if isinstance(feedback_id, str):
|
||||
await feedback_repo.delete(feedback_id)
|
||||
return await _create_feedback(thread_id, run_id, body, request)
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
|
||||
async def create_feedback(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
body: FeedbackCreateRequest,
|
||||
request: Request,
|
||||
) -> dict[str, Any]:
|
||||
"""Submit feedback for a run."""
|
||||
return await _create_feedback(thread_id, run_id, body, request)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}/feedback", response_model=list[FeedbackResponse])
|
||||
async def list_feedback(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
request: Request,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""List all feedback for a run."""
|
||||
feedback_repo = get_feedback_repository(request)
|
||||
user_id = await _get_current_user(request)
|
||||
return await feedback_repo.list_by_run(thread_id, run_id, user_id=user_id)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}/feedback/stats", response_model=FeedbackStatsResponse)
|
||||
async def feedback_stats(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
request: Request,
|
||||
) -> dict[str, Any]:
|
||||
"""Get aggregated feedback stats for a run."""
|
||||
feedback_repo = get_feedback_repository(request)
|
||||
return await feedback_repo.aggregate_by_run(thread_id, run_id)
|
||||
|
||||
|
||||
@router.delete("/{thread_id}/runs/{run_id}/feedback")
|
||||
async def delete_run_feedback(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
request: Request,
|
||||
) -> dict[str, bool]:
|
||||
"""Delete all feedback records for a run."""
|
||||
feedback_repo = get_feedback_repository(request)
|
||||
user_id = await _get_current_user(request)
|
||||
if user_id is not None:
|
||||
return {"success": await feedback_repo.delete_by_run(thread_id=thread_id, run_id=run_id, user_id=user_id)}
|
||||
existing = await feedback_repo.list_by_run(thread_id, run_id, limit=100, user_id=None)
|
||||
for item in existing:
|
||||
feedback_id = item.get("feedback_id")
|
||||
if isinstance(feedback_id, str):
|
||||
await feedback_repo.delete(feedback_id)
|
||||
return {"success": True}
|
||||
|
||||
|
||||
@router.delete("/{thread_id}/runs/{run_id}/feedback/{feedback_id}")
|
||||
async def delete_feedback(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
feedback_id: str,
|
||||
request: Request,
|
||||
) -> dict[str, bool]:
|
||||
"""Delete a single feedback record."""
|
||||
feedback_repo = get_feedback_repository(request)
|
||||
existing = await feedback_repo.get(feedback_id)
|
||||
if existing is None:
|
||||
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found")
|
||||
if existing.get("thread_id") != thread_id or existing.get("run_id") != run_id:
|
||||
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found in run {run_id}")
|
||||
deleted = await feedback_repo.delete(feedback_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found")
|
||||
return {"success": True}
|
||||
@@ -0,0 +1,501 @@
|
||||
"""LangGraph-compatible runs endpoints backed by RunsFacade."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.plugins.auth.security.actor_context import bind_request_actor_context
|
||||
from app.gateway.services.runs.facade_factory import build_runs_facade_from_request
|
||||
from app.gateway.services.runs.input import (
|
||||
AdaptedRunRequest,
|
||||
RunSpecBuilder,
|
||||
UnsupportedRunFeatureError,
|
||||
adapt_create_run_request,
|
||||
adapt_create_stream_request,
|
||||
adapt_create_wait_request,
|
||||
adapt_join_stream_request,
|
||||
adapt_join_wait_request,
|
||||
)
|
||||
from deerflow.runtime.runs.types import RunRecord, RunSpec
|
||||
from deerflow.runtime.stream_bridge import JSONValue, StreamEvent
|
||||
|
||||
router = APIRouter(tags=["runs"])
|
||||
|
||||
|
||||
class RunCreateRequest(BaseModel):
|
||||
assistant_id: str | None = Field(default=None, description="Agent / assistant to use")
|
||||
follow_up_to_run_id: str | None = Field(default=None, description="Lineage link to the prior run")
|
||||
input: dict[str, JSONValue] | None = Field(default=None, description="Graph input (e.g. {messages: [...]})")
|
||||
command: dict[str, JSONValue] | None = Field(default=None, description="LangGraph Command")
|
||||
metadata: dict[str, JSONValue] | None = Field(default=None, description="Run metadata")
|
||||
config: dict[str, JSONValue] | None = Field(default=None, description="RunnableConfig overrides")
|
||||
context: dict[str, JSONValue] | None = Field(default=None, description="DeerFlow context overrides (model_name, thinking_enabled, etc.)")
|
||||
webhook: str | None = Field(default=None, description="Completion callback URL")
|
||||
checkpoint_id: str | None = Field(default=None, description="Resume from checkpoint")
|
||||
checkpoint: dict[str, JSONValue] | None = Field(default=None, description="Full checkpoint object")
|
||||
interrupt_before: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt before")
|
||||
interrupt_after: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt after")
|
||||
stream_mode: list[str] | str | None = Field(default=None, description="Stream mode(s)")
|
||||
stream_subgraphs: bool = Field(default=False, description="Include subgraph events")
|
||||
stream_resumable: bool | None = Field(default=None, description="SSE resumable mode")
|
||||
on_disconnect: Literal["cancel", "continue"] = Field(default="cancel", description="Behaviour on SSE disconnect")
|
||||
on_completion: Literal["delete", "keep"] = Field(default="keep", description="Delete temp thread on completion")
|
||||
multitask_strategy: Literal["reject", "rollback", "interrupt", "enqueue"] = Field(default="reject", description="Concurrency strategy")
|
||||
after_seconds: float | None = Field(default=None, description="Delayed execution")
|
||||
if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy")
|
||||
feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys")
|
||||
|
||||
|
||||
class RunResponse(BaseModel):
|
||||
run_id: str
|
||||
thread_id: str
|
||||
assistant_id: str | None = None
|
||||
status: str
|
||||
metadata: dict[str, JSONValue] = Field(default_factory=dict)
|
||||
multitask_strategy: str = "reject"
|
||||
created_at: str = ""
|
||||
updated_at: str = ""
|
||||
|
||||
|
||||
class RunDeleteResponse(BaseModel):
|
||||
deleted: bool
|
||||
|
||||
|
||||
class RunMessageResponse(BaseModel):
|
||||
run_id: str
|
||||
content: JSONValue
|
||||
metadata: dict[str, JSONValue] = Field(default_factory=dict)
|
||||
created_at: str
|
||||
seq: int
|
||||
|
||||
|
||||
class RunMessagesResponse(BaseModel):
|
||||
data: list[RunMessageResponse]
|
||||
hasMore: bool = False
|
||||
|
||||
|
||||
def format_sse(event: str, data: JSONValue, *, event_id: str | None = None) -> str:
|
||||
"""Format a single SSE frame."""
|
||||
payload = json.dumps(data, default=str, ensure_ascii=False)
|
||||
parts = [f"event: {event}", f"data: {payload}"]
|
||||
if event_id:
|
||||
parts.append(f"id: {event_id}")
|
||||
parts.append("")
|
||||
parts.append("")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _record_to_response(record: RunRecord) -> RunResponse:
|
||||
return RunResponse(
|
||||
run_id=record.run_id,
|
||||
thread_id=record.thread_id,
|
||||
assistant_id=record.assistant_id,
|
||||
status=record.status,
|
||||
metadata=record.metadata,
|
||||
multitask_strategy=record.multitask_strategy,
|
||||
created_at=record.created_at,
|
||||
updated_at=record.updated_at,
|
||||
)
|
||||
|
||||
|
||||
def _trim_paginated_rows(
|
||||
rows: list[dict],
|
||||
*,
|
||||
limit: int,
|
||||
after_seq: int | None,
|
||||
) -> tuple[list[dict], bool]:
|
||||
has_more = len(rows) > limit
|
||||
if not has_more:
|
||||
return rows, False
|
||||
if after_seq is not None:
|
||||
return rows[:limit], True
|
||||
return rows[-limit:], True
|
||||
|
||||
|
||||
def _event_to_run_message(event: dict) -> RunMessageResponse:
|
||||
return RunMessageResponse(
|
||||
run_id=str(event["run_id"]),
|
||||
content=event.get("content"),
|
||||
metadata=dict(event.get("metadata") or {}),
|
||||
created_at=str(event.get("created_at") or ""),
|
||||
seq=int(event["seq"]),
|
||||
)
|
||||
|
||||
|
||||
async def _sse_consumer(
|
||||
stream: AsyncIterator[StreamEvent],
|
||||
request: Request,
|
||||
*,
|
||||
cancel_on_disconnect: bool,
|
||||
cancel_run,
|
||||
run_id: str,
|
||||
) -> AsyncIterator[str]:
|
||||
try:
|
||||
async for event in stream:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
if event.event == "__heartbeat__":
|
||||
yield ": heartbeat\n\n"
|
||||
continue
|
||||
|
||||
if event.event == "__end__":
|
||||
yield format_sse("end", None, event_id=event.id or None)
|
||||
return
|
||||
|
||||
if event.event == "__cancelled__":
|
||||
yield format_sse("cancel", None, event_id=event.id or None)
|
||||
return
|
||||
|
||||
yield format_sse(event.event, event.data, event_id=event.id or None)
|
||||
finally:
|
||||
if cancel_on_disconnect:
|
||||
await cancel_run(run_id)
|
||||
|
||||
|
||||
def _get_run_event_store(request: Request):
|
||||
event_store = getattr(request.app.state, "run_event_store", None)
|
||||
if event_store is None:
|
||||
raise HTTPException(status_code=503, detail="Run event store not available")
|
||||
return event_store
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs", response_model=list[RunResponse])
|
||||
async def list_runs(
|
||||
thread_id: str,
|
||||
request: Request,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
status: str | None = None,
|
||||
) -> list[RunResponse]:
|
||||
# Accepted for API compatibility; field projection is not implemented yet.
|
||||
facade = build_runs_facade_from_request(request)
|
||||
with bind_request_actor_context(request):
|
||||
records = await facade.list_runs(thread_id)
|
||||
if status is not None:
|
||||
records = [record for record in records if record.status == status]
|
||||
records = records[offset : offset + limit]
|
||||
return [_record_to_response(record) for record in records]
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}", response_model=RunResponse)
|
||||
async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
|
||||
facade = build_runs_facade_from_request(request)
|
||||
with bind_request_actor_context(request):
|
||||
record = await facade.get_run(run_id)
|
||||
if record is None or record.thread_id != thread_id:
|
||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||
return _record_to_response(record)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}/messages", response_model=RunMessagesResponse)
|
||||
async def run_messages(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
request: Request,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
) -> RunMessagesResponse:
|
||||
facade = build_runs_facade_from_request(request)
|
||||
with bind_request_actor_context(request):
|
||||
record = await facade.get_run(run_id)
|
||||
if record is None or record.thread_id != thread_id:
|
||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||
|
||||
event_store = _get_run_event_store(request)
|
||||
with bind_request_actor_context(request):
|
||||
rows = await event_store.list_messages_by_run(
|
||||
thread_id,
|
||||
run_id,
|
||||
limit=limit + 1,
|
||||
before_seq=before_seq,
|
||||
after_seq=after_seq,
|
||||
)
|
||||
page, has_more = _trim_paginated_rows(rows, limit=limit, after_seq=after_seq)
|
||||
return RunMessagesResponse(data=[_event_to_run_message(row) for row in page], hasMore=has_more)
|
||||
|
||||
|
||||
def _build_spec(
|
||||
*,
|
||||
adapted: AdaptedRunRequest,
|
||||
) -> RunSpec:
|
||||
try:
|
||||
return RunSpecBuilder().build(adapted)
|
||||
except UnsupportedRunFeatureError as exc:
|
||||
raise HTTPException(status_code=501, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs", response_model=RunResponse)
|
||||
async def create_run(
|
||||
thread_id: str,
|
||||
body: RunCreateRequest,
|
||||
request: Request,
|
||||
) -> Response:
|
||||
adapted = adapt_create_run_request(
|
||||
thread_id=thread_id,
|
||||
body=body.model_dump(),
|
||||
headers=dict(request.headers),
|
||||
query=dict(request.query_params),
|
||||
)
|
||||
spec = _build_spec(adapted=adapted)
|
||||
facade = build_runs_facade_from_request(request)
|
||||
with bind_request_actor_context(request):
|
||||
record = await facade.create_background(spec)
|
||||
return Response(
|
||||
content=_record_to_response(record).model_dump_json(),
|
||||
media_type="application/json",
|
||||
headers={"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}"},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/stream")
|
||||
async def stream_run(
|
||||
thread_id: str,
|
||||
body: RunCreateRequest,
|
||||
request: Request,
|
||||
) -> StreamingResponse:
|
||||
adapted = adapt_create_stream_request(
|
||||
thread_id=thread_id,
|
||||
body=body.model_dump(),
|
||||
headers=dict(request.headers),
|
||||
query=dict(request.query_params),
|
||||
)
|
||||
|
||||
spec = _build_spec(adapted=adapted)
|
||||
|
||||
facade = build_runs_facade_from_request(request)
|
||||
with bind_request_actor_context(request):
|
||||
record, stream = await facade.create_and_stream(spec)
|
||||
|
||||
return StreamingResponse(
|
||||
_sse_consumer(
|
||||
stream,
|
||||
request,
|
||||
cancel_on_disconnect=spec.on_disconnect == "cancel",
|
||||
cancel_run=facade.cancel,
|
||||
run_id=record.run_id,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/wait")
|
||||
async def wait_run(
|
||||
thread_id: str,
|
||||
body: RunCreateRequest,
|
||||
request: Request,
|
||||
) -> Response:
|
||||
adapted = adapt_create_wait_request(
|
||||
thread_id=thread_id,
|
||||
body=body.model_dump(),
|
||||
headers=dict(request.headers),
|
||||
query=dict(request.query_params),
|
||||
)
|
||||
spec = _build_spec(adapted=adapted)
|
||||
facade = build_runs_facade_from_request(request)
|
||||
with bind_request_actor_context(request):
|
||||
record, result = await facade.create_and_wait(spec)
|
||||
return Response(
|
||||
content=json.dumps(result, default=str, ensure_ascii=False),
|
||||
media_type="application/json",
|
||||
headers={"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}"},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/runs", response_model=RunResponse)
|
||||
async def create_stateless_run(body: RunCreateRequest, request: Request) -> Response:
|
||||
adapted = adapt_create_run_request(
|
||||
thread_id=None,
|
||||
body=body.model_dump(),
|
||||
headers=dict(request.headers),
|
||||
query=dict(request.query_params),
|
||||
)
|
||||
spec = _build_spec(adapted=adapted)
|
||||
facade = build_runs_facade_from_request(request)
|
||||
with bind_request_actor_context(request):
|
||||
record = await facade.create_background(spec)
|
||||
return Response(
|
||||
content=_record_to_response(record).model_dump_json(),
|
||||
media_type="application/json",
|
||||
headers={"Content-Location": f"/api/threads/{record.thread_id}/runs/{record.run_id}"},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/runs/stream")
|
||||
async def create_stateless_stream_run(body: RunCreateRequest, request: Request) -> StreamingResponse:
|
||||
adapted = adapt_create_stream_request(
|
||||
thread_id=None,
|
||||
body=body.model_dump(),
|
||||
headers=dict(request.headers),
|
||||
query=dict(request.query_params),
|
||||
)
|
||||
spec = _build_spec(adapted=adapted)
|
||||
facade = build_runs_facade_from_request(request)
|
||||
with bind_request_actor_context(request):
|
||||
record, stream = await facade.create_and_stream(spec)
|
||||
|
||||
return StreamingResponse(
|
||||
_sse_consumer(
|
||||
stream,
|
||||
request,
|
||||
cancel_on_disconnect=spec.on_disconnect == "cancel",
|
||||
cancel_run=facade.cancel,
|
||||
run_id=record.run_id,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"Content-Location": f"/api/threads/{record.thread_id}/runs/{record.run_id}",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/runs/wait")
|
||||
async def wait_stateless_run(body: RunCreateRequest, request: Request) -> Response:
|
||||
adapted = adapt_create_wait_request(
|
||||
thread_id=None,
|
||||
body=body.model_dump(),
|
||||
headers=dict(request.headers),
|
||||
query=dict(request.query_params),
|
||||
)
|
||||
spec = _build_spec(adapted=adapted)
|
||||
facade = build_runs_facade_from_request(request)
|
||||
with bind_request_actor_context(request):
|
||||
record, result = await facade.create_and_wait(spec)
|
||||
return Response(
|
||||
content=json.dumps(result, default=str, ensure_ascii=False),
|
||||
media_type="application/json",
|
||||
headers={"Content-Location": f"/api/threads/{record.thread_id}/runs/{record.run_id}"},
|
||||
)
|
||||
|
||||
|
||||
@router.api_route("/{thread_id}/runs/{run_id}/stream", methods=["GET", "POST"], response_model=None)
|
||||
async def stream_existing_run(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
request: Request,
|
||||
action: Literal["interrupt", "rollback"] | None = None,
|
||||
wait: bool = False,
|
||||
cancel_on_disconnect: bool = False,
|
||||
stream_mode: str | None = None,
|
||||
) -> StreamingResponse | Response:
|
||||
facade = build_runs_facade_from_request(request)
|
||||
with bind_request_actor_context(request):
|
||||
record = await facade.get_run(run_id)
|
||||
if record is None or record.thread_id != thread_id:
|
||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||
|
||||
if action is not None:
|
||||
with bind_request_actor_context(request):
|
||||
cancelled = await facade.cancel(run_id, action=action)
|
||||
if not cancelled:
|
||||
raise HTTPException(status_code=409, detail=f"Run {run_id} is not cancellable")
|
||||
if wait:
|
||||
with bind_request_actor_context(request):
|
||||
await facade.join_wait(run_id)
|
||||
return Response(status_code=204)
|
||||
|
||||
adapted = adapt_join_stream_request(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
headers=dict(request.headers),
|
||||
query=dict(request.query_params),
|
||||
)
|
||||
with bind_request_actor_context(request):
|
||||
stream = await facade.join_stream(run_id, last_event_id=adapted.last_event_id)
|
||||
|
||||
return StreamingResponse(
|
||||
_sse_consumer(
|
||||
stream,
|
||||
request,
|
||||
cancel_on_disconnect=cancel_on_disconnect,
|
||||
cancel_run=facade.cancel,
|
||||
run_id=run_id,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}/join")
|
||||
async def join_existing_run(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
request: Request,
|
||||
cancel_on_disconnect: bool = False,
|
||||
) -> JSONValue:
|
||||
# Accepted for API compatibility; current join_wait path does not change
|
||||
# behavior based on client disconnect.
|
||||
facade = build_runs_facade_from_request(request)
|
||||
with bind_request_actor_context(request):
|
||||
record = await facade.get_run(run_id)
|
||||
if record is None or record.thread_id != thread_id:
|
||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||
|
||||
adapted = adapt_join_wait_request(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
headers=dict(request.headers),
|
||||
query=dict(request.query_params),
|
||||
)
|
||||
with bind_request_actor_context(request):
|
||||
return await facade.join_wait(run_id, last_event_id=adapted.last_event_id)
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/{run_id}/cancel")
|
||||
async def cancel_existing_run(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
request: Request,
|
||||
wait: bool = False,
|
||||
action: Literal["interrupt", "rollback"] = "interrupt",
|
||||
) -> JSONValue:
|
||||
facade = build_runs_facade_from_request(request)
|
||||
with bind_request_actor_context(request):
|
||||
record = await facade.get_run(run_id)
|
||||
if record is None or record.thread_id != thread_id:
|
||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||
|
||||
with bind_request_actor_context(request):
|
||||
cancelled = await facade.cancel(run_id, action=action)
|
||||
if not cancelled:
|
||||
raise HTTPException(status_code=409, detail=f"Run {run_id} is not cancellable")
|
||||
if wait:
|
||||
with bind_request_actor_context(request):
|
||||
return await facade.join_wait(run_id)
|
||||
return {}
|
||||
|
||||
|
||||
@router.delete("/{thread_id}/runs/{run_id}", response_model=RunDeleteResponse)
|
||||
async def delete_run(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
request: Request,
|
||||
) -> RunDeleteResponse:
|
||||
facade = build_runs_facade_from_request(request)
|
||||
with bind_request_actor_context(request):
|
||||
record = await facade.get_run(run_id)
|
||||
if record is None or record.thread_id != thread_id:
|
||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||
with bind_request_actor_context(request):
|
||||
deleted = await facade.delete_run(run_id)
|
||||
return RunDeleteResponse(deleted=deleted)
|
||||
@@ -0,0 +1,132 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from deerflow.models import create_chat_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["suggestions"])
|
||||
|
||||
|
||||
class SuggestionMessage(BaseModel):
|
||||
role: str = Field(..., description="Message role: user|assistant")
|
||||
content: str = Field(..., description="Message content as plain text")
|
||||
|
||||
|
||||
class SuggestionsRequest(BaseModel):
|
||||
messages: list[SuggestionMessage] = Field(..., description="Recent conversation messages")
|
||||
n: int = Field(default=3, ge=1, le=5, description="Number of suggestions to generate")
|
||||
model_name: str | None = Field(default=None, description="Optional model override")
|
||||
|
||||
|
||||
class SuggestionsResponse(BaseModel):
|
||||
suggestions: list[str] = Field(default_factory=list, description="Suggested follow-up questions")
|
||||
|
||||
|
||||
def _strip_markdown_code_fence(text: str) -> str:
|
||||
stripped = text.strip()
|
||||
if not stripped.startswith("```"):
|
||||
return stripped
|
||||
lines = stripped.splitlines()
|
||||
if len(lines) >= 3 and lines[0].startswith("```") and lines[-1].startswith("```"):
|
||||
return "\n".join(lines[1:-1]).strip()
|
||||
return stripped
|
||||
|
||||
|
||||
def _parse_json_string_list(text: str) -> list[str] | None:
|
||||
candidate = _strip_markdown_code_fence(text)
|
||||
start = candidate.find("[")
|
||||
end = candidate.rfind("]")
|
||||
if start == -1 or end == -1 or end <= start:
|
||||
return None
|
||||
candidate = candidate[start : end + 1]
|
||||
try:
|
||||
data = json.loads(candidate)
|
||||
except Exception:
|
||||
return None
|
||||
if not isinstance(data, list):
|
||||
return None
|
||||
out: list[str] = []
|
||||
for item in data:
|
||||
if not isinstance(item, str):
|
||||
continue
|
||||
s = item.strip()
|
||||
if not s:
|
||||
continue
|
||||
out.append(s)
|
||||
return out
|
||||
|
||||
|
||||
def _extract_response_text(content: object) -> str:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
parts.append(block)
|
||||
elif isinstance(block, dict) and block.get("type") in {"text", "output_text"}:
|
||||
text = block.get("text")
|
||||
if isinstance(text, str):
|
||||
parts.append(text)
|
||||
return "\n".join(parts) if parts else ""
|
||||
if content is None:
|
||||
return ""
|
||||
return str(content)
|
||||
|
||||
|
||||
def _format_conversation(messages: list[SuggestionMessage]) -> str:
|
||||
parts: list[str] = []
|
||||
for m in messages:
|
||||
role = m.role.strip().lower()
|
||||
if role in ("user", "human"):
|
||||
parts.append(f"User: {m.content.strip()}")
|
||||
elif role in ("assistant", "ai"):
|
||||
parts.append(f"Assistant: {m.content.strip()}")
|
||||
else:
|
||||
parts.append(f"{m.role}: {m.content.strip()}")
|
||||
return "\n".join(parts).strip()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/threads/{thread_id}/suggestions",
|
||||
response_model=SuggestionsResponse,
|
||||
summary="Generate Follow-up Questions",
|
||||
description="Generate short follow-up questions a user might ask next, based on recent conversation context.",
|
||||
)
|
||||
async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> SuggestionsResponse:
|
||||
if not request.messages:
|
||||
return SuggestionsResponse(suggestions=[])
|
||||
|
||||
n = request.n
|
||||
conversation = _format_conversation(request.messages)
|
||||
if not conversation:
|
||||
return SuggestionsResponse(suggestions=[])
|
||||
|
||||
system_instruction = (
|
||||
"You are generating follow-up questions to help the user continue the conversation.\n"
|
||||
f"Based on the conversation below, produce EXACTLY {n} short questions the user might ask next.\n"
|
||||
"Requirements:\n"
|
||||
"- Questions must be relevant to the preceding conversation.\n"
|
||||
"- Questions must be written in the same language as the user.\n"
|
||||
"- Keep each question concise (ideally <= 20 words / <= 40 Chinese characters).\n"
|
||||
"- Do NOT include numbering, markdown, or any extra text.\n"
|
||||
"- Output MUST be a JSON array of strings only.\n"
|
||||
)
|
||||
user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions"
|
||||
|
||||
try:
|
||||
model = create_chat_model(name=request.model_name, thinking_enabled=False)
|
||||
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)])
|
||||
raw = _extract_response_text(response.content)
|
||||
suggestions = _parse_json_string_list(raw) or []
|
||||
cleaned = [s.replace("\n", " ").strip() for s in suggestions if s.strip()]
|
||||
cleaned = cleaned[:n]
|
||||
return SuggestionsResponse(suggestions=cleaned)
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to generate suggestions: thread_id=%s err=%s", thread_id, exc)
|
||||
return SuggestionsResponse(suggestions=[])
|
||||
@@ -0,0 +1,455 @@
|
||||
"""Thread management endpoints.
|
||||
|
||||
Provides CRUD operations for threads and checkpoint state management.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.gateway.dependencies import CurrentCheckpointer, CurrentRunRepository, CurrentThreadMetaStorage
|
||||
from app.infra.storage import ThreadMetaStorage
|
||||
from app.plugins.auth.security.actor_context import bind_request_actor_context, resolve_request_user_id
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
from deerflow.runtime import serialize_channel_values
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(tags=["threads"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request / Response Models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ThreadCreateRequest(BaseModel):
|
||||
thread_id: str | None = Field(default=None, description="Optional thread ID (auto-generated if omitted)")
|
||||
assistant_id: str | None = Field(default=None, description="Associate thread with an assistant")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Initial metadata")
|
||||
|
||||
|
||||
class ThreadSearchRequest(BaseModel):
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Metadata filter (exact match)")
|
||||
limit: int = Field(default=100, ge=1, le=1000, description="Maximum results")
|
||||
offset: int = Field(default=0, ge=0, description="Pagination offset")
|
||||
status: str | None = Field(default=None, description="Filter by thread status")
|
||||
user_id: str | None = Field(default=None, description="Filter by user ID")
|
||||
assistant_id: str | None = Field(default=None, description="Filter by assistant ID")
|
||||
|
||||
|
||||
class ThreadResponse(BaseModel):
|
||||
thread_id: str = Field(description="Unique thread identifier")
|
||||
status: str = Field(default="idle", description="Thread status")
|
||||
created_at: str = Field(default="", description="ISO timestamp")
|
||||
updated_at: str = Field(default="", description="ISO timestamp")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Thread metadata")
|
||||
values: dict[str, Any] = Field(default_factory=dict, description="Current state values")
|
||||
interrupts: dict[str, Any] = Field(default_factory=dict, description="Pending interrupts")
|
||||
|
||||
|
||||
class ThreadDeleteResponse(BaseModel):
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class ThreadStateUpdateRequest(BaseModel):
|
||||
values: dict[str, Any] | None = Field(default=None, description="Channel values to merge")
|
||||
checkpoint_id: str | None = Field(default=None, description="Checkpoint to branch from")
|
||||
checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object")
|
||||
as_node: str | None = Field(default=None, description="Node identity for the update")
|
||||
|
||||
|
||||
class ThreadStateResponse(BaseModel):
|
||||
values: dict[str, Any] = Field(default_factory=dict, description="Current channel values")
|
||||
next: list[str] = Field(default_factory=list, description="Next nodes to execute")
|
||||
tasks: list[dict[str, Any]] = Field(default_factory=list, description="Interrupted task details")
|
||||
checkpoint: dict[str, Any] = Field(default_factory=dict, description="Checkpoint info")
|
||||
checkpoint_id: str | None = Field(default=None, description="Current checkpoint ID")
|
||||
parent_checkpoint_id: str | None = Field(default=None, description="Parent checkpoint ID")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Checkpoint metadata")
|
||||
created_at: str | None = Field(default=None, description="Checkpoint timestamp")
|
||||
|
||||
|
||||
class ThreadHistoryRequest(BaseModel):
|
||||
limit: int = Field(default=10, ge=1, le=100, description="Maximum entries")
|
||||
before: str | None = Field(default=None, description="Cursor for pagination (checkpoint_id)")
|
||||
|
||||
|
||||
class HistoryEntry(BaseModel):
|
||||
checkpoint_id: str
|
||||
parent_checkpoint_id: str | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
values: dict[str, Any] = Field(default_factory=dict)
|
||||
created_at: str | None = None
|
||||
next: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def sanitize_log_param(value: str) -> str:
|
||||
"""Strip control characters to prevent log injection."""
|
||||
|
||||
return value.replace("\n", "").replace("\r", "").replace("\x00", "")
|
||||
|
||||
|
||||
def _delete_thread_data(thread_id: str, paths: Paths | None = None) -> ThreadDeleteResponse:
|
||||
"""Delete local filesystem data for a thread."""
|
||||
path_manager = paths or get_paths()
|
||||
try:
|
||||
path_manager.delete_thread_dir(thread_id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
except FileNotFoundError:
|
||||
logger.debug("No local thread data to delete for %s", sanitize_log_param(thread_id))
|
||||
return ThreadDeleteResponse(success=True, message=f"No local data for {thread_id}")
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to delete thread data for %s", sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to delete local thread data.") from exc
|
||||
|
||||
logger.info("Deleted local thread data for %s", sanitize_log_param(thread_id))
|
||||
return ThreadDeleteResponse(success=True, message=f"Deleted local thread data for {thread_id}")
|
||||
|
||||
|
||||
async def _thread_or_run_exists(
|
||||
*,
|
||||
request: Request,
|
||||
thread_id: str,
|
||||
thread_meta_storage: ThreadMetaStorage,
|
||||
run_repo,
|
||||
) -> bool:
|
||||
request_user_id = resolve_request_user_id(request)
|
||||
|
||||
if request_user_id is None:
|
||||
thread = await thread_meta_storage.get_thread(thread_id, user_id=None)
|
||||
if thread is not None:
|
||||
return True
|
||||
runs = await run_repo.list_by_thread(thread_id, limit=1, user_id=None)
|
||||
return bool(runs)
|
||||
|
||||
with bind_request_actor_context(request):
|
||||
thread = await thread_meta_storage.get_thread(thread_id)
|
||||
if thread is not None:
|
||||
return True
|
||||
runs = await run_repo.list_by_thread(thread_id, limit=1)
|
||||
return bool(runs)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("", response_model=ThreadResponse)
|
||||
async def create_thread(
|
||||
body: ThreadCreateRequest,
|
||||
request: Request,
|
||||
thread_meta_storage: CurrentThreadMetaStorage,
|
||||
) -> ThreadResponse:
|
||||
"""Create a new thread."""
|
||||
thread_id = body.thread_id or str(uuid.uuid4())
|
||||
|
||||
request_user_id = resolve_request_user_id(request)
|
||||
if request_user_id is None:
|
||||
existing = await thread_meta_storage.get_thread(thread_id, user_id=None)
|
||||
else:
|
||||
with bind_request_actor_context(request):
|
||||
existing = await thread_meta_storage.get_thread(thread_id)
|
||||
if existing is not None:
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status=existing.status,
|
||||
created_at=existing.created_time.isoformat() if existing.created_time else "",
|
||||
updated_at=existing.updated_time.isoformat() if existing.updated_time else "",
|
||||
metadata=existing.metadata,
|
||||
)
|
||||
|
||||
try:
|
||||
if request_user_id is None:
|
||||
created = await thread_meta_storage.ensure_thread(
|
||||
thread_id=thread_id,
|
||||
assistant_id=body.assistant_id,
|
||||
metadata=body.metadata,
|
||||
user_id=None,
|
||||
)
|
||||
else:
|
||||
with bind_request_actor_context(request):
|
||||
created = await thread_meta_storage.ensure_thread(
|
||||
thread_id=thread_id,
|
||||
assistant_id=body.assistant_id,
|
||||
metadata=body.metadata,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to create thread %s", sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to create thread")
|
||||
|
||||
logger.info("Thread created: %s", sanitize_log_param(thread_id))
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status=created.status,
|
||||
created_at=created.created_time.isoformat() if created.created_time else "",
|
||||
updated_at=created.updated_time.isoformat() if created.updated_time else "",
|
||||
metadata=created.metadata,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/search", response_model=list[ThreadResponse])
|
||||
async def search_threads(
|
||||
body: ThreadSearchRequest,
|
||||
request: Request,
|
||||
thread_meta_storage: CurrentThreadMetaStorage,
|
||||
) -> list[ThreadResponse]:
|
||||
"""Search threads with filters."""
|
||||
try:
|
||||
request_user_id = resolve_request_user_id(request)
|
||||
if request_user_id is None:
|
||||
threads = await thread_meta_storage.search_threads(
|
||||
metadata=body.metadata or None,
|
||||
status=body.status,
|
||||
user_id=body.user_id,
|
||||
assistant_id=body.assistant_id,
|
||||
limit=body.limit,
|
||||
offset=body.offset,
|
||||
)
|
||||
else:
|
||||
with bind_request_actor_context(request):
|
||||
threads = await thread_meta_storage.search_threads(
|
||||
metadata=body.metadata or None,
|
||||
status=body.status,
|
||||
assistant_id=body.assistant_id,
|
||||
limit=body.limit,
|
||||
offset=body.offset,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to search threads")
|
||||
raise HTTPException(status_code=500, detail="Failed to search threads")
|
||||
|
||||
return [
|
||||
ThreadResponse(
|
||||
thread_id=t.thread_id,
|
||||
status=t.status,
|
||||
created_at=t.created_time.isoformat() if t.created_time else "",
|
||||
updated_at=t.updated_time.isoformat() if t.updated_time else "",
|
||||
metadata=t.metadata,
|
||||
values={"title": t.display_name} if t.display_name else {},
|
||||
interrupts={},
|
||||
)
|
||||
for t in threads
|
||||
]
|
||||
|
||||
|
||||
@router.delete("/{thread_id}", response_model=ThreadDeleteResponse)
|
||||
async def delete_thread(
|
||||
thread_id: str,
|
||||
checkpointer: CurrentCheckpointer,
|
||||
thread_meta_storage: CurrentThreadMetaStorage,
|
||||
) -> ThreadDeleteResponse:
|
||||
"""Delete a thread and all associated data."""
|
||||
response = _delete_thread_data(thread_id)
|
||||
|
||||
# Remove checkpoints (best-effort)
|
||||
try:
|
||||
if hasattr(checkpointer, "adelete_thread"):
|
||||
await checkpointer.adelete_thread(thread_id)
|
||||
except Exception:
|
||||
logger.debug("Could not delete checkpoints for thread %s", sanitize_log_param(thread_id))
|
||||
|
||||
# Remove thread_meta (best-effort)
|
||||
try:
|
||||
await thread_meta_storage.delete_thread(thread_id)
|
||||
except Exception:
|
||||
logger.debug("Could not delete thread_meta for %s", sanitize_log_param(thread_id))
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.get("/{thread_id}/state", response_model=ThreadStateResponse)
|
||||
async def get_thread_state(
|
||||
thread_id: str,
|
||||
request: Request,
|
||||
checkpointer: CurrentCheckpointer,
|
||||
thread_meta_storage: CurrentThreadMetaStorage,
|
||||
run_repo: CurrentRunRepository,
|
||||
) -> ThreadStateResponse:
|
||||
"""Get the latest state snapshot for a thread."""
|
||||
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||
|
||||
try:
|
||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
||||
except Exception:
|
||||
logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to get thread state")
|
||||
|
||||
if checkpoint_tuple is None:
|
||||
if await _thread_or_run_exists(
|
||||
request=request,
|
||||
thread_id=thread_id,
|
||||
thread_meta_storage=thread_meta_storage,
|
||||
run_repo=run_repo,
|
||||
):
|
||||
return ThreadStateResponse()
|
||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||
|
||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||
metadata = getattr(checkpoint_tuple, "metadata", {}) or {}
|
||||
channel_values = checkpoint.get("channel_values", {})
|
||||
|
||||
ckpt_config = getattr(checkpoint_tuple, "config", {}) or {}
|
||||
checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id")
|
||||
|
||||
parent_config = getattr(checkpoint_tuple, "parent_config", None)
|
||||
parent_checkpoint_id = parent_config.get("configurable", {}).get("checkpoint_id") if parent_config else None
|
||||
|
||||
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
|
||||
next_nodes = [t.name for t in tasks_raw if hasattr(t, "name")]
|
||||
tasks = [{"id": getattr(t, "id", ""), "name": getattr(t, "name", "")} for t in tasks_raw]
|
||||
|
||||
return ThreadStateResponse(
|
||||
values=serialize_channel_values(channel_values),
|
||||
next=next_nodes,
|
||||
tasks=tasks,
|
||||
checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))},
|
||||
checkpoint_id=checkpoint_id,
|
||||
parent_checkpoint_id=parent_checkpoint_id,
|
||||
metadata=metadata,
|
||||
created_at=str(metadata.get("created_at", "")),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{thread_id}/state", response_model=ThreadStateResponse)
|
||||
async def update_thread_state(
|
||||
thread_id: str,
|
||||
body: ThreadStateUpdateRequest,
|
||||
checkpointer: CurrentCheckpointer,
|
||||
thread_meta_storage: CurrentThreadMetaStorage,
|
||||
) -> ThreadStateResponse:
|
||||
"""Update thread state (human-in-the-loop or title rename)."""
|
||||
read_config: dict[str, Any] = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||
if body.checkpoint_id:
|
||||
read_config["configurable"]["checkpoint_id"] = body.checkpoint_id
|
||||
|
||||
try:
|
||||
checkpoint_tuple = await checkpointer.aget_tuple(read_config)
|
||||
except Exception:
|
||||
logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to get thread state")
|
||||
|
||||
if checkpoint_tuple is None:
|
||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||
|
||||
checkpoint: dict[str, Any] = dict(getattr(checkpoint_tuple, "checkpoint", {}) or {})
|
||||
metadata: dict[str, Any] = dict(getattr(checkpoint_tuple, "metadata", {}) or {})
|
||||
channel_values: dict[str, Any] = dict(checkpoint.get("channel_values", {}))
|
||||
|
||||
if body.values:
|
||||
channel_values.update(body.values)
|
||||
|
||||
checkpoint["channel_values"] = channel_values
|
||||
metadata["updated_at"] = time.time()
|
||||
|
||||
if body.as_node:
|
||||
metadata["source"] = "update"
|
||||
metadata["step"] = metadata.get("step", 0) + 1
|
||||
metadata["writes"] = {body.as_node: body.values}
|
||||
|
||||
write_config: dict[str, Any] = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||
try:
|
||||
new_config = await checkpointer.aput(write_config, checkpoint, metadata, {})
|
||||
except Exception:
|
||||
logger.exception("Failed to update state for thread %s", sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to update thread state")
|
||||
|
||||
new_checkpoint_id: str | None = None
|
||||
if isinstance(new_config, dict):
|
||||
new_checkpoint_id = new_config.get("configurable", {}).get("checkpoint_id")
|
||||
|
||||
# Sync title to thread_meta
|
||||
if body.values and "title" in body.values:
|
||||
new_title = body.values["title"]
|
||||
if new_title:
|
||||
try:
|
||||
await thread_meta_storage.sync_thread_title(
|
||||
thread_id=thread_id,
|
||||
title=new_title,
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to sync title for %s", sanitize_log_param(thread_id))
|
||||
|
||||
return ThreadStateResponse(
|
||||
values=serialize_channel_values(channel_values),
|
||||
next=[],
|
||||
metadata=metadata,
|
||||
checkpoint_id=new_checkpoint_id,
|
||||
created_at=str(metadata.get("created_at", "")),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{thread_id}/history", response_model=list[HistoryEntry])
|
||||
async def get_thread_history(
|
||||
thread_id: str,
|
||||
body: ThreadHistoryRequest,
|
||||
request: Request,
|
||||
checkpointer: CurrentCheckpointer,
|
||||
thread_meta_storage: CurrentThreadMetaStorage,
|
||||
run_repo: CurrentRunRepository,
|
||||
) -> list[HistoryEntry]:
|
||||
"""Get checkpoint history for a thread."""
|
||||
config: dict[str, Any] = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||
if body.before:
|
||||
config["configurable"]["checkpoint_id"] = body.before
|
||||
|
||||
entries: list[HistoryEntry] = []
|
||||
is_first = True
|
||||
|
||||
try:
|
||||
async for checkpoint_tuple in checkpointer.alist(config, limit=body.limit):
|
||||
ckpt_config = getattr(checkpoint_tuple, "config", {}) or {}
|
||||
parent_config = getattr(checkpoint_tuple, "parent_config", None)
|
||||
metadata = getattr(checkpoint_tuple, "metadata", {}) or {}
|
||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||
|
||||
checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id", "")
|
||||
parent_id = parent_config.get("configurable", {}).get("checkpoint_id") if parent_config else None
|
||||
channel_values = checkpoint.get("channel_values", {})
|
||||
|
||||
values: dict[str, Any] = {}
|
||||
if title := channel_values.get("title"):
|
||||
values["title"] = title
|
||||
if is_first and (messages := channel_values.get("messages")):
|
||||
values["messages"] = serialize_channel_values({"messages": messages}).get("messages", [])
|
||||
is_first = False
|
||||
|
||||
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
|
||||
next_nodes = [t.name for t in tasks_raw if hasattr(t, "name")]
|
||||
|
||||
entries.append(
|
||||
HistoryEntry(
|
||||
checkpoint_id=checkpoint_id,
|
||||
parent_checkpoint_id=parent_id,
|
||||
metadata=metadata,
|
||||
values=values,
|
||||
created_at=str(metadata.get("created_at", "")),
|
||||
next=next_nodes,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to get history for thread %s", sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to get thread history")
|
||||
|
||||
if not entries and await _thread_or_run_exists(
|
||||
request=request,
|
||||
thread_id=thread_id,
|
||||
thread_meta_storage=thread_meta_storage,
|
||||
run_repo=run_repo,
|
||||
):
|
||||
return []
|
||||
|
||||
return entries
|
||||
Reference in New Issue
Block a user