mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-22 07:56:48 +00:00
refactor(runtime): introduce RunContext to reduce run_agent parameter bloat
Extract checkpointer, store, event_store, run_events_config, thread_meta_repo, and follow_up_to_run_id into a frozen RunContext dataclass. Add get_run_context() in deps.py to build the base context from app.state singletons. start_run() uses dataclasses.replace() to enrich per-run fields before passing ctx to run_agent. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -14,7 +14,7 @@ from contextlib import AsyncExitStack, asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
|
||||
from deerflow.runtime import RunManager
|
||||
from deerflow.runtime import RunContext, RunManager
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -109,6 +109,25 @@ def get_thread_meta_repo(request: Request):
|
||||
return getattr(request.app.state, "thread_meta_repo", None)
|
||||
|
||||
|
||||
def get_run_context(request: Request) -> RunContext:
|
||||
"""Build a :class:`RunContext` from ``app.state`` singletons.
|
||||
|
||||
Returns a *base* context with infrastructure dependencies. Callers that
|
||||
need per-run fields (e.g. ``follow_up_to_run_id``) should use
|
||||
``dataclasses.replace(ctx, follow_up_to_run_id=...)`` before passing it
|
||||
to :func:`run_agent`.
|
||||
"""
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
return RunContext(
|
||||
checkpointer=get_checkpointer(request),
|
||||
store=get_store(request),
|
||||
event_store=get_run_event_store(request),
|
||||
run_events_config=getattr(get_app_config(), "run_events", None),
|
||||
thread_meta_repo=get_thread_meta_repo(request),
|
||||
)
|
||||
|
||||
|
||||
async def get_current_user(request: Request) -> str | None:
|
||||
"""Extract user identity from request.
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ frames, and consuming stream bridge events. Router modules
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
@@ -17,7 +18,7 @@ from typing import Any
|
||||
from fastapi import HTTPException, Request
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_run_store, get_store, get_stream_bridge, get_thread_meta_repo
|
||||
from app.gateway.deps import get_run_context, get_run_manager, get_run_store, get_stream_bridge
|
||||
from app.gateway.routers.threads import _sanitize_log_param
|
||||
from deerflow.runtime import (
|
||||
END_SENTINEL,
|
||||
@@ -256,14 +257,7 @@ async def start_run(
|
||||
"""
|
||||
bridge = get_stream_bridge(request)
|
||||
run_mgr = get_run_manager(request)
|
||||
checkpointer = get_checkpointer(request)
|
||||
store = get_store(request)
|
||||
event_store = get_run_event_store(request)
|
||||
|
||||
# Get run_events config for journal
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
run_events_config = getattr(get_app_config(), "run_events", None)
|
||||
run_ctx = get_run_context(request)
|
||||
|
||||
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
|
||||
|
||||
@@ -278,6 +272,10 @@ async def start_run(
|
||||
except Exception:
|
||||
pass # Don't block run creation
|
||||
|
||||
# Enrich base context with per-run field
|
||||
if follow_up_to_run_id:
|
||||
run_ctx = dataclasses.replace(run_ctx, follow_up_to_run_id=follow_up_to_run_id)
|
||||
|
||||
try:
|
||||
record = await run_mgr.create_or_reject(
|
||||
thread_id,
|
||||
@@ -295,23 +293,21 @@ async def start_run(
|
||||
|
||||
# Ensure the thread is visible in /threads/search, even for threads that
|
||||
# were never explicitly created via POST /threads (e.g. stateless runs).
|
||||
store = get_store(request)
|
||||
if store is not None:
|
||||
await _upsert_thread_in_store(store, thread_id, body.metadata)
|
||||
if run_ctx.store is not None:
|
||||
await _upsert_thread_in_store(run_ctx.store, thread_id, body.metadata)
|
||||
|
||||
# Upsert thread metadata in the SQL-backed threads_meta table
|
||||
thread_meta_repo = get_thread_meta_repo(request)
|
||||
if thread_meta_repo is not None:
|
||||
if run_ctx.thread_meta_repo is not None:
|
||||
try:
|
||||
existing = await thread_meta_repo.get(thread_id)
|
||||
existing = await run_ctx.thread_meta_repo.get(thread_id)
|
||||
if existing is None:
|
||||
await thread_meta_repo.create(
|
||||
await run_ctx.thread_meta_repo.create(
|
||||
thread_id,
|
||||
assistant_id=body.assistant_id,
|
||||
metadata=body.metadata,
|
||||
)
|
||||
else:
|
||||
await thread_meta_repo.update_status(thread_id, "running")
|
||||
await run_ctx.thread_meta_repo.update_status(thread_id, "running")
|
||||
except Exception:
|
||||
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", _sanitize_log_param(thread_id))
|
||||
|
||||
@@ -346,8 +342,7 @@ async def start_run(
|
||||
bridge,
|
||||
run_mgr,
|
||||
record,
|
||||
checkpointer=checkpointer,
|
||||
store=store,
|
||||
ctx=run_ctx,
|
||||
agent_factory=agent_factory,
|
||||
graph_input=graph_input,
|
||||
config=config,
|
||||
@@ -355,10 +350,6 @@ async def start_run(
|
||||
stream_subgraphs=body.stream_subgraphs,
|
||||
interrupt_before=body.interrupt_before,
|
||||
interrupt_after=body.interrupt_after,
|
||||
event_store=event_store,
|
||||
run_events_config=run_events_config,
|
||||
follow_up_to_run_id=follow_up_to_run_id,
|
||||
thread_meta_repo=thread_meta_repo,
|
||||
)
|
||||
)
|
||||
record.task = task
|
||||
@@ -366,8 +357,8 @@ async def start_run(
|
||||
# After the run completes, sync the title generated by TitleMiddleware from
|
||||
# the checkpointer into the Store record so that /threads/search returns the
|
||||
# correct title instead of an empty values dict.
|
||||
if store is not None:
|
||||
asyncio.create_task(_sync_thread_title_after_run(task, thread_id, checkpointer, store))
|
||||
if run_ctx.store is not None:
|
||||
asyncio.create_task(_sync_thread_title_after_run(task, thread_id, run_ctx.checkpointer, run_ctx.store))
|
||||
|
||||
return record
|
||||
|
||||
|
||||
Reference in New Issue
Block a user