fix: keep new agent bootstrap in user scope (#2784)

This commit is contained in:
Eilen Shin
2026-05-09 19:43:50 +08:00
committed by GitHub
parent 417416087b
commit 1c96a6afc8
7 changed files with 75 additions and 50 deletions
+19
View File
@@ -136,6 +136,24 @@ def merge_run_context_overrides(config: dict[str, Any], context: Mapping[str, An
runtime_context.setdefault(key, context[key])
def inject_authenticated_user_context(config: dict[str, Any], request: Request) -> None:
"""Stamp the authenticated user into the run context for background tools.
Tool execution may happen after the request handler has returned, so tools
that persist user-scoped files should not rely only on ambient ContextVars.
The value comes from server-side auth state, never from client context.
"""
user = getattr(request.state, "user", None)
user_id = getattr(user, "id", None)
if user_id is None:
return
runtime_context = config.setdefault("context", {})
if isinstance(runtime_context, dict):
runtime_context["user_id"] = str(user_id)
def resolve_agent_factory(assistant_id: str | None):
"""Resolve the agent factory callable from config.
@@ -288,6 +306,7 @@ async def start_run(
# that carries agent configuration (model_name, thinking_enabled, etc.).
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
merge_run_context_overrides(config, getattr(body, "context", None))
inject_authenticated_user_context(config, request)
stream_modes = normalize_stream_modes(body.stream_mode)
@@ -13,6 +13,13 @@ from deerflow.tools.types import Runtime
logger = logging.getLogger(__name__)
def _get_runtime_user_id(runtime: Runtime) -> str:
context_user_id = runtime.context.get("user_id") if runtime.context else None
if context_user_id:
return str(context_user_id)
return get_effective_user_id()
@tool
def setup_agent(
soul: str,
@@ -38,7 +45,7 @@ def setup_agent(
if agent_name:
# Custom agents are persisted under the current user's bucket so
# different users do not see each other's agents.
user_id = get_effective_user_id()
user_id = _get_runtime_user_id(runtime)
agent_dir = paths.user_agent_dir(user_id, agent_name)
else:
# Default agent (no agent_name): SOUL.md lives at the global base dir.
+15
View File
@@ -324,6 +324,21 @@ def test_context_does_not_override_existing_configurable():
assert config["configurable"]["subagent_enabled"] is True
def test_inject_authenticated_user_context_overrides_client_user_id():
"""Run context should carry the authenticated user, not client-supplied user_id."""
from types import SimpleNamespace
from app.gateway.services import build_run_config, inject_authenticated_user_context
config = build_run_config("thread-1", None, None)
config["context"] = {"user_id": "spoofed-client"}
request = SimpleNamespace(state=SimpleNamespace(user=SimpleNamespace(id="auth-user-42")))
inject_authenticated_user_context(config, request)
assert config["context"]["user_id"] == "auth-user-42"
# ---------------------------------------------------------------------------
# build_run_config — context / configurable precedence (LangGraph >= 0.6.0)
# ---------------------------------------------------------------------------
+22
View File
@@ -6,6 +6,8 @@ from pathlib import Path
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from deerflow.tools.builtins.setup_agent_tool import setup_agent
# --- Helpers ---
@@ -126,3 +128,23 @@ class TestSetupAgentNoDataLoss:
assert agent_dir.exists()
assert (agent_dir / "SOUL.md").read_text() == "# My Agent"
assert (agent_dir / "config.yaml").exists()
@pytest.mark.no_auto_user
def test_runtime_user_id_used_when_contextvar_missing(self, tmp_path: Path):
"""setup_agent should not fall back to default when runtime carries user_id."""
runtime = _DummyRuntime(
context={"agent_name": "test-agent", "user_id": "auth-user-42"},
tool_call_id="tool-3",
)
with patch("deerflow.tools.builtins.setup_agent_tool.get_paths", return_value=_make_paths_mock(tmp_path)):
setup_agent.func(
soul="# My Agent",
description="A test agent",
runtime=runtime,
)
expected_dir = tmp_path / "users" / "auth-user-42" / "agents" / "test-agent"
default_dir = tmp_path / "users" / "default" / "agents" / "test-agent"
assert (expected_dir / "SOUL.md").read_text() == "# My Agent"
assert not default_dir.exists()