refactor(runtime): introduce RunContext to reduce run_agent parameter bloat

Extract checkpointer, store, event_store, run_events_config, thread_meta_repo,
and follow_up_to_run_id into a frozen RunContext dataclass. Add get_run_context()
in deps.py to build the base context from app.state singletons. start_run() uses
dataclasses.replace() to enrich per-run fields before passing ctx to run_agent.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
rayhpeng
2026-04-06 10:59:47 +08:00
parent 8746a2bcd9
commit eba6810a44
5 changed files with 67 additions and 34 deletions
+20 -1
View File
@@ -14,7 +14,7 @@ from contextlib import AsyncExitStack, asynccontextmanager
from fastapi import FastAPI, HTTPException, Request from fastapi import FastAPI, HTTPException, Request
from deerflow.runtime import RunManager from deerflow.runtime import RunContext, RunManager
@asynccontextmanager @asynccontextmanager
@@ -109,6 +109,25 @@ def get_thread_meta_repo(request: Request):
return getattr(request.app.state, "thread_meta_repo", None) return getattr(request.app.state, "thread_meta_repo", None)
def get_run_context(request: Request) -> RunContext:
"""Build a :class:`RunContext` from ``app.state`` singletons.
Returns a *base* context with infrastructure dependencies. Callers that
need per-run fields (e.g. ``follow_up_to_run_id``) should use
``dataclasses.replace(ctx, follow_up_to_run_id=...)`` before passing it
to :func:`run_agent`.
"""
from deerflow.config import get_app_config
return RunContext(
checkpointer=get_checkpointer(request),
store=get_store(request),
event_store=get_run_event_store(request),
run_events_config=getattr(get_app_config(), "run_events", None),
thread_meta_repo=get_thread_meta_repo(request),
)
async def get_current_user(request: Request) -> str | None: async def get_current_user(request: Request) -> str | None:
"""Extract user identity from request. """Extract user identity from request.
+16 -25
View File
@@ -8,6 +8,7 @@ frames, and consuming stream bridge events. Router modules
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import dataclasses
import json import json
import logging import logging
import re import re
@@ -17,7 +18,7 @@ from typing import Any
from fastapi import HTTPException, Request from fastapi import HTTPException, Request
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_run_store, get_store, get_stream_bridge, get_thread_meta_repo from app.gateway.deps import get_run_context, get_run_manager, get_run_store, get_stream_bridge
from app.gateway.routers.threads import _sanitize_log_param from app.gateway.routers.threads import _sanitize_log_param
from deerflow.runtime import ( from deerflow.runtime import (
END_SENTINEL, END_SENTINEL,
@@ -256,14 +257,7 @@ async def start_run(
""" """
bridge = get_stream_bridge(request) bridge = get_stream_bridge(request)
run_mgr = get_run_manager(request) run_mgr = get_run_manager(request)
checkpointer = get_checkpointer(request) run_ctx = get_run_context(request)
store = get_store(request)
event_store = get_run_event_store(request)
# Get run_events config for journal
from deerflow.config import get_app_config
run_events_config = getattr(get_app_config(), "run_events", None)
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_ disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
@@ -278,6 +272,10 @@ async def start_run(
except Exception: except Exception:
pass # Don't block run creation pass # Don't block run creation
# Enrich base context with per-run field
if follow_up_to_run_id:
run_ctx = dataclasses.replace(run_ctx, follow_up_to_run_id=follow_up_to_run_id)
try: try:
record = await run_mgr.create_or_reject( record = await run_mgr.create_or_reject(
thread_id, thread_id,
@@ -295,23 +293,21 @@ async def start_run(
# Ensure the thread is visible in /threads/search, even for threads that # Ensure the thread is visible in /threads/search, even for threads that
# were never explicitly created via POST /threads (e.g. stateless runs). # were never explicitly created via POST /threads (e.g. stateless runs).
store = get_store(request) if run_ctx.store is not None:
if store is not None: await _upsert_thread_in_store(run_ctx.store, thread_id, body.metadata)
await _upsert_thread_in_store(store, thread_id, body.metadata)
# Upsert thread metadata in the SQL-backed threads_meta table # Upsert thread metadata in the SQL-backed threads_meta table
thread_meta_repo = get_thread_meta_repo(request) if run_ctx.thread_meta_repo is not None:
if thread_meta_repo is not None:
try: try:
existing = await thread_meta_repo.get(thread_id) existing = await run_ctx.thread_meta_repo.get(thread_id)
if existing is None: if existing is None:
await thread_meta_repo.create( await run_ctx.thread_meta_repo.create(
thread_id, thread_id,
assistant_id=body.assistant_id, assistant_id=body.assistant_id,
metadata=body.metadata, metadata=body.metadata,
) )
else: else:
await thread_meta_repo.update_status(thread_id, "running") await run_ctx.thread_meta_repo.update_status(thread_id, "running")
except Exception: except Exception:
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", _sanitize_log_param(thread_id)) logger.warning("Failed to upsert thread_meta for %s (non-fatal)", _sanitize_log_param(thread_id))
@@ -346,8 +342,7 @@ async def start_run(
bridge, bridge,
run_mgr, run_mgr,
record, record,
checkpointer=checkpointer, ctx=run_ctx,
store=store,
agent_factory=agent_factory, agent_factory=agent_factory,
graph_input=graph_input, graph_input=graph_input,
config=config, config=config,
@@ -355,10 +350,6 @@ async def start_run(
stream_subgraphs=body.stream_subgraphs, stream_subgraphs=body.stream_subgraphs,
interrupt_before=body.interrupt_before, interrupt_before=body.interrupt_before,
interrupt_after=body.interrupt_after, interrupt_after=body.interrupt_after,
event_store=event_store,
run_events_config=run_events_config,
follow_up_to_run_id=follow_up_to_run_id,
thread_meta_repo=thread_meta_repo,
) )
) )
record.task = task record.task = task
@@ -366,8 +357,8 @@ async def start_run(
# After the run completes, sync the title generated by TitleMiddleware from # After the run completes, sync the title generated by TitleMiddleware from
# the checkpointer into the Store record so that /threads/search returns the # the checkpointer into the Store record so that /threads/search returns the
# correct title instead of an empty values dict. # correct title instead of an empty values dict.
if store is not None: if run_ctx.store is not None:
asyncio.create_task(_sync_thread_title_after_run(task, thread_id, checkpointer, store)) asyncio.create_task(_sync_thread_title_after_run(task, thread_id, run_ctx.checkpointer, run_ctx.store))
return record return record
@@ -5,7 +5,7 @@ Re-exports the public API of :mod:`~deerflow.runtime.runs` and
directly from ``deerflow.runtime``. directly from ``deerflow.runtime``.
""" """
from .runs import ConflictError, DisconnectMode, RunManager, RunRecord, RunStatus, UnsupportedStrategyError, run_agent from .runs import ConflictError, DisconnectMode, RunContext, RunManager, RunRecord, RunStatus, UnsupportedStrategyError, run_agent
from .serialization import serialize, serialize_channel_values, serialize_lc_object, serialize_messages_tuple from .serialization import serialize, serialize_channel_values, serialize_lc_object, serialize_messages_tuple
from .store import get_store, make_store, reset_store, store_context from .store import get_store, make_store, reset_store, store_context
from .stream_bridge import END_SENTINEL, HEARTBEAT_SENTINEL, MemoryStreamBridge, StreamBridge, StreamEvent, make_stream_bridge from .stream_bridge import END_SENTINEL, HEARTBEAT_SENTINEL, MemoryStreamBridge, StreamBridge, StreamEvent, make_stream_bridge
@@ -14,6 +14,7 @@ __all__ = [
# runs # runs
"ConflictError", "ConflictError",
"DisconnectMode", "DisconnectMode",
"RunContext",
"RunManager", "RunManager",
"RunRecord", "RunRecord",
"RunStatus", "RunStatus",
@@ -2,11 +2,12 @@
from .manager import ConflictError, RunManager, RunRecord, UnsupportedStrategyError from .manager import ConflictError, RunManager, RunRecord, UnsupportedStrategyError
from .schemas import DisconnectMode, RunStatus from .schemas import DisconnectMode, RunStatus
from .worker import run_agent from .worker import RunContext, run_agent
__all__ = [ __all__ = [
"ConflictError", "ConflictError",
"DisconnectMode", "DisconnectMode",
"RunContext",
"RunManager", "RunManager",
"RunRecord", "RunRecord",
"RunStatus", "RunStatus",
@@ -17,6 +17,7 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal from typing import TYPE_CHECKING, Any, Literal
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -34,13 +35,29 @@ logger = logging.getLogger(__name__)
_VALID_LG_MODES = {"values", "updates", "checkpoints", "tasks", "debug", "messages", "custom"} _VALID_LG_MODES = {"values", "updates", "checkpoints", "tasks", "debug", "messages", "custom"}
@dataclass(frozen=True)
class RunContext:
"""Infrastructure dependencies for a single agent run.
Groups checkpointer, store, and persistence-related singletons so that
``run_agent`` (and any future callers) receive one object instead of a
growing list of keyword arguments.
"""
checkpointer: Any
store: Any | None = field(default=None)
event_store: Any | None = field(default=None)
run_events_config: Any | None = field(default=None)
thread_meta_repo: Any | None = field(default=None)
follow_up_to_run_id: str | None = field(default=None)
async def run_agent( async def run_agent(
bridge: StreamBridge, bridge: StreamBridge,
run_manager: RunManager, run_manager: RunManager,
record: RunRecord, record: RunRecord,
*, *,
checkpointer: Any, ctx: RunContext,
store: Any | None = None,
agent_factory: Any, agent_factory: Any,
graph_input: dict, graph_input: dict,
config: dict, config: dict,
@@ -48,13 +65,17 @@ async def run_agent(
stream_subgraphs: bool = False, stream_subgraphs: bool = False,
interrupt_before: list[str] | Literal["*"] | None = None, interrupt_before: list[str] | Literal["*"] | None = None,
interrupt_after: list[str] | Literal["*"] | None = None, interrupt_after: list[str] | Literal["*"] | None = None,
event_store: Any | None = None,
run_events_config: Any | None = None,
follow_up_to_run_id: str | None = None,
thread_meta_repo: Any | None = None,
) -> None: ) -> None:
"""Execute an agent in the background, publishing events to *bridge*.""" """Execute an agent in the background, publishing events to *bridge*."""
# Unpack infrastructure dependencies from RunContext.
checkpointer = ctx.checkpointer
store = ctx.store
event_store = ctx.event_store
run_events_config = ctx.run_events_config
thread_meta_repo = ctx.thread_meta_repo
follow_up_to_run_id = ctx.follow_up_to_run_id
run_id = record.run_id run_id = record.run_id
thread_id = record.thread_id thread_id = record.thread_id
requested_modes: set[str] = set(stream_modes or ["values"]) requested_modes: set[str] = set(stream_modes or ["values"])