mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 08:25:57 +00:00
refactor(runtime): introduce RunContext to reduce run_agent parameter bloat
Extract checkpointer, store, event_store, run_events_config, thread_meta_repo, and follow_up_to_run_id into a frozen RunContext dataclass. Add get_run_context() in deps.py to build the base context from app.state singletons. start_run() uses dataclasses.replace() to enrich per-run fields before passing ctx to run_agent. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -14,7 +14,7 @@ from contextlib import AsyncExitStack, asynccontextmanager
|
|||||||
|
|
||||||
from fastapi import FastAPI, HTTPException, Request
|
from fastapi import FastAPI, HTTPException, Request
|
||||||
|
|
||||||
from deerflow.runtime import RunManager
|
from deerflow.runtime import RunContext, RunManager
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
@@ -109,6 +109,25 @@ def get_thread_meta_repo(request: Request):
|
|||||||
return getattr(request.app.state, "thread_meta_repo", None)
|
return getattr(request.app.state, "thread_meta_repo", None)
|
||||||
|
|
||||||
|
|
||||||
|
def get_run_context(request: Request) -> RunContext:
|
||||||
|
"""Build a :class:`RunContext` from ``app.state`` singletons.
|
||||||
|
|
||||||
|
Returns a *base* context with infrastructure dependencies. Callers that
|
||||||
|
need per-run fields (e.g. ``follow_up_to_run_id``) should use
|
||||||
|
``dataclasses.replace(ctx, follow_up_to_run_id=...)`` before passing it
|
||||||
|
to :func:`run_agent`.
|
||||||
|
"""
|
||||||
|
from deerflow.config import get_app_config
|
||||||
|
|
||||||
|
return RunContext(
|
||||||
|
checkpointer=get_checkpointer(request),
|
||||||
|
store=get_store(request),
|
||||||
|
event_store=get_run_event_store(request),
|
||||||
|
run_events_config=getattr(get_app_config(), "run_events", None),
|
||||||
|
thread_meta_repo=get_thread_meta_repo(request),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_current_user(request: Request) -> str | None:
|
async def get_current_user(request: Request) -> str | None:
|
||||||
"""Extract user identity from request.
|
"""Extract user identity from request.
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ frames, and consuming stream bridge events. Router modules
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import dataclasses
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
@@ -17,7 +18,7 @@ from typing import Any
|
|||||||
from fastapi import HTTPException, Request
|
from fastapi import HTTPException, Request
|
||||||
from langchain_core.messages import HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_run_store, get_store, get_stream_bridge, get_thread_meta_repo
|
from app.gateway.deps import get_run_context, get_run_manager, get_run_store, get_stream_bridge
|
||||||
from app.gateway.routers.threads import _sanitize_log_param
|
from app.gateway.routers.threads import _sanitize_log_param
|
||||||
from deerflow.runtime import (
|
from deerflow.runtime import (
|
||||||
END_SENTINEL,
|
END_SENTINEL,
|
||||||
@@ -256,14 +257,7 @@ async def start_run(
|
|||||||
"""
|
"""
|
||||||
bridge = get_stream_bridge(request)
|
bridge = get_stream_bridge(request)
|
||||||
run_mgr = get_run_manager(request)
|
run_mgr = get_run_manager(request)
|
||||||
checkpointer = get_checkpointer(request)
|
run_ctx = get_run_context(request)
|
||||||
store = get_store(request)
|
|
||||||
event_store = get_run_event_store(request)
|
|
||||||
|
|
||||||
# Get run_events config for journal
|
|
||||||
from deerflow.config import get_app_config
|
|
||||||
|
|
||||||
run_events_config = getattr(get_app_config(), "run_events", None)
|
|
||||||
|
|
||||||
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
|
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
|
||||||
|
|
||||||
@@ -278,6 +272,10 @@ async def start_run(
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass # Don't block run creation
|
pass # Don't block run creation
|
||||||
|
|
||||||
|
# Enrich base context with per-run field
|
||||||
|
if follow_up_to_run_id:
|
||||||
|
run_ctx = dataclasses.replace(run_ctx, follow_up_to_run_id=follow_up_to_run_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
record = await run_mgr.create_or_reject(
|
record = await run_mgr.create_or_reject(
|
||||||
thread_id,
|
thread_id,
|
||||||
@@ -295,23 +293,21 @@ async def start_run(
|
|||||||
|
|
||||||
# Ensure the thread is visible in /threads/search, even for threads that
|
# Ensure the thread is visible in /threads/search, even for threads that
|
||||||
# were never explicitly created via POST /threads (e.g. stateless runs).
|
# were never explicitly created via POST /threads (e.g. stateless runs).
|
||||||
store = get_store(request)
|
if run_ctx.store is not None:
|
||||||
if store is not None:
|
await _upsert_thread_in_store(run_ctx.store, thread_id, body.metadata)
|
||||||
await _upsert_thread_in_store(store, thread_id, body.metadata)
|
|
||||||
|
|
||||||
# Upsert thread metadata in the SQL-backed threads_meta table
|
# Upsert thread metadata in the SQL-backed threads_meta table
|
||||||
thread_meta_repo = get_thread_meta_repo(request)
|
if run_ctx.thread_meta_repo is not None:
|
||||||
if thread_meta_repo is not None:
|
|
||||||
try:
|
try:
|
||||||
existing = await thread_meta_repo.get(thread_id)
|
existing = await run_ctx.thread_meta_repo.get(thread_id)
|
||||||
if existing is None:
|
if existing is None:
|
||||||
await thread_meta_repo.create(
|
await run_ctx.thread_meta_repo.create(
|
||||||
thread_id,
|
thread_id,
|
||||||
assistant_id=body.assistant_id,
|
assistant_id=body.assistant_id,
|
||||||
metadata=body.metadata,
|
metadata=body.metadata,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await thread_meta_repo.update_status(thread_id, "running")
|
await run_ctx.thread_meta_repo.update_status(thread_id, "running")
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", _sanitize_log_param(thread_id))
|
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", _sanitize_log_param(thread_id))
|
||||||
|
|
||||||
@@ -346,8 +342,7 @@ async def start_run(
|
|||||||
bridge,
|
bridge,
|
||||||
run_mgr,
|
run_mgr,
|
||||||
record,
|
record,
|
||||||
checkpointer=checkpointer,
|
ctx=run_ctx,
|
||||||
store=store,
|
|
||||||
agent_factory=agent_factory,
|
agent_factory=agent_factory,
|
||||||
graph_input=graph_input,
|
graph_input=graph_input,
|
||||||
config=config,
|
config=config,
|
||||||
@@ -355,10 +350,6 @@ async def start_run(
|
|||||||
stream_subgraphs=body.stream_subgraphs,
|
stream_subgraphs=body.stream_subgraphs,
|
||||||
interrupt_before=body.interrupt_before,
|
interrupt_before=body.interrupt_before,
|
||||||
interrupt_after=body.interrupt_after,
|
interrupt_after=body.interrupt_after,
|
||||||
event_store=event_store,
|
|
||||||
run_events_config=run_events_config,
|
|
||||||
follow_up_to_run_id=follow_up_to_run_id,
|
|
||||||
thread_meta_repo=thread_meta_repo,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
record.task = task
|
record.task = task
|
||||||
@@ -366,8 +357,8 @@ async def start_run(
|
|||||||
# After the run completes, sync the title generated by TitleMiddleware from
|
# After the run completes, sync the title generated by TitleMiddleware from
|
||||||
# the checkpointer into the Store record so that /threads/search returns the
|
# the checkpointer into the Store record so that /threads/search returns the
|
||||||
# correct title instead of an empty values dict.
|
# correct title instead of an empty values dict.
|
||||||
if store is not None:
|
if run_ctx.store is not None:
|
||||||
asyncio.create_task(_sync_thread_title_after_run(task, thread_id, checkpointer, store))
|
asyncio.create_task(_sync_thread_title_after_run(task, thread_id, run_ctx.checkpointer, run_ctx.store))
|
||||||
|
|
||||||
return record
|
return record
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ Re-exports the public API of :mod:`~deerflow.runtime.runs` and
|
|||||||
directly from ``deerflow.runtime``.
|
directly from ``deerflow.runtime``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .runs import ConflictError, DisconnectMode, RunManager, RunRecord, RunStatus, UnsupportedStrategyError, run_agent
|
from .runs import ConflictError, DisconnectMode, RunContext, RunManager, RunRecord, RunStatus, UnsupportedStrategyError, run_agent
|
||||||
from .serialization import serialize, serialize_channel_values, serialize_lc_object, serialize_messages_tuple
|
from .serialization import serialize, serialize_channel_values, serialize_lc_object, serialize_messages_tuple
|
||||||
from .store import get_store, make_store, reset_store, store_context
|
from .store import get_store, make_store, reset_store, store_context
|
||||||
from .stream_bridge import END_SENTINEL, HEARTBEAT_SENTINEL, MemoryStreamBridge, StreamBridge, StreamEvent, make_stream_bridge
|
from .stream_bridge import END_SENTINEL, HEARTBEAT_SENTINEL, MemoryStreamBridge, StreamBridge, StreamEvent, make_stream_bridge
|
||||||
@@ -14,6 +14,7 @@ __all__ = [
|
|||||||
# runs
|
# runs
|
||||||
"ConflictError",
|
"ConflictError",
|
||||||
"DisconnectMode",
|
"DisconnectMode",
|
||||||
|
"RunContext",
|
||||||
"RunManager",
|
"RunManager",
|
||||||
"RunRecord",
|
"RunRecord",
|
||||||
"RunStatus",
|
"RunStatus",
|
||||||
|
|||||||
@@ -2,11 +2,12 @@
|
|||||||
|
|
||||||
from .manager import ConflictError, RunManager, RunRecord, UnsupportedStrategyError
|
from .manager import ConflictError, RunManager, RunRecord, UnsupportedStrategyError
|
||||||
from .schemas import DisconnectMode, RunStatus
|
from .schemas import DisconnectMode, RunStatus
|
||||||
from .worker import run_agent
|
from .worker import RunContext, run_agent
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ConflictError",
|
"ConflictError",
|
||||||
"DisconnectMode",
|
"DisconnectMode",
|
||||||
|
"RunContext",
|
||||||
"RunManager",
|
"RunManager",
|
||||||
"RunRecord",
|
"RunRecord",
|
||||||
"RunStatus",
|
"RunStatus",
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, Any, Literal
|
from typing import TYPE_CHECKING, Any, Literal
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -34,13 +35,29 @@ logger = logging.getLogger(__name__)
|
|||||||
_VALID_LG_MODES = {"values", "updates", "checkpoints", "tasks", "debug", "messages", "custom"}
|
_VALID_LG_MODES = {"values", "updates", "checkpoints", "tasks", "debug", "messages", "custom"}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RunContext:
|
||||||
|
"""Infrastructure dependencies for a single agent run.
|
||||||
|
|
||||||
|
Groups checkpointer, store, and persistence-related singletons so that
|
||||||
|
``run_agent`` (and any future callers) receive one object instead of a
|
||||||
|
growing list of keyword arguments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
checkpointer: Any
|
||||||
|
store: Any | None = field(default=None)
|
||||||
|
event_store: Any | None = field(default=None)
|
||||||
|
run_events_config: Any | None = field(default=None)
|
||||||
|
thread_meta_repo: Any | None = field(default=None)
|
||||||
|
follow_up_to_run_id: str | None = field(default=None)
|
||||||
|
|
||||||
|
|
||||||
async def run_agent(
|
async def run_agent(
|
||||||
bridge: StreamBridge,
|
bridge: StreamBridge,
|
||||||
run_manager: RunManager,
|
run_manager: RunManager,
|
||||||
record: RunRecord,
|
record: RunRecord,
|
||||||
*,
|
*,
|
||||||
checkpointer: Any,
|
ctx: RunContext,
|
||||||
store: Any | None = None,
|
|
||||||
agent_factory: Any,
|
agent_factory: Any,
|
||||||
graph_input: dict,
|
graph_input: dict,
|
||||||
config: dict,
|
config: dict,
|
||||||
@@ -48,13 +65,17 @@ async def run_agent(
|
|||||||
stream_subgraphs: bool = False,
|
stream_subgraphs: bool = False,
|
||||||
interrupt_before: list[str] | Literal["*"] | None = None,
|
interrupt_before: list[str] | Literal["*"] | None = None,
|
||||||
interrupt_after: list[str] | Literal["*"] | None = None,
|
interrupt_after: list[str] | Literal["*"] | None = None,
|
||||||
event_store: Any | None = None,
|
|
||||||
run_events_config: Any | None = None,
|
|
||||||
follow_up_to_run_id: str | None = None,
|
|
||||||
thread_meta_repo: Any | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Execute an agent in the background, publishing events to *bridge*."""
|
"""Execute an agent in the background, publishing events to *bridge*."""
|
||||||
|
|
||||||
|
# Unpack infrastructure dependencies from RunContext.
|
||||||
|
checkpointer = ctx.checkpointer
|
||||||
|
store = ctx.store
|
||||||
|
event_store = ctx.event_store
|
||||||
|
run_events_config = ctx.run_events_config
|
||||||
|
thread_meta_repo = ctx.thread_meta_repo
|
||||||
|
follow_up_to_run_id = ctx.follow_up_to_run_id
|
||||||
|
|
||||||
run_id = record.run_id
|
run_id = record.run_id
|
||||||
thread_id = record.thread_id
|
thread_id = record.thread_id
|
||||||
requested_modes: set[str] = set(stream_modes or ["values"])
|
requested_modes: set[str] = set(stream_modes or ["values"])
|
||||||
|
|||||||
Reference in New Issue
Block a user