refactor(runtime): introduce RunContext to reduce run_agent parameter bloat

Extract checkpointer, store, event_store, run_events_config, thread_meta_repo,
and follow_up_to_run_id into a frozen RunContext dataclass. Add get_run_context()
in deps.py to build the base context from app.state singletons. start_run() uses
dataclasses.replace() to enrich per-run fields before passing ctx to run_agent.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
rayhpeng
2026-04-06 10:59:47 +08:00
parent 8746a2bcd9
commit eba6810a44
5 changed files with 67 additions and 34 deletions
+20 -1
View File
@@ -14,7 +14,7 @@ from contextlib import AsyncExitStack, asynccontextmanager
from fastapi import FastAPI, HTTPException, Request
from deerflow.runtime import RunManager
from deerflow.runtime import RunContext, RunManager
@asynccontextmanager
@@ -109,6 +109,25 @@ def get_thread_meta_repo(request: Request):
return getattr(request.app.state, "thread_meta_repo", None)
def get_run_context(request: Request) -> RunContext:
"""Build a :class:`RunContext` from ``app.state`` singletons.
Returns a *base* context with infrastructure dependencies. Callers that
need per-run fields (e.g. ``follow_up_to_run_id``) should use
``dataclasses.replace(ctx, follow_up_to_run_id=...)`` before passing it
to :func:`run_agent`.
"""
from deerflow.config import get_app_config
return RunContext(
checkpointer=get_checkpointer(request),
store=get_store(request),
event_store=get_run_event_store(request),
run_events_config=getattr(get_app_config(), "run_events", None),
thread_meta_repo=get_thread_meta_repo(request),
)
async def get_current_user(request: Request) -> str | None:
"""Extract user identity from request.
+16 -25
View File
@@ -8,6 +8,7 @@ frames, and consuming stream bridge events. Router modules
from __future__ import annotations
import asyncio
import dataclasses
import json
import logging
import re
@@ -17,7 +18,7 @@ from typing import Any
from fastapi import HTTPException, Request
from langchain_core.messages import HumanMessage
from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_run_store, get_store, get_stream_bridge, get_thread_meta_repo
from app.gateway.deps import get_run_context, get_run_manager, get_run_store, get_stream_bridge
from app.gateway.routers.threads import _sanitize_log_param
from deerflow.runtime import (
END_SENTINEL,
@@ -256,14 +257,7 @@ async def start_run(
"""
bridge = get_stream_bridge(request)
run_mgr = get_run_manager(request)
checkpointer = get_checkpointer(request)
store = get_store(request)
event_store = get_run_event_store(request)
# Get run_events config for journal
from deerflow.config import get_app_config
run_events_config = getattr(get_app_config(), "run_events", None)
run_ctx = get_run_context(request)
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
@@ -278,6 +272,10 @@ async def start_run(
except Exception:
pass # Don't block run creation
# Enrich base context with per-run field
if follow_up_to_run_id:
run_ctx = dataclasses.replace(run_ctx, follow_up_to_run_id=follow_up_to_run_id)
try:
record = await run_mgr.create_or_reject(
thread_id,
@@ -295,23 +293,21 @@ async def start_run(
# Ensure the thread is visible in /threads/search, even for threads that
# were never explicitly created via POST /threads (e.g. stateless runs).
store = get_store(request)
if store is not None:
await _upsert_thread_in_store(store, thread_id, body.metadata)
if run_ctx.store is not None:
await _upsert_thread_in_store(run_ctx.store, thread_id, body.metadata)
# Upsert thread metadata in the SQL-backed threads_meta table
thread_meta_repo = get_thread_meta_repo(request)
if thread_meta_repo is not None:
if run_ctx.thread_meta_repo is not None:
try:
existing = await thread_meta_repo.get(thread_id)
existing = await run_ctx.thread_meta_repo.get(thread_id)
if existing is None:
await thread_meta_repo.create(
await run_ctx.thread_meta_repo.create(
thread_id,
assistant_id=body.assistant_id,
metadata=body.metadata,
)
else:
await thread_meta_repo.update_status(thread_id, "running")
await run_ctx.thread_meta_repo.update_status(thread_id, "running")
except Exception:
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", _sanitize_log_param(thread_id))
@@ -346,8 +342,7 @@ async def start_run(
bridge,
run_mgr,
record,
checkpointer=checkpointer,
store=store,
ctx=run_ctx,
agent_factory=agent_factory,
graph_input=graph_input,
config=config,
@@ -355,10 +350,6 @@ async def start_run(
stream_subgraphs=body.stream_subgraphs,
interrupt_before=body.interrupt_before,
interrupt_after=body.interrupt_after,
event_store=event_store,
run_events_config=run_events_config,
follow_up_to_run_id=follow_up_to_run_id,
thread_meta_repo=thread_meta_repo,
)
)
record.task = task
@@ -366,8 +357,8 @@ async def start_run(
# After the run completes, sync the title generated by TitleMiddleware from
# the checkpointer into the Store record so that /threads/search returns the
# correct title instead of an empty values dict.
if store is not None:
asyncio.create_task(_sync_thread_title_after_run(task, thread_id, checkpointer, store))
if run_ctx.store is not None:
asyncio.create_task(_sync_thread_title_after_run(task, thread_id, run_ctx.checkpointer, run_ctx.store))
return record