mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 16:35:59 +00:00
feat(threads): switch search endpoint to threads_meta table and sync title
- POST /api/threads/search now queries threads_meta table directly, removing the two-phase Store + Checkpointer scan approach - Add ThreadMetaRepository.search() with metadata/status filters - Add ThreadMetaRepository.update_display_name() for title sync - Worker syncs checkpoint title to threads_meta.display_name on run completion - Map display_name to values.title in search response for API compatibility Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -317,107 +317,31 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
|||||||
|
|
||||||
@router.post("/search", response_model=list[ThreadResponse])
|
@router.post("/search", response_model=list[ThreadResponse])
|
||||||
async def search_threads(body: ThreadSearchRequest, request: Request) -> list[ThreadResponse]:
|
async def search_threads(body: ThreadSearchRequest, request: Request) -> list[ThreadResponse]:
|
||||||
"""Search and list threads.
|
"""Search and list threads from the threads_meta table."""
|
||||||
|
from app.gateway.deps import get_thread_meta_repo
|
||||||
|
|
||||||
Two-phase approach:
|
repo = get_thread_meta_repo(request)
|
||||||
|
if repo is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Thread metadata store not available")
|
||||||
|
|
||||||
**Phase 1 — Store (fast path, O(threads))**: returns threads that were
|
rows = await repo.search(
|
||||||
created or run through this Gateway. Store records are tiny metadata
|
metadata=body.metadata or None,
|
||||||
dicts so fetching all of them at once is cheap.
|
status=body.status,
|
||||||
|
limit=body.limit,
|
||||||
**Phase 2 — Checkpointer supplement (lazy migration)**: threads that
|
offset=body.offset,
|
||||||
were created directly by LangGraph Server (and therefore absent from the
|
)
|
||||||
Store) are discovered here by iterating the shared checkpointer. Any
|
return [
|
||||||
newly found thread is immediately written to the Store so that the next
|
ThreadResponse(
|
||||||
search skips Phase 2 for that thread — the Store converges to a full
|
thread_id=r["thread_id"],
|
||||||
index over time without a one-shot migration job.
|
status=r.get("status", "idle"),
|
||||||
"""
|
created_at=r.get("created_at", ""),
|
||||||
store = get_store(request)
|
updated_at=r.get("updated_at", ""),
|
||||||
checkpointer = get_checkpointer(request)
|
metadata=r.get("metadata", {}),
|
||||||
|
values={"title": r["display_name"]} if r.get("display_name") else {},
|
||||||
# -----------------------------------------------------------------------
|
interrupts={},
|
||||||
# Phase 1: Store
|
)
|
||||||
# -----------------------------------------------------------------------
|
for r in rows
|
||||||
merged: dict[str, ThreadResponse] = {}
|
]
|
||||||
|
|
||||||
if store is not None:
|
|
||||||
try:
|
|
||||||
items = await store.asearch(THREADS_NS, limit=10_000)
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Store search failed — falling back to checkpointer only", exc_info=True)
|
|
||||||
items = []
|
|
||||||
|
|
||||||
for item in items:
|
|
||||||
val = item.value
|
|
||||||
merged[val["thread_id"]] = ThreadResponse(
|
|
||||||
thread_id=val["thread_id"],
|
|
||||||
status=val.get("status", "idle"),
|
|
||||||
created_at=str(val.get("created_at", "")),
|
|
||||||
updated_at=str(val.get("updated_at", "")),
|
|
||||||
metadata=val.get("metadata", {}),
|
|
||||||
values=val.get("values", {}),
|
|
||||||
)
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------
|
|
||||||
# Phase 2: Checkpointer supplement
|
|
||||||
# Discovers threads not yet in the Store (e.g. created by LangGraph
|
|
||||||
# Server) and lazily migrates them so future searches skip this phase.
|
|
||||||
# -----------------------------------------------------------------------
|
|
||||||
try:
|
|
||||||
async for checkpoint_tuple in checkpointer.alist(None):
|
|
||||||
cfg = getattr(checkpoint_tuple, "config", {})
|
|
||||||
thread_id = cfg.get("configurable", {}).get("thread_id")
|
|
||||||
if not thread_id or thread_id in merged:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Skip sub-graph checkpoints (checkpoint_ns is non-empty for those)
|
|
||||||
if cfg.get("configurable", {}).get("checkpoint_ns", ""):
|
|
||||||
continue
|
|
||||||
|
|
||||||
ckpt_meta = getattr(checkpoint_tuple, "metadata", {}) or {}
|
|
||||||
# Strip LangGraph internal keys from the user-visible metadata dict
|
|
||||||
user_meta = {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")}
|
|
||||||
|
|
||||||
# Extract state values (title) from the checkpoint's channel_values
|
|
||||||
checkpoint_data = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
|
||||||
channel_values = checkpoint_data.get("channel_values", {})
|
|
||||||
ckpt_values = {}
|
|
||||||
if title := channel_values.get("title"):
|
|
||||||
ckpt_values["title"] = title
|
|
||||||
|
|
||||||
thread_resp = ThreadResponse(
|
|
||||||
thread_id=thread_id,
|
|
||||||
status=_derive_thread_status(checkpoint_tuple),
|
|
||||||
created_at=str(ckpt_meta.get("created_at", "")),
|
|
||||||
updated_at=str(ckpt_meta.get("updated_at", ckpt_meta.get("created_at", ""))),
|
|
||||||
metadata=user_meta,
|
|
||||||
values=ckpt_values,
|
|
||||||
)
|
|
||||||
merged[thread_id] = thread_resp
|
|
||||||
|
|
||||||
# Lazy migration — write to Store so the next search finds it there
|
|
||||||
if store is not None:
|
|
||||||
try:
|
|
||||||
await _store_upsert(store, thread_id, metadata=user_meta, values=ckpt_values or None)
|
|
||||||
except Exception:
|
|
||||||
logger.debug("Failed to migrate thread %s to store (non-fatal)", thread_id)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Checkpointer scan failed during thread search")
|
|
||||||
# Don't raise — return whatever was collected from Store + partial scan
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------
|
|
||||||
# Phase 3: Filter → sort → paginate
|
|
||||||
# -----------------------------------------------------------------------
|
|
||||||
results = list(merged.values())
|
|
||||||
|
|
||||||
if body.metadata:
|
|
||||||
results = [r for r in results if all(r.metadata.get(k) == v for k, v in body.metadata.items())]
|
|
||||||
|
|
||||||
if body.status:
|
|
||||||
results = [r for r in results if r.status == body.status]
|
|
||||||
|
|
||||||
results.sort(key=lambda r: r.updated_at, reverse=True)
|
|
||||||
return results[body.offset : body.offset + body.limit]
|
|
||||||
|
|
||||||
|
|
||||||
@router.patch("/{thread_id}", response_model=ThreadResponse)
|
@router.patch("/{thread_id}", response_model=ThreadResponse)
|
||||||
|
|||||||
@@ -323,6 +323,7 @@ async def start_run(
|
|||||||
event_store=event_store,
|
event_store=event_store,
|
||||||
run_events_config=run_events_config,
|
run_events_config=run_events_config,
|
||||||
follow_up_to_run_id=follow_up_to_run_id,
|
follow_up_to_run_id=follow_up_to_run_id,
|
||||||
|
thread_meta_repo=thread_meta_repo,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
record.task = task
|
record.task = task
|
||||||
|
|||||||
@@ -78,6 +78,37 @@ class ThreadMetaRepository:
|
|||||||
return True
|
return True
|
||||||
return row.owner_id == owner_id
|
return row.owner_id == owner_id
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
metadata: dict | None = None,
|
||||||
|
status: str | None = None,
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Search threads with optional metadata and status filters."""
|
||||||
|
stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc())
|
||||||
|
if status:
|
||||||
|
stmt = stmt.where(ThreadMetaRow.status == status)
|
||||||
|
stmt = stmt.limit(limit).offset(offset)
|
||||||
|
async with self._sf() as session:
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
rows = [self._row_to_dict(r) for r in result.scalars()]
|
||||||
|
|
||||||
|
if metadata:
|
||||||
|
rows = [r for r in rows if all(r.get("metadata", {}).get(k) == v for k, v in metadata.items())]
|
||||||
|
return rows
|
||||||
|
|
||||||
|
async def update_display_name(self, thread_id: str, display_name: str) -> None:
|
||||||
|
"""Update the display_name (title) for a thread."""
|
||||||
|
async with self._sf() as session:
|
||||||
|
await session.execute(
|
||||||
|
update(ThreadMetaRow)
|
||||||
|
.where(ThreadMetaRow.thread_id == thread_id)
|
||||||
|
.values(display_name=display_name, updated_at=datetime.now(UTC))
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
async def update_status(self, thread_id: str, status: str) -> None:
|
async def update_status(self, thread_id: str, status: str) -> None:
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(status=status, updated_at=datetime.now(UTC)))
|
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(status=status, updated_at=datetime.now(UTC)))
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ async def run_agent(
|
|||||||
event_store: Any | None = None,
|
event_store: Any | None = None,
|
||||||
run_events_config: Any | None = None,
|
run_events_config: Any | None = None,
|
||||||
follow_up_to_run_id: str | None = None,
|
follow_up_to_run_id: str | None = None,
|
||||||
|
thread_meta_repo: Any | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Execute an agent in the background, publishing events to *bridge*."""
|
"""Execute an agent in the background, publishing events to *bridge*."""
|
||||||
|
|
||||||
@@ -262,6 +263,19 @@ async def run_agent(
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Failed to persist run completion for %s", run_id, exc_info=True)
|
logger.warning("Failed to persist run completion for %s", run_id, exc_info=True)
|
||||||
|
|
||||||
|
# Sync title from checkpoint to threads_meta.display_name
|
||||||
|
if thread_meta_repo is not None and checkpointer is not None:
|
||||||
|
try:
|
||||||
|
ckpt_config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||||
|
ckpt_tuple = await checkpointer.aget_tuple(ckpt_config)
|
||||||
|
if ckpt_tuple is not None:
|
||||||
|
ckpt = getattr(ckpt_tuple, "checkpoint", {}) or {}
|
||||||
|
title = ckpt.get("channel_values", {}).get("title")
|
||||||
|
if title:
|
||||||
|
await thread_meta_repo.update_display_name(thread_id, title)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Failed to sync title for thread %s (non-fatal)", thread_id)
|
||||||
|
|
||||||
await bridge.publish_end(run_id)
|
await bridge.publish_end(run_id)
|
||||||
asyncio.create_task(bridge.cleanup(run_id, delay=60))
|
asyncio.create_task(bridge.cleanup(run_id, delay=60))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user