refactor(gateway): route all thread metadata access through ThreadMetaStore

Following the rename/delete bug fix in PR1, migrate the remaining direct
LangGraph Store reads/writes in the threads router and services to the
ThreadMetaStore abstraction so that the sqlite and memory backends behave
identically and the legacy dual-write paths can be removed.

Migrated endpoints (threads.py):
- create_thread: idempotency check + write now use thread_meta_repo.get/create
  instead of dual-writing the LangGraph Store and the SQL row.
- get_thread: reads from thread_meta_repo.get; the checkpoint-only fallback
  for legacy threads is preserved.
- patch_thread: replaced _store_get/_store_put with thread_meta_repo.update_metadata.
- delete_thread_data: dropped the legacy store.adelete; thread_meta_repo.delete
  already covers it.

Removed dead code (services.py):
- _upsert_thread_in_store — redundant with the immediately following
  thread_meta_repo.create() call.
- _sync_thread_title_after_run — worker.py's finally block already syncs
  the title via thread_meta_repo.update_display_name() after each run.

Removed dead code (threads.py):
- _store_get / _store_put / _store_upsert helpers (no remaining callers).
- THREADS_NS constant.
- get_store import (router no longer touches the LangGraph Store directly).

New abstract method:
- ThreadMetaStore.update_metadata(thread_id, metadata) merges metadata into
  the thread's metadata field. Implemented in both ThreadMetaRepository (SQL,
  read-modify-write inside one session) and MemoryThreadMetaStore. Three new
  unit tests cover merge / empty / nonexistent behaviour.

Net change: -134 lines. Full test suite: 1693 passed, 14 skipped.
Verified end-to-end with curl in gateway mode against sqlite backend
(create / patch / get / rename / search / delete).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
rayhpeng
2026-04-07 10:56:03 +08:00
parent 6f155d3b4b
commit 439c10d6f2
6 changed files with 116 additions and 206 deletions
+48 -129
View File
@@ -20,18 +20,11 @@ from typing import Any
from fastapi import APIRouter, HTTPException, Request from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.gateway.deps import get_checkpointer, get_store from app.gateway.deps import get_checkpointer
from app.gateway.utils import sanitize_log_param from app.gateway.utils import sanitize_log_param
from deerflow.config.paths import Paths, get_paths from deerflow.config.paths import Paths, get_paths
from deerflow.runtime import serialize_channel_values from deerflow.runtime import serialize_channel_values
# ---------------------------------------------------------------------------
# Store namespace
# ---------------------------------------------------------------------------
THREADS_NS: tuple[str, ...] = ("threads",)
"""Namespace used by the Store for thread metadata records."""
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/threads", tags=["threads"]) router = APIRouter(prefix="/api/threads", tags=["threads"])
@@ -147,51 +140,6 @@ def _delete_thread_data(thread_id: str, paths: Paths | None = None) -> ThreadDel
return ThreadDeleteResponse(success=True, message=f"Deleted local thread data for {thread_id}") return ThreadDeleteResponse(success=True, message=f"Deleted local thread data for {thread_id}")
async def _store_get(store, thread_id: str) -> dict | None:
"""Fetch a thread record from the Store; returns ``None`` if absent."""
item = await store.aget(THREADS_NS, thread_id)
return item.value if item is not None else None
async def _store_put(store, record: dict) -> None:
"""Write a thread record to the Store."""
await store.aput(THREADS_NS, record["thread_id"], record)
async def _store_upsert(store, thread_id: str, *, metadata: dict | None = None, values: dict | None = None) -> None:
"""Create or refresh a thread record in the Store.
On creation the record is written with ``status="idle"``. On update only
``updated_at`` (and optionally ``metadata`` / ``values``) are changed so
that existing fields are preserved.
``values`` carries the agent-state snapshot exposed to the frontend
(currently just ``{"title": "..."}``).
"""
now = time.time()
existing = await _store_get(store, thread_id)
if existing is None:
await _store_put(
store,
{
"thread_id": thread_id,
"status": "idle",
"created_at": now,
"updated_at": now,
"metadata": metadata or {},
"values": values or {},
},
)
else:
val = dict(existing)
val["updated_at"] = now
if metadata:
val.setdefault("metadata", {}).update(metadata)
if values:
val.setdefault("values", {}).update(values)
await _store_put(store, val)
def _derive_thread_status(checkpoint_tuple) -> str: def _derive_thread_status(checkpoint_tuple) -> str:
"""Derive thread status from checkpoint metadata.""" """Derive thread status from checkpoint metadata."""
if checkpoint_tuple is None: if checkpoint_tuple is None:
@@ -221,22 +169,14 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe
"""Delete local persisted filesystem data for a thread. """Delete local persisted filesystem data for a thread.
Cleans DeerFlow-managed thread directories, removes checkpoint data, Cleans DeerFlow-managed thread directories, removes checkpoint data,
removes the thread record from the Store, and removes the thread_meta and removes the thread_meta row from the configured ThreadMetaStore
row from the configured ThreadMetaStore (sqlite or memory). (sqlite or memory).
""" """
from app.gateway.deps import get_thread_meta_repo from app.gateway.deps import get_thread_meta_repo
# Clean local filesystem # Clean local filesystem
response = _delete_thread_data(thread_id) response = _delete_thread_data(thread_id)
# Remove from Store (best-effort) — legacy in-memory thread record
store = get_store(request)
if store is not None:
try:
await store.adelete(THREADS_NS, thread_id)
except Exception:
logger.debug("Could not delete store record for thread %s (not critical)", sanitize_log_param(thread_id))
# Remove checkpoints (best-effort) # Remove checkpoints (best-effort)
checkpointer = getattr(request.app.state, "checkpointer", None) checkpointer = getattr(request.app.state, "checkpointer", None)
if checkpointer is not None: if checkpointer is not None:
@@ -261,43 +201,38 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe
async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadResponse: async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadResponse:
"""Create a new thread. """Create a new thread.
The thread record is written to the Store (for fast listing) and an Writes a thread_meta record (so the thread appears in /threads/search)
empty checkpoint is written to the checkpointer (for state reads). and an empty checkpoint (so state endpoints work immediately).
Idempotent: returns the existing record when ``thread_id`` already exists. Idempotent: returns the existing record when ``thread_id`` already exists.
""" """
store = get_store(request) from app.gateway.deps import get_thread_meta_repo
checkpointer = get_checkpointer(request) checkpointer = get_checkpointer(request)
thread_meta_repo = get_thread_meta_repo(request)
thread_id = body.thread_id or str(uuid.uuid4()) thread_id = body.thread_id or str(uuid.uuid4())
now = time.time() now = time.time()
# Idempotency: return existing record from Store when already present # Idempotency: return existing record when already present
if store is not None: existing_record = await thread_meta_repo.get(thread_id)
existing_record = await _store_get(store, thread_id) if existing_record is not None:
if existing_record is not None: return ThreadResponse(
return ThreadResponse( thread_id=thread_id,
thread_id=thread_id, status=existing_record.get("status", "idle"),
status=existing_record.get("status", "idle"), created_at=str(existing_record.get("created_at", "")),
created_at=str(existing_record.get("created_at", "")), updated_at=str(existing_record.get("updated_at", "")),
updated_at=str(existing_record.get("updated_at", "")), metadata=existing_record.get("metadata", {}),
metadata=existing_record.get("metadata", {}), )
)
# Write thread record to Store # Write thread_meta so the thread appears in /threads/search immediately
if store is not None: try:
try: await thread_meta_repo.create(
await _store_put( thread_id,
store, assistant_id=getattr(body, "assistant_id", None),
{ metadata=body.metadata,
"thread_id": thread_id, )
"status": "idle", except Exception:
"created_at": now, logger.exception("Failed to write thread_meta for %s", sanitize_log_param(thread_id))
"updated_at": now, raise HTTPException(status_code=500, detail="Failed to create thread")
"metadata": body.metadata,
},
)
except Exception:
logger.exception("Failed to write thread %s to store", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to create thread")
# Write an empty checkpoint so state endpoints work immediately # Write an empty checkpoint so state endpoints work immediately
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}} config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
@@ -317,19 +252,6 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
logger.exception("Failed to create checkpoint for thread %s", sanitize_log_param(thread_id)) logger.exception("Failed to create checkpoint for thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to create thread") raise HTTPException(status_code=500, detail="Failed to create thread")
# Write thread_meta so the thread appears in /threads/search immediately
from app.gateway.deps import get_thread_meta_repo
thread_meta_repo = get_thread_meta_repo(request)
try:
await thread_meta_repo.create(
thread_id,
assistant_id=getattr(body, "assistant_id", None),
metadata=body.metadata,
)
except Exception:
logger.debug("Failed to upsert thread_meta on create for %s (non-fatal)", sanitize_log_param(thread_id))
logger.info("Thread created: %s", sanitize_log_param(thread_id)) logger.info("Thread created: %s", sanitize_log_param(thread_id))
return ThreadResponse( return ThreadResponse(
thread_id=thread_id, thread_id=thread_id,
@@ -373,31 +295,27 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
@router.patch("/{thread_id}", response_model=ThreadResponse) @router.patch("/{thread_id}", response_model=ThreadResponse)
async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse: async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse:
"""Merge metadata into a thread record.""" """Merge metadata into a thread record."""
store = get_store(request) from app.gateway.deps import get_thread_meta_repo
if store is None:
raise HTTPException(status_code=503, detail="Store not available")
record = await _store_get(store, thread_id) thread_meta_repo = get_thread_meta_repo(request)
record = await thread_meta_repo.get(thread_id)
if record is None: if record is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found") raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
now = time.time()
updated = dict(record)
updated.setdefault("metadata", {}).update(body.metadata)
updated["updated_at"] = now
try: try:
await _store_put(store, updated) await thread_meta_repo.update_metadata(thread_id, body.metadata)
except Exception: except Exception:
logger.exception("Failed to patch thread %s", sanitize_log_param(thread_id)) logger.exception("Failed to patch thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to update thread") raise HTTPException(status_code=500, detail="Failed to update thread")
# Re-read to get the merged metadata + refreshed updated_at
record = await thread_meta_repo.get(thread_id) or record
return ThreadResponse( return ThreadResponse(
thread_id=thread_id, thread_id=thread_id,
status=updated.get("status", "idle"), status=record.get("status", "idle"),
created_at=str(updated.get("created_at", "")), created_at=str(record.get("created_at", "")),
updated_at=str(now), updated_at=str(record.get("updated_at", "")),
metadata=updated.get("metadata", {}), metadata=record.get("metadata", {}),
) )
@@ -405,16 +323,16 @@ async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Reques
async def get_thread(thread_id: str, request: Request) -> ThreadResponse: async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
"""Get thread info. """Get thread info.
Reads metadata from the Store and derives the accurate execution Reads metadata from the ThreadMetaStore and derives the accurate
status from the checkpointer. Falls back to the checkpointer alone execution status from the checkpointer. Falls back to the checkpointer
for threads that pre-date Store adoption (backward compat). alone for threads that pre-date ThreadMetaStore adoption (backward compat).
""" """
store = get_store(request) from app.gateway.deps import get_thread_meta_repo
thread_meta_repo = get_thread_meta_repo(request)
checkpointer = get_checkpointer(request) checkpointer = get_checkpointer(request)
record: dict | None = None record: dict | None = await thread_meta_repo.get(thread_id)
if store is not None:
record = await _store_get(store, thread_id)
# Derive accurate status from the checkpointer # Derive accurate status from the checkpointer
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}} config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
@@ -427,8 +345,9 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
if record is None and checkpoint_tuple is None: if record is None and checkpoint_tuple is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found") raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
# If the thread exists in the checkpointer but not the store (e.g. legacy # If the thread exists in the checkpointer but not in thread_meta (e.g.
# data), synthesize a minimal store record from the checkpoint metadata. # legacy data created before thread_meta adoption), synthesize a minimal
# record from the checkpoint metadata.
if record is None and checkpoint_tuple is not None: if record is None and checkpoint_tuple is not None:
ckpt_meta = getattr(checkpoint_tuple, "metadata", {}) or {} ckpt_meta = getattr(checkpoint_tuple, "metadata", {}) or {}
record = { record = {
+6 -77
View File
@@ -12,7 +12,6 @@ import dataclasses
import json import json
import logging import logging
import re import re
import time
from typing import Any from typing import Any
from fastapi import HTTPException, Request from fastapi import HTTPException, Request
@@ -173,71 +172,6 @@ def build_run_config(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
async def _upsert_thread_in_store(store, thread_id: str, metadata: dict | None) -> None:
"""Create or refresh the thread record in the Store.
Called from :func:`start_run` so that threads created via the stateless
``/runs/stream`` endpoint (which never calls ``POST /threads``) still
appear in ``/threads/search`` results.
"""
# Deferred import to avoid circular import with the threads router module.
from app.gateway.routers.threads import _store_upsert
try:
await _store_upsert(store, thread_id, metadata=metadata)
except Exception:
logger.warning("Failed to upsert thread %s in store (non-fatal)", sanitize_log_param(thread_id))
async def _sync_thread_title_after_run(
run_task: asyncio.Task,
thread_id: str,
checkpointer: Any,
store: Any,
) -> None:
"""Wait for *run_task* to finish, then persist the generated title to the Store.
TitleMiddleware writes the generated title to the LangGraph agent state
(checkpointer) but the Gateway's Store record is not updated automatically.
This coroutine closes that gap by reading the final checkpoint after the
run completes and syncing ``values.title`` into the Store record so that
subsequent ``/threads/search`` responses include the correct title.
Runs as a fire-and-forget :func:`asyncio.create_task`; failures are
logged at DEBUG level and never propagate.
"""
# Wait for the background run task to complete (any outcome).
# asyncio.wait does not propagate task exceptions — it just returns
# when the task is done, cancelled, or failed.
await asyncio.wait({run_task})
# Deferred import to avoid circular import with the threads router module.
from app.gateway.routers.threads import _store_get, _store_put
try:
ckpt_config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
ckpt_tuple = await checkpointer.aget_tuple(ckpt_config)
if ckpt_tuple is None:
return
channel_values = ckpt_tuple.checkpoint.get("channel_values", {})
title = channel_values.get("title")
if not title:
return
existing = await _store_get(store, thread_id)
if existing is None:
return
updated = dict(existing)
updated.setdefault("values", {})["title"] = title
updated["updated_at"] = time.time()
await _store_put(store, updated)
logger.debug("Synced title %r for thread %s", title, thread_id)
except Exception:
logger.debug("Failed to sync title for thread %s (non-fatal)", thread_id, exc_info=True)
async def start_run( async def start_run(
body: Any, body: Any,
thread_id: str, thread_id: str,
@@ -291,12 +225,9 @@ async def start_run(
except UnsupportedStrategyError as exc: except UnsupportedStrategyError as exc:
raise HTTPException(status_code=501, detail=str(exc)) from exc raise HTTPException(status_code=501, detail=str(exc)) from exc
# Ensure the thread is visible in /threads/search, even for threads that # Upsert thread metadata so the thread appears in /threads/search,
# were never explicitly created via POST /threads (e.g. stateless runs). # even for threads that were never explicitly created via POST /threads
if run_ctx.store is not None: # (e.g. stateless runs).
await _upsert_thread_in_store(run_ctx.store, thread_id, body.metadata)
# Upsert thread metadata so the thread appears in /threads/search
try: try:
existing = await run_ctx.thread_meta_repo.get(thread_id) existing = await run_ctx.thread_meta_repo.get(thread_id)
if existing is None: if existing is None:
@@ -353,11 +284,9 @@ async def start_run(
) )
record.task = task record.task = task
# After the run completes, sync the title generated by TitleMiddleware from # Title sync is handled by worker.py's finally block which reads the
# the checkpointer into the Store record so that /threads/search returns the # title from the checkpoint and calls thread_meta_repo.update_display_name
# correct title instead of an empty values dict. # after the run completes.
if run_ctx.store is not None:
asyncio.create_task(_sync_thread_title_after_run(task, thread_id, run_ctx.checkpointer, run_ctx.store))
return record return record
@@ -46,6 +46,15 @@ class ThreadMetaStore(abc.ABC):
async def update_status(self, thread_id: str, status: str) -> None: async def update_status(self, thread_id: str, status: str) -> None:
pass pass
@abc.abstractmethod
async def update_metadata(self, thread_id: str, metadata: dict) -> None:
"""Merge ``metadata`` into the thread's metadata field.
Existing keys are overwritten by the new values; keys absent from
``metadata`` are preserved. No-op if the thread does not exist.
"""
pass
@abc.abstractmethod @abc.abstractmethod
async def delete(self, thread_id: str) -> None: async def delete(self, thread_id: str) -> None:
pass pass
@@ -89,6 +89,18 @@ class MemoryThreadMetaStore(ThreadMetaStore):
record["updated_at"] = time.time() record["updated_at"] = time.time()
await self._store.aput(THREADS_NS, thread_id, record) await self._store.aput(THREADS_NS, thread_id, record)
async def update_metadata(self, thread_id: str, metadata: dict) -> None:
"""Merge ``metadata`` into the in-memory record. No-op if absent."""
item = await self._store.aget(THREADS_NS, thread_id)
if item is None:
return
record = dict(item.value)
merged = dict(record.get("metadata") or {})
merged.update(metadata)
record["metadata"] = merged
record["updated_at"] = time.time()
await self._store.aput(THREADS_NS, thread_id, record)
async def delete(self, thread_id: str) -> None: async def delete(self, thread_id: str) -> None:
await self._store.adelete(THREADS_NS, thread_id) await self._store.adelete(THREADS_NS, thread_id)
@@ -116,6 +116,22 @@ class ThreadMetaRepository(ThreadMetaStore):
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(status=status, updated_at=datetime.now(UTC))) await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(status=status, updated_at=datetime.now(UTC)))
await session.commit() await session.commit()
async def update_metadata(self, thread_id: str, metadata: dict) -> None:
"""Merge ``metadata`` into ``metadata_json``.
Read-modify-write inside a single session/transaction so concurrent
callers see consistent state. No-op if the row does not exist.
"""
async with self._sf() as session:
row = await session.get(ThreadMetaRow, thread_id)
if row is None:
return
merged = dict(row.metadata_json or {})
merged.update(metadata)
row.metadata_json = merged
row.updated_at = datetime.now(UTC)
await session.commit()
async def delete(self, thread_id: str) -> None: async def delete(self, thread_id: str) -> None:
async with self._sf() as session: async with self._sf() as session:
row = await session.get(ThreadMetaRow, thread_id) row = await session.get(ThreadMetaRow, thread_id)
+25
View File
@@ -130,3 +130,28 @@ class TestThreadMetaRepository:
repo = await _make_repo(tmp_path) repo = await _make_repo(tmp_path)
await repo.delete("nonexistent") # should not raise await repo.delete("nonexistent") # should not raise
await _cleanup() await _cleanup()
@pytest.mark.anyio
async def test_update_metadata_merges(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.create("t1", metadata={"a": 1, "b": 2})
await repo.update_metadata("t1", {"b": 99, "c": 3})
record = await repo.get("t1")
# Existing key preserved, overlapping key overwritten, new key added
assert record["metadata"] == {"a": 1, "b": 99, "c": 3}
await _cleanup()
@pytest.mark.anyio
async def test_update_metadata_on_empty(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.create("t1")
await repo.update_metadata("t1", {"k": "v"})
record = await repo.get("t1")
assert record["metadata"] == {"k": "v"}
await _cleanup()
@pytest.mark.anyio
async def test_update_metadata_nonexistent_is_noop(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.update_metadata("nonexistent", {"k": "v"}) # should not raise
await _cleanup()