mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 16:35:59 +00:00
fix(agents): propagate agent_name into ToolRuntime.context for setup_agent (#2679)
* fix(agents): propagate agent_name into ToolRuntime.context for setup_agent (#2677) When creating a custom agent via the web UI, SOUL.md was always written to the global base_dir/SOUL.md instead of agents/<name>/SOUL.md. Root cause: the bootstrap flow sends agent_name via body.context, but two layers were broken: 1. services.py only forwarded body.context keys into config["configurable"]; config["context"] was never populated. 2. worker.py constructed the parent Runtime with a hard-coded {thread_id, run_id} context, ignoring config["context"] entirely. After the langgraph >= 1.1.9 bump (#98a5b34f), ToolRuntime.context no longer falls back to configurable, so setup_agent's runtime.context.get("agent_name") returned None and the tool's silent agent_name=None -> base_dir fallback kicked in, overwriting the global SOUL.md. Fix: - services.py: extract merge_run_context_overrides() and write the whitelisted context keys into both configurable (legacy readers) and context (langgraph 1.1+ ToolRuntime consumers). - worker.py: extract _build_runtime_context() and merge config["context"] into the Runtime's context (without letting callers override thread_id/run_id). The base_dir fallback in setup_agent_tool.py is left in place because the IM /bootstrap channel command depends on it. That code path can be tightened in a follow-up. Adds regression tests covering both helpers. * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -98,6 +98,44 @@ def normalize_input(raw_input: dict[str, Any] | None) -> dict[str, Any]:
|
|||||||
_DEFAULT_ASSISTANT_ID = "lead_agent"
|
_DEFAULT_ASSISTANT_ID = "lead_agent"
|
||||||
|
|
||||||
|
|
||||||
|
# Whitelist of run-context keys that the langgraph-compat layer forwards from
|
||||||
|
# ``body.context`` into the run config. ``config["context"]`` exists in
|
||||||
|
# LangGraph >=0.6, but these values must be written to both ``configurable``
|
||||||
|
# (for legacy ``_get_runtime_config`` consumers) and ``context`` because
|
||||||
|
# LangGraph >=1.1.9 no longer makes ``ToolRuntime.context`` fall back to
|
||||||
|
# ``configurable`` for consumers like ``setup_agent``.
|
||||||
|
_CONTEXT_CONFIGURABLE_KEYS: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
|
"model_name",
|
||||||
|
"mode",
|
||||||
|
"thinking_enabled",
|
||||||
|
"reasoning_effort",
|
||||||
|
"is_plan_mode",
|
||||||
|
"subagent_enabled",
|
||||||
|
"max_concurrent_subagents",
|
||||||
|
"agent_name",
|
||||||
|
"is_bootstrap",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_run_context_overrides(config: dict[str, Any], context: Mapping[str, Any] | None) -> None:
|
||||||
|
"""Merge whitelisted keys from ``body.context`` into both ``config['configurable']``
|
||||||
|
and ``config['context']`` so they are visible to legacy configurable readers and
|
||||||
|
to LangGraph ``ToolRuntime.context`` consumers (e.g. the ``setup_agent`` tool —
|
||||||
|
see issue #2677)."""
|
||||||
|
if not context:
|
||||||
|
return
|
||||||
|
configurable = config.setdefault("configurable", {})
|
||||||
|
runtime_context = config.setdefault("context", {})
|
||||||
|
for key in _CONTEXT_CONFIGURABLE_KEYS:
|
||||||
|
if key in context:
|
||||||
|
if isinstance(configurable, dict):
|
||||||
|
configurable.setdefault(key, context[key])
|
||||||
|
if isinstance(runtime_context, dict):
|
||||||
|
runtime_context.setdefault(key, context[key])
|
||||||
|
|
||||||
|
|
||||||
def resolve_agent_factory(assistant_id: str | None):
|
def resolve_agent_factory(assistant_id: str | None):
|
||||||
"""Resolve the agent factory callable from config.
|
"""Resolve the agent factory callable from config.
|
||||||
|
|
||||||
@@ -245,27 +283,11 @@ async def start_run(
|
|||||||
graph_input = normalize_input(body.input)
|
graph_input = normalize_input(body.input)
|
||||||
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
|
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
|
||||||
|
|
||||||
# Merge DeerFlow-specific context overrides into configurable.
|
# Merge DeerFlow-specific context overrides into both ``configurable`` and ``context``.
|
||||||
# The ``context`` field is a custom extension for the langgraph-compat layer
|
# The ``context`` field is a custom extension for the langgraph-compat layer
|
||||||
# that carries agent configuration (model_name, thinking_enabled, etc.).
|
# that carries agent configuration (model_name, thinking_enabled, etc.).
|
||||||
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
|
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
|
||||||
context = getattr(body, "context", None)
|
merge_run_context_overrides(config, getattr(body, "context", None))
|
||||||
if context:
|
|
||||||
_CONTEXT_CONFIGURABLE_KEYS = {
|
|
||||||
"model_name",
|
|
||||||
"mode",
|
|
||||||
"thinking_enabled",
|
|
||||||
"reasoning_effort",
|
|
||||||
"is_plan_mode",
|
|
||||||
"subagent_enabled",
|
|
||||||
"max_concurrent_subagents",
|
|
||||||
"agent_name",
|
|
||||||
"is_bootstrap",
|
|
||||||
}
|
|
||||||
configurable = config.setdefault("configurable", {})
|
|
||||||
for key in _CONTEXT_CONFIGURABLE_KEYS:
|
|
||||||
if key in context:
|
|
||||||
configurable.setdefault(key, context[key])
|
|
||||||
|
|
||||||
stream_modes = normalize_stream_modes(body.stream_mode)
|
stream_modes = normalize_stream_modes(body.stream_mode)
|
||||||
|
|
||||||
|
|||||||
@@ -39,6 +39,24 @@ logger = logging.getLogger(__name__)
|
|||||||
_VALID_LG_MODES = {"values", "updates", "checkpoints", "tasks", "debug", "messages", "custom"}
|
_VALID_LG_MODES = {"values", "updates", "checkpoints", "tasks", "debug", "messages", "custom"}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_runtime_context(thread_id: str, run_id: str, caller_context: Any | None) -> dict[str, Any]:
|
||||||
|
"""Build the dict that becomes ``ToolRuntime.context`` for the run.
|
||||||
|
|
||||||
|
Always includes ``thread_id`` and ``run_id``. Additional keys from the caller's
|
||||||
|
``config['context']`` (e.g. ``agent_name`` for the bootstrap flow — issue #2677)
|
||||||
|
are merged in but never override ``thread_id``/``run_id``.
|
||||||
|
|
||||||
|
langgraph 1.1+ surfaces this as ``runtime.context`` via the parent runtime stored
|
||||||
|
under ``config['configurable']['__pregel_runtime']`` — see
|
||||||
|
``langgraph.pregel.main`` where ``parent_runtime.merge(...)`` is invoked.
|
||||||
|
"""
|
||||||
|
runtime_ctx: dict[str, Any] = {"thread_id": thread_id, "run_id": run_id}
|
||||||
|
if isinstance(caller_context, dict):
|
||||||
|
for key, value in caller_context.items():
|
||||||
|
runtime_ctx.setdefault(key, value)
|
||||||
|
return runtime_ctx
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class RunContext:
|
class RunContext:
|
||||||
"""Infrastructure dependencies for a single agent run.
|
"""Infrastructure dependencies for a single agent run.
|
||||||
@@ -169,15 +187,15 @@ async def run_agent(
|
|||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
# Inject runtime context so middlewares can access thread_id
|
# Inject runtime context so middlewares and tools (via ToolRuntime.context) can
|
||||||
# (langgraph-cli does this automatically; we must do it manually)
|
# access thread-level data. langgraph-cli does this automatically; we must do it
|
||||||
runtime = Runtime(context={"thread_id": thread_id, "run_id": run_id}, store=store)
|
# manually here because we drive the graph through ``agent.astream(config=...)``
|
||||||
# If the caller already set a ``context`` key (LangGraph >= 0.6.0
|
# without passing the official ``context=`` parameter.
|
||||||
# prefers it over ``configurable`` for thread-level data), make
|
runtime_ctx = _build_runtime_context(thread_id, run_id, config.get("context"))
|
||||||
# sure ``thread_id`` is available there too.
|
|
||||||
if "context" in config and isinstance(config["context"], dict):
|
if "context" in config and isinstance(config["context"], dict):
|
||||||
config["context"].setdefault("thread_id", thread_id)
|
config["context"].setdefault("thread_id", thread_id)
|
||||||
config["context"].setdefault("run_id", run_id)
|
config["context"].setdefault("run_id", run_id)
|
||||||
|
runtime = Runtime(context=runtime_ctx, store=store)
|
||||||
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
|
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
|
||||||
|
|
||||||
# Inject RunJournal as a LangChain callback handler.
|
# Inject RunJournal as a LangChain callback handler.
|
||||||
|
|||||||
@@ -256,6 +256,37 @@ def test_context_merges_into_configurable():
|
|||||||
assert "thread_id" not in {k for k in context if k in _CONTEXT_CONFIGURABLE_KEYS}
|
assert "thread_id" not in {k for k in context if k in _CONTEXT_CONFIGURABLE_KEYS}
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_run_context_overrides_propagates_to_runtime_context():
|
||||||
|
"""Regression for issue #2677: ``agent_name`` (and other whitelisted keys) from
|
||||||
|
``body.context`` must be propagated into BOTH ``config['configurable']`` and
|
||||||
|
``config['context']``. Previously only ``configurable`` was populated, so after
|
||||||
|
the LangGraph 1.1.x upgrade removed the fallback from ``configurable``, the
|
||||||
|
``setup_agent`` tool read ``runtime.context`` with ``agent_name=None`` and
|
||||||
|
silently wrote SOUL.md to the global base_dir.
|
||||||
|
"""
|
||||||
|
from app.gateway.services import build_run_config, merge_run_context_overrides
|
||||||
|
|
||||||
|
config = build_run_config("thread-1", None, None)
|
||||||
|
merge_run_context_overrides(config, {"agent_name": "my-agent", "is_bootstrap": True, "thread_id": "ignored"})
|
||||||
|
|
||||||
|
assert config["configurable"]["agent_name"] == "my-agent"
|
||||||
|
assert config["configurable"]["is_bootstrap"] is True
|
||||||
|
assert config["context"]["agent_name"] == "my-agent"
|
||||||
|
assert config["context"]["is_bootstrap"] is True
|
||||||
|
# Non-whitelisted keys are not forwarded.
|
||||||
|
assert "thread_id" not in config["context"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_run_context_overrides_noop_for_empty_context():
|
||||||
|
from app.gateway.services import build_run_config, merge_run_context_overrides
|
||||||
|
|
||||||
|
config = build_run_config("thread-1", None, None)
|
||||||
|
before = {k: dict(v) if isinstance(v, dict) else v for k, v in config.items()}
|
||||||
|
merge_run_context_overrides(config, None)
|
||||||
|
merge_run_context_overrides(config, {})
|
||||||
|
assert config == before
|
||||||
|
|
||||||
|
|
||||||
def test_context_does_not_override_existing_configurable():
|
def test_context_does_not_override_existing_configurable():
|
||||||
"""Values already in config.configurable must NOT be overridden by context."""
|
"""Values already in config.configurable must NOT be overridden by context."""
|
||||||
from app.gateway.services import build_run_config
|
from app.gateway.services import build_run_config
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from unittest.mock import AsyncMock, call
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from deerflow.runtime.runs.worker import _agent_factory_supports_app_config, _rollback_to_pre_run_checkpoint
|
from deerflow.runtime.runs.worker import _agent_factory_supports_app_config, _build_runtime_context, _rollback_to_pre_run_checkpoint
|
||||||
|
|
||||||
|
|
||||||
class FakeCheckpointer:
|
class FakeCheckpointer:
|
||||||
@@ -221,6 +221,43 @@ def test_agent_factory_supports_app_config_detects_supported_signature():
|
|||||||
assert _agent_factory_supports_app_config(factory) is True
|
assert _agent_factory_supports_app_config(factory) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_runtime_context_defaults_to_thread_and_run_id():
|
||||||
|
ctx = _build_runtime_context("thread-1", "run-1", None)
|
||||||
|
assert ctx == {"thread_id": "thread-1", "run_id": "run-1"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_runtime_context_merges_caller_context():
|
||||||
|
"""Regression for issue #2677: keys from ``config['context']`` (e.g. ``agent_name``)
|
||||||
|
must be merged into the Runtime's context so that ``ToolRuntime.context`` — which
|
||||||
|
is what ``setup_agent`` reads — can see them."""
|
||||||
|
caller_context = {"agent_name": "my-agent", "is_bootstrap": True, "model_name": "gpt-4"}
|
||||||
|
|
||||||
|
ctx = _build_runtime_context("thread-1", "run-1", caller_context)
|
||||||
|
|
||||||
|
assert ctx["thread_id"] == "thread-1"
|
||||||
|
assert ctx["run_id"] == "run-1"
|
||||||
|
assert ctx["agent_name"] == "my-agent"
|
||||||
|
assert ctx["is_bootstrap"] is True
|
||||||
|
assert ctx["model_name"] == "gpt-4"
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_runtime_context_caller_cannot_override_thread_id_or_run_id():
|
||||||
|
"""A malicious or buggy caller must not be able to overwrite the worker-assigned
|
||||||
|
``thread_id`` / ``run_id`` by stuffing them into ``config['context']``."""
|
||||||
|
caller_context = {"thread_id": "spoofed", "run_id": "spoofed", "agent_name": "ok"}
|
||||||
|
|
||||||
|
ctx = _build_runtime_context("real-thread", "real-run", caller_context)
|
||||||
|
|
||||||
|
assert ctx["thread_id"] == "real-thread"
|
||||||
|
assert ctx["run_id"] == "real-run"
|
||||||
|
assert ctx["agent_name"] == "ok"
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_runtime_context_ignores_non_dict_caller_context():
|
||||||
|
ctx = _build_runtime_context("thread-1", "run-1", "not-a-dict")
|
||||||
|
assert ctx == {"thread_id": "thread-1", "run_id": "run-1"}
|
||||||
|
|
||||||
|
|
||||||
def test_agent_factory_supports_app_config_returns_false_when_signature_lookup_fails(monkeypatch):
|
def test_agent_factory_supports_app_config_returns_false_when_signature_lookup_fails(monkeypatch):
|
||||||
class BrokenCallable:
|
class BrokenCallable:
|
||||||
def __call__(self, **kwargs):
|
def __call__(self, **kwargs):
|
||||||
|
|||||||
Reference in New Issue
Block a user