mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-14 03:15:58 +00:00
f43aa78107
* fix(agents): sync agent_name across context/configurable and reject empty soul (#3549) Two independent issues caused custom agent creation to silently fail: 1. build_run_config only wrote agent_name into one container (configurable or context), so setup_agent — which reads ToolRuntime.context exclusively since LangGraph >=1.1.9 — saw agent_name=None and wrote SOUL.md to the global base_dir instead of users/{user_id}/agents/{name}/. Mirror the dual-write pattern already used by merge_run_context_overrides and naming.py so both containers always carry the same value. 2. setup_agent persisted whatever soul string it received, including empty or whitespace-only content, and still reported success. The frontend then surfaced an unusable agent and the global default SOUL.md could be silently overwritten with empty content. Reject empty soul before any filesystem operation so the model can retry. Tests: - test_gateway_services.py: dual-write regressions for both configurable and context entry paths, explicit-agent-name precedence on both sides, and a shape-parity test against merge_run_context_overrides. - test_setup_agent_tool.py: empty/whitespace soul rejection, plus no-overwrite guarantees for existing global and per-agent SOUL.md. * Update services.py
511 lines
21 KiB
Python
511 lines
21 KiB
Python
"""Run lifecycle service layer.
|
|
|
|
Centralizes the business logic for creating runs, formatting SSE
|
|
frames, and consuming stream bridge events. Router modules
|
|
(``thread_runs``, ``runs``) are thin HTTP handlers that delegate here.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import re
|
|
from collections.abc import Mapping
|
|
from types import SimpleNamespace
|
|
from typing import Any
|
|
|
|
from fastapi import HTTPException, Request
|
|
from langchain_core.messages import BaseMessage
|
|
from langchain_core.messages.utils import convert_to_messages
|
|
|
|
from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge
|
|
from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE, get_trusted_internal_owner_user_id
|
|
from app.gateway.utils import sanitize_log_param
|
|
from deerflow.config.app_config import get_app_config
|
|
from deerflow.runtime import (
|
|
END_SENTINEL,
|
|
HEARTBEAT_SENTINEL,
|
|
ConflictError,
|
|
DisconnectMode,
|
|
RunManager,
|
|
RunRecord,
|
|
RunStatus,
|
|
StreamBridge,
|
|
UnsupportedStrategyError,
|
|
run_agent,
|
|
)
|
|
from deerflow.runtime.runs.naming import resolve_root_run_name
|
|
from deerflow.runtime.user_context import reset_current_user, set_current_user
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# SSE formatting
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def format_sse(event: str, data: Any, *, event_id: str | None = None) -> str:
|
|
"""Format a single SSE frame.
|
|
|
|
Field order: ``event:`` -> ``data:`` -> ``id:`` (optional) -> blank line.
|
|
This matches the LangGraph Platform wire format consumed by the
|
|
``useStream`` React hook and the Python ``langgraph-sdk`` SSE decoder.
|
|
"""
|
|
payload = json.dumps(data, default=str, ensure_ascii=False)
|
|
parts = [f"event: {event}", f"data: {payload}"]
|
|
if event_id:
|
|
parts.append(f"id: {event_id}")
|
|
parts.append("")
|
|
parts.append("")
|
|
return "\n".join(parts)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Input / config helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def normalize_stream_modes(raw: list[str] | str | None) -> list[str]:
|
|
"""Normalize the stream_mode parameter to a list.
|
|
|
|
Default matches what ``useStream`` expects: values + messages-tuple.
|
|
"""
|
|
if raw is None:
|
|
return ["values"]
|
|
if isinstance(raw, str):
|
|
return [raw]
|
|
return raw if raw else ["values"]
|
|
|
|
|
|
def normalize_input(raw_input: dict[str, Any] | None) -> dict[str, Any]:
|
|
"""Convert LangGraph Platform input format to LangChain state dict.
|
|
|
|
Delegates dict→message coercion to ``langchain_core.messages.utils.convert_to_messages``
|
|
so that ``additional_kwargs`` (e.g. uploaded-file metadata — gh #3132), ``id``,
|
|
``name``, and non-human roles (ai/system/tool) survive unchanged. An earlier
|
|
hand-rolled version only forwarded ``content`` and collapsed every role to
|
|
``HumanMessage``, which silently stripped frontend-supplied attachments.
|
|
|
|
Malformed message dicts (missing ``role``/``type``/``content``, unsupported
|
|
role, etc.) raise ``HTTPException(400)`` with the offending index, instead
|
|
of bubbling up as a 500. The gateway is a system boundary, so per-entry
|
|
validation errors are the right shape for clients to retry against.
|
|
"""
|
|
if raw_input is None:
|
|
return {}
|
|
messages = raw_input.get("messages")
|
|
if messages and isinstance(messages, list):
|
|
converted: list[Any] = []
|
|
for index, msg in enumerate(messages):
|
|
if isinstance(msg, BaseMessage):
|
|
converted.append(msg)
|
|
elif isinstance(msg, dict):
|
|
try:
|
|
converted.extend(convert_to_messages([msg]))
|
|
except (ValueError, TypeError, NotImplementedError) as exc:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Invalid message at input.messages[{index}]: {exc}",
|
|
) from exc
|
|
else:
|
|
converted.append(msg)
|
|
return {**raw_input, "messages": converted}
|
|
return raw_input
|
|
|
|
|
|
_DEFAULT_ASSISTANT_ID = "lead_agent"
|
|
|
|
|
|
# Whitelist of run-context keys that the langgraph-compat layer forwards from
|
|
# ``body.context`` into the run config. ``config["context"]`` exists in
|
|
# LangGraph >=0.6, but these values must be written to both ``configurable``
|
|
# (for legacy ``_get_runtime_config`` consumers) and ``context`` because
|
|
# LangGraph >=1.1.9 no longer makes ``ToolRuntime.context`` fall back to
|
|
# ``configurable`` for consumers like ``setup_agent``.
|
|
_CONTEXT_CONFIGURABLE_KEYS: frozenset[str] = frozenset(
|
|
{
|
|
"model_name",
|
|
"mode",
|
|
"thinking_enabled",
|
|
"reasoning_effort",
|
|
"is_plan_mode",
|
|
"subagent_enabled",
|
|
"max_concurrent_subagents",
|
|
"agent_name",
|
|
"is_bootstrap",
|
|
}
|
|
)
|
|
|
|
|
|
def merge_run_context_overrides(config: dict[str, Any], context: Mapping[str, Any] | None) -> None:
|
|
"""Merge whitelisted keys from ``body.context`` into both ``config['configurable']``
|
|
and ``config['context']`` so they are visible to legacy configurable readers and
|
|
to LangGraph ``ToolRuntime.context`` consumers (e.g. the ``setup_agent`` tool —
|
|
see issue #2677).
|
|
|
|
``user_id`` is intentionally propagated into ``config['context']`` in addition to
|
|
the whitelisted keys, so non-web callers (e.g. IM channels) that supply identity in
|
|
``body.context`` keep it on ``ToolRuntime.context``. It is merged with
|
|
``setdefault`` so a server-authenticated id stamped by
|
|
:func:`inject_authenticated_user_context` always wins over the client-supplied one.
|
|
"""
|
|
if not context:
|
|
return
|
|
configurable = config.setdefault("configurable", {})
|
|
runtime_context = config.setdefault("context", {})
|
|
for key in _CONTEXT_CONFIGURABLE_KEYS:
|
|
if key in context:
|
|
if isinstance(configurable, dict):
|
|
configurable.setdefault(key, context[key])
|
|
if isinstance(runtime_context, dict):
|
|
runtime_context.setdefault(key, context[key])
|
|
if "user_id" in context and isinstance(runtime_context, dict):
|
|
runtime_context.setdefault("user_id", context["user_id"])
|
|
|
|
|
|
def inject_authenticated_user_context(config: dict[str, Any], request: Request) -> None:
|
|
"""Stamp the authenticated user into the run context for background tools.
|
|
|
|
Tool execution may happen after the request handler has returned, so tools
|
|
that persist user-scoped files should not rely only on ambient ContextVars.
|
|
The value comes from server-side auth state, never from client context.
|
|
"""
|
|
|
|
user = getattr(request.state, "user", None)
|
|
user_id = getattr(user, "id", None)
|
|
if user_id is None:
|
|
return
|
|
|
|
if getattr(user, "system_role", None) == INTERNAL_SYSTEM_ROLE:
|
|
return
|
|
|
|
runtime_context = config.setdefault("context", {})
|
|
if isinstance(runtime_context, dict):
|
|
runtime_context["user_id"] = str(user_id)
|
|
|
|
|
|
def resolve_agent_factory(assistant_id: str | None):
|
|
"""Resolve the agent factory callable from config.
|
|
|
|
Custom agents are implemented as ``lead_agent`` + an ``agent_name``
|
|
injected into ``configurable`` or ``context`` — see
|
|
:func:`build_run_config`. All ``assistant_id`` values therefore map to the
|
|
same factory; the routing happens inside ``make_lead_agent`` when it reads
|
|
``cfg["agent_name"]``.
|
|
"""
|
|
from deerflow.agents.lead_agent.agent import make_lead_agent
|
|
|
|
return make_lead_agent
|
|
|
|
|
|
def build_run_config(
|
|
thread_id: str,
|
|
request_config: dict[str, Any] | None,
|
|
metadata: dict[str, Any] | None,
|
|
*,
|
|
assistant_id: str | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Build a RunnableConfig dict for the agent.
|
|
|
|
When *assistant_id* refers to a custom agent (anything other than
|
|
``"lead_agent"`` / ``None``), the name is forwarded as ``agent_name`` in
|
|
both ``configurable`` and ``context`` so it is visible to legacy
|
|
configurable readers and to LangGraph ``ToolRuntime.context`` consumers
|
|
(e.g. the ``setup_agent`` tool, which since LangGraph >=1.1.9 no longer
|
|
falls back from ``context`` to ``configurable``). An explicit
|
|
``agent_name`` in either container takes precedence over the value
|
|
derived from ``assistant_id``. ``make_lead_agent`` reads this key to
|
|
load the matching ``agents/<name>/SOUL.md`` and per-agent config —
|
|
without it the agent silently runs as the default lead agent.
|
|
|
|
This mirrors the channel manager's ``_resolve_run_params`` logic so that
|
|
the LangGraph Platform-compatible HTTP API and the IM channel path behave
|
|
identically.
|
|
"""
|
|
config: dict[str, Any] = {"recursion_limit": 100}
|
|
if request_config:
|
|
# LangGraph >= 0.6.0 introduced ``context`` as the preferred way to
|
|
# pass thread-level data and rejects requests that include both
|
|
# ``configurable`` and ``context``. If the caller already sends
|
|
# ``context``, honour it and skip our own ``configurable`` dict.
|
|
if "context" in request_config:
|
|
if "configurable" in request_config:
|
|
logger.warning(
|
|
"build_run_config: client sent both 'context' and 'configurable'; preferring 'context' (LangGraph >= 0.6.0). thread_id=%s, caller_configurable keys=%s",
|
|
thread_id,
|
|
list(request_config.get("configurable", {}).keys()),
|
|
)
|
|
context_value = request_config["context"]
|
|
if context_value is None:
|
|
context = {}
|
|
elif isinstance(context_value, Mapping):
|
|
context = dict(context_value)
|
|
else:
|
|
raise ValueError("request config 'context' must be a mapping or null.")
|
|
config["context"] = context
|
|
else:
|
|
configurable = {"thread_id": thread_id}
|
|
configurable.update(request_config.get("configurable", {}))
|
|
config["configurable"] = configurable
|
|
for k, v in request_config.items():
|
|
if k not in ("configurable", "context"):
|
|
config[k] = v
|
|
else:
|
|
config["configurable"] = {"thread_id": thread_id}
|
|
|
|
# Inject custom agent name when the caller specified a non-default assistant.
|
|
# Honour an explicit agent_name in either runtime options container.
|
|
if assistant_id and assistant_id != _DEFAULT_ASSISTANT_ID:
|
|
normalized = assistant_id.strip().lower().replace("_", "-")
|
|
if not normalized or not re.fullmatch(r"[a-z0-9-]+", normalized):
|
|
raise ValueError(f"Invalid assistant_id {assistant_id!r}: must contain only letters, digits, and hyphens after normalization.")
|
|
configurable = config.setdefault("configurable", {})
|
|
runtime_context = config.setdefault("context", {})
|
|
explicit_agent_name: str | None = None
|
|
if isinstance(configurable, dict) and isinstance(configurable.get("agent_name"), str):
|
|
explicit_agent_name = configurable["agent_name"]
|
|
elif isinstance(runtime_context, dict) and isinstance(runtime_context.get("agent_name"), str):
|
|
explicit_agent_name = runtime_context["agent_name"]
|
|
effective_agent_name = explicit_agent_name or normalized
|
|
if isinstance(configurable, dict):
|
|
configurable["agent_name"] = effective_agent_name
|
|
if isinstance(runtime_context, dict):
|
|
runtime_context["agent_name"] = effective_agent_name
|
|
config.setdefault("run_name", resolve_root_run_name(config, normalized))
|
|
if metadata:
|
|
config.setdefault("metadata", {}).update(metadata)
|
|
return config
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Run lifecycle
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def start_run(
|
|
body: Any,
|
|
thread_id: str,
|
|
request: Request,
|
|
) -> RunRecord:
|
|
"""Create a RunRecord and launch the background agent task.
|
|
|
|
Parameters
|
|
----------
|
|
body : RunCreateRequest
|
|
The validated request body (typed as Any to avoid circular import
|
|
with the router module that defines the Pydantic model).
|
|
thread_id : str
|
|
Target thread.
|
|
request : Request
|
|
FastAPI request — used to retrieve singletons from ``app.state``.
|
|
"""
|
|
bridge = get_stream_bridge(request)
|
|
run_mgr = get_run_manager(request)
|
|
run_ctx = get_run_context(request)
|
|
|
|
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
|
|
|
|
body_context = getattr(body, "context", None) or {}
|
|
model_name = body_context.get("model_name")
|
|
|
|
# Coerce non-string model_name values to str before truncation.
|
|
if model_name is not None and not isinstance(model_name, str):
|
|
model_name = str(model_name)
|
|
|
|
# Validate model against the allowlist when a model_name is provided.
|
|
if model_name:
|
|
app_config = get_app_config()
|
|
resolved = app_config.get_model_config(model_name)
|
|
if resolved is None:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Model {model_name!r} is not in the configured model allowlist",
|
|
)
|
|
|
|
owner_user_id = get_trusted_internal_owner_user_id(request)
|
|
# Stateless run endpoints carry thread_id in the request *body*, so the
|
|
# @require_permission(owner_check=True) decorator -- which resolves ownership
|
|
# from the path param -- cannot protect them. Enforce thread ownership here,
|
|
# before any run is created, so one user cannot start runs on (or read /wait
|
|
# checkpoint state from) another user's thread. Missing rows (auto-created
|
|
# temp threads) and NULL-owner rows (shared / pre-auth data) stay accessible
|
|
# via check_access; only a thread already owned by another user is rejected
|
|
# with 404, matching thread_runs.py's anti-enumeration behaviour. Internal
|
|
# channel runs act on behalf of the connection owner carried in
|
|
# X-DeerFlow-Owner-User-Id, so they are scoped to that owner instead of
|
|
# bypassing the check -- a leaked internal token must not grant cross-user
|
|
# thread access.
|
|
user = getattr(request.state, "user", None)
|
|
if user is not None:
|
|
allowed = await run_ctx.thread_store.check_access(thread_id, str(user.id))
|
|
if not allowed and owner_user_id and getattr(user, "system_role", None) == INTERNAL_SYSTEM_ROLE:
|
|
# Channel workers may also act for the connection owner named in
|
|
# the trusted header (e.g. claiming a legacy default-owned channel
|
|
# thread for its real owner).
|
|
allowed = await run_ctx.thread_store.check_access(thread_id, owner_user_id)
|
|
if not allowed:
|
|
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
|
|
|
owner_context_token = set_current_user(SimpleNamespace(id=owner_user_id)) if owner_user_id else None
|
|
try:
|
|
try:
|
|
record = await run_mgr.create_or_reject(
|
|
thread_id,
|
|
body.assistant_id,
|
|
on_disconnect=disconnect,
|
|
metadata=body.metadata or {},
|
|
kwargs={"input": body.input, "config": body.config},
|
|
multitask_strategy=body.multitask_strategy,
|
|
model_name=model_name,
|
|
user_id=owner_user_id,
|
|
)
|
|
except ConflictError as exc:
|
|
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
|
except UnsupportedStrategyError as exc:
|
|
raise HTTPException(status_code=501, detail=str(exc)) from exc
|
|
|
|
# Upsert thread metadata so the thread appears in /threads/search,
|
|
# even for threads that were never explicitly created via POST /threads
|
|
# (e.g. stateless runs).
|
|
try:
|
|
existing = await run_ctx.thread_store.get(thread_id)
|
|
if existing is None and owner_user_id:
|
|
unscoped_existing = await run_ctx.thread_store.get(thread_id, user_id=None)
|
|
if unscoped_existing is not None:
|
|
if unscoped_existing.get("user_id") != owner_user_id:
|
|
await run_ctx.thread_store.update_owner(thread_id, owner_user_id, user_id=None)
|
|
existing = await run_ctx.thread_store.get(thread_id)
|
|
if existing is None:
|
|
await run_ctx.thread_store.create(
|
|
thread_id,
|
|
assistant_id=body.assistant_id,
|
|
metadata=body.metadata,
|
|
)
|
|
else:
|
|
await run_ctx.thread_store.update_status(thread_id, "running")
|
|
except Exception:
|
|
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
|
|
|
agent_factory = resolve_agent_factory(body.assistant_id)
|
|
graph_input = normalize_input(body.input)
|
|
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
|
|
|
|
# Merge DeerFlow-specific context overrides into both ``configurable`` and ``context``.
|
|
# The ``context`` field is a custom extension for the langgraph-compat layer
|
|
# that carries agent configuration (model_name, thinking_enabled, etc.).
|
|
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
|
|
merge_run_context_overrides(config, getattr(body, "context", None))
|
|
inject_authenticated_user_context(config, request)
|
|
|
|
stream_modes = normalize_stream_modes(body.stream_mode)
|
|
|
|
task = asyncio.create_task(
|
|
run_agent(
|
|
bridge,
|
|
run_mgr,
|
|
record,
|
|
ctx=run_ctx,
|
|
agent_factory=agent_factory,
|
|
graph_input=graph_input,
|
|
config=config,
|
|
stream_modes=stream_modes,
|
|
stream_subgraphs=body.stream_subgraphs,
|
|
interrupt_before=body.interrupt_before,
|
|
interrupt_after=body.interrupt_after,
|
|
)
|
|
)
|
|
record.task = task
|
|
|
|
# Title sync is handled by worker.py's finally block which reads the
|
|
# title from the checkpoint and calls thread_store.update_display_name
|
|
# after the run completes.
|
|
|
|
return record
|
|
finally:
|
|
if owner_context_token is not None:
|
|
reset_current_user(owner_context_token)
|
|
|
|
|
|
async def sse_consumer(
|
|
bridge: StreamBridge,
|
|
record: RunRecord,
|
|
request: Request,
|
|
run_mgr: RunManager,
|
|
):
|
|
"""Async generator that yields SSE frames from the bridge.
|
|
|
|
The ``finally`` block implements ``on_disconnect`` semantics:
|
|
- ``cancel``: abort the background task on client disconnect.
|
|
- ``continue``: let the task run; events are discarded.
|
|
"""
|
|
last_event_id = request.headers.get("Last-Event-ID")
|
|
try:
|
|
async for entry in bridge.subscribe(record.run_id, last_event_id=last_event_id):
|
|
if await request.is_disconnected():
|
|
break
|
|
|
|
if entry is HEARTBEAT_SENTINEL:
|
|
yield ": heartbeat\n\n"
|
|
continue
|
|
|
|
if entry is END_SENTINEL:
|
|
yield format_sse("end", None, event_id=entry.id or None)
|
|
return
|
|
|
|
yield format_sse(entry.event, entry.data, event_id=entry.id or None)
|
|
|
|
finally:
|
|
if record.status in (RunStatus.pending, RunStatus.running):
|
|
if record.on_disconnect == DisconnectMode.cancel:
|
|
await run_mgr.cancel(record.run_id)
|
|
|
|
|
|
async def wait_for_run_completion(
|
|
bridge: StreamBridge,
|
|
record: RunRecord,
|
|
request: Request,
|
|
run_mgr: RunManager,
|
|
) -> bool:
|
|
"""Block until the run publishes ``END_SENTINEL``, honouring on_disconnect.
|
|
|
|
The non-streaming ``/wait`` endpoints used to ``await record.task``
|
|
directly with no disconnect handling. When the client (or an
|
|
intermediate HTTP proxy) timed out during a long tool call such as
|
|
``pip install``, the handler would swallow ``CancelledError`` and
|
|
serialize whatever checkpoint happened to exist — masking a half-finished
|
|
run as a normal completion (issue #3265).
|
|
|
|
This helper consumes the same bridge that ``sse_consumer`` does so the
|
|
wait path shares its disconnect semantics: each wake-up polls
|
|
``request.is_disconnected()``; on a real disconnect it cancels the
|
|
background run when ``record.on_disconnect`` is ``cancel``. The bridge's
|
|
heartbeat sentinels guarantee at least one wake-up per
|
|
``heartbeat_interval`` even when the agent emits no events for a while.
|
|
|
|
Returns:
|
|
``True`` when ``END_SENTINEL`` was observed (run reached a terminal
|
|
state), ``False`` when the loop exited because the client
|
|
disconnected. Callers must skip checkpoint serialization on
|
|
``False`` so a partial checkpoint is not returned as a normal
|
|
response.
|
|
"""
|
|
completed = False
|
|
try:
|
|
async for entry in bridge.subscribe(record.run_id):
|
|
# END_SENTINEL means the run reached a terminal state; honour it
|
|
# even if the client just disconnected so the caller still serializes
|
|
# the real final checkpoint.
|
|
if entry is END_SENTINEL:
|
|
completed = True
|
|
return True
|
|
if await request.is_disconnected():
|
|
break
|
|
# Heartbeats and regular events: keep waiting for END_SENTINEL.
|
|
return completed
|
|
finally:
|
|
if not completed and record.status in (RunStatus.pending, RunStatus.running):
|
|
if record.on_disconnect == DisconnectMode.cancel:
|
|
await run_mgr.cancel(record.run_id)
|