mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-24 17:06:00 +00:00
fix: keep new agent bootstrap in user scope (#2784)
This commit is contained in:
@@ -136,6 +136,24 @@ def merge_run_context_overrides(config: dict[str, Any], context: Mapping[str, An
|
|||||||
runtime_context.setdefault(key, context[key])
|
runtime_context.setdefault(key, context[key])
|
||||||
|
|
||||||
|
|
||||||
|
def inject_authenticated_user_context(config: dict[str, Any], request: Request) -> None:
|
||||||
|
"""Stamp the authenticated user into the run context for background tools.
|
||||||
|
|
||||||
|
Tool execution may happen after the request handler has returned, so tools
|
||||||
|
that persist user-scoped files should not rely only on ambient ContextVars.
|
||||||
|
The value comes from server-side auth state, never from client context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
user = getattr(request.state, "user", None)
|
||||||
|
user_id = getattr(user, "id", None)
|
||||||
|
if user_id is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
runtime_context = config.setdefault("context", {})
|
||||||
|
if isinstance(runtime_context, dict):
|
||||||
|
runtime_context["user_id"] = str(user_id)
|
||||||
|
|
||||||
|
|
||||||
def resolve_agent_factory(assistant_id: str | None):
|
def resolve_agent_factory(assistant_id: str | None):
|
||||||
"""Resolve the agent factory callable from config.
|
"""Resolve the agent factory callable from config.
|
||||||
|
|
||||||
@@ -288,6 +306,7 @@ async def start_run(
|
|||||||
# that carries agent configuration (model_name, thinking_enabled, etc.).
|
# that carries agent configuration (model_name, thinking_enabled, etc.).
|
||||||
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
|
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
|
||||||
merge_run_context_overrides(config, getattr(body, "context", None))
|
merge_run_context_overrides(config, getattr(body, "context", None))
|
||||||
|
inject_authenticated_user_context(config, request)
|
||||||
|
|
||||||
stream_modes = normalize_stream_modes(body.stream_mode)
|
stream_modes = normalize_stream_modes(body.stream_mode)
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,13 @@ from deerflow.tools.types import Runtime
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_runtime_user_id(runtime: Runtime) -> str:
|
||||||
|
context_user_id = runtime.context.get("user_id") if runtime.context else None
|
||||||
|
if context_user_id:
|
||||||
|
return str(context_user_id)
|
||||||
|
return get_effective_user_id()
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def setup_agent(
|
def setup_agent(
|
||||||
soul: str,
|
soul: str,
|
||||||
@@ -38,7 +45,7 @@ def setup_agent(
|
|||||||
if agent_name:
|
if agent_name:
|
||||||
# Custom agents are persisted under the current user's bucket so
|
# Custom agents are persisted under the current user's bucket so
|
||||||
# different users do not see each other's agents.
|
# different users do not see each other's agents.
|
||||||
user_id = get_effective_user_id()
|
user_id = _get_runtime_user_id(runtime)
|
||||||
agent_dir = paths.user_agent_dir(user_id, agent_name)
|
agent_dir = paths.user_agent_dir(user_id, agent_name)
|
||||||
else:
|
else:
|
||||||
# Default agent (no agent_name): SOUL.md lives at the global base dir.
|
# Default agent (no agent_name): SOUL.md lives at the global base dir.
|
||||||
|
|||||||
@@ -324,6 +324,21 @@ def test_context_does_not_override_existing_configurable():
|
|||||||
assert config["configurable"]["subagent_enabled"] is True
|
assert config["configurable"]["subagent_enabled"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_inject_authenticated_user_context_overrides_client_user_id():
|
||||||
|
"""Run context should carry the authenticated user, not client-supplied user_id."""
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from app.gateway.services import build_run_config, inject_authenticated_user_context
|
||||||
|
|
||||||
|
config = build_run_config("thread-1", None, None)
|
||||||
|
config["context"] = {"user_id": "spoofed-client"}
|
||||||
|
request = SimpleNamespace(state=SimpleNamespace(user=SimpleNamespace(id="auth-user-42")))
|
||||||
|
|
||||||
|
inject_authenticated_user_context(config, request)
|
||||||
|
|
||||||
|
assert config["context"]["user_id"] == "auth-user-42"
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# build_run_config — context / configurable precedence (LangGraph >= 0.6.0)
|
# build_run_config — context / configurable precedence (LangGraph >= 0.6.0)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ from pathlib import Path
|
|||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from deerflow.tools.builtins.setup_agent_tool import setup_agent
|
from deerflow.tools.builtins.setup_agent_tool import setup_agent
|
||||||
|
|
||||||
# --- Helpers ---
|
# --- Helpers ---
|
||||||
@@ -126,3 +128,23 @@ class TestSetupAgentNoDataLoss:
|
|||||||
assert agent_dir.exists()
|
assert agent_dir.exists()
|
||||||
assert (agent_dir / "SOUL.md").read_text() == "# My Agent"
|
assert (agent_dir / "SOUL.md").read_text() == "# My Agent"
|
||||||
assert (agent_dir / "config.yaml").exists()
|
assert (agent_dir / "config.yaml").exists()
|
||||||
|
|
||||||
|
@pytest.mark.no_auto_user
|
||||||
|
def test_runtime_user_id_used_when_contextvar_missing(self, tmp_path: Path):
|
||||||
|
"""setup_agent should not fall back to default when runtime carries user_id."""
|
||||||
|
runtime = _DummyRuntime(
|
||||||
|
context={"agent_name": "test-agent", "user_id": "auth-user-42"},
|
||||||
|
tool_call_id="tool-3",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("deerflow.tools.builtins.setup_agent_tool.get_paths", return_value=_make_paths_mock(tmp_path)):
|
||||||
|
setup_agent.func(
|
||||||
|
soul="# My Agent",
|
||||||
|
description="A test agent",
|
||||||
|
runtime=runtime,
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_dir = tmp_path / "users" / "auth-user-42" / "agents" / "test-agent"
|
||||||
|
default_dir = tmp_path / "users" / "default" / "agents" / "test-agent"
|
||||||
|
assert (expected_dir / "SOUL.md").read_text() == "# My Agent"
|
||||||
|
assert not default_dir.exists()
|
||||||
|
|||||||
@@ -35,7 +35,6 @@ import {
|
|||||||
AgentNameCheckError,
|
AgentNameCheckError,
|
||||||
AgentsApiDisabledError,
|
AgentsApiDisabledError,
|
||||||
checkAgentName,
|
checkAgentName,
|
||||||
createAgent,
|
|
||||||
getAgent,
|
getAgent,
|
||||||
} from "@/core/agents/api";
|
} from "@/core/agents/api";
|
||||||
import { useI18n } from "@/core/i18n/hooks";
|
import { useI18n } from "@/core/i18n/hooks";
|
||||||
@@ -71,20 +70,6 @@ async function getAgentWithRetry(agentName: string) {
|
|||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
function getCreateAgentErrorMessage(
|
|
||||||
error: unknown,
|
|
||||||
networkErrorMessage: string,
|
|
||||||
fallbackMessage: string,
|
|
||||||
) {
|
|
||||||
if (error instanceof TypeError && error.message === "Failed to fetch") {
|
|
||||||
return networkErrorMessage;
|
|
||||||
}
|
|
||||||
if (error instanceof Error && error.message) {
|
|
||||||
return error.message;
|
|
||||||
}
|
|
||||||
return fallbackMessage;
|
|
||||||
}
|
|
||||||
|
|
||||||
export default function NewAgentPage() {
|
export default function NewAgentPage() {
|
||||||
const { t } = useI18n();
|
const { t } = useI18n();
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
@@ -93,7 +78,6 @@ export default function NewAgentPage() {
|
|||||||
const [nameInput, setNameInput] = useState("");
|
const [nameInput, setNameInput] = useState("");
|
||||||
const [nameError, setNameError] = useState("");
|
const [nameError, setNameError] = useState("");
|
||||||
const [isCheckingName, setIsCheckingName] = useState(false);
|
const [isCheckingName, setIsCheckingName] = useState(false);
|
||||||
const [isCreatingAgent, setIsCreatingAgent] = useState(false);
|
|
||||||
const [agentName, setAgentName] = useState("");
|
const [agentName, setAgentName] = useState("");
|
||||||
const [agent, setAgent] = useState<Agent | null>(null);
|
const [agent, setAgent] = useState<Agent | null>(null);
|
||||||
const [showSaveHint, setShowSaveHint] = useState(false);
|
const [showSaveHint, setShowSaveHint] = useState(false);
|
||||||
@@ -170,36 +154,16 @@ export default function NewAgentPage() {
|
|||||||
setIsCheckingName(false);
|
setIsCheckingName(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
setIsCreatingAgent(true);
|
|
||||||
try {
|
|
||||||
await createAgent({
|
|
||||||
name: trimmed,
|
|
||||||
description: "",
|
|
||||||
soul: "",
|
|
||||||
});
|
|
||||||
} catch (err) {
|
|
||||||
if (err instanceof AgentsApiDisabledError) {
|
|
||||||
setNameError(t.agents.nameStepApiDisabledError);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
setNameError(
|
|
||||||
getCreateAgentErrorMessage(
|
|
||||||
err,
|
|
||||||
t.agents.nameStepNetworkError,
|
|
||||||
t.agents.nameStepCheckError,
|
|
||||||
),
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
} finally {
|
|
||||||
setIsCreatingAgent(false);
|
|
||||||
}
|
|
||||||
|
|
||||||
setAgentName(trimmed);
|
setAgentName(trimmed);
|
||||||
setStep("chat");
|
setStep("chat");
|
||||||
await sendMessage(threadId, {
|
await sendMessage(
|
||||||
text: t.agents.nameStepBootstrapMessage.replace("{name}", trimmed),
|
threadId,
|
||||||
files: [],
|
{
|
||||||
});
|
text: t.agents.nameStepBootstrapMessage.replace("{name}", trimmed),
|
||||||
|
files: [],
|
||||||
|
},
|
||||||
|
{ agent_name: trimmed },
|
||||||
|
);
|
||||||
}, [
|
}, [
|
||||||
nameInput,
|
nameInput,
|
||||||
sendMessage,
|
sendMessage,
|
||||||
@@ -345,9 +309,7 @@ export default function NewAgentPage() {
|
|||||||
<Button
|
<Button
|
||||||
className="w-full"
|
className="w-full"
|
||||||
onClick={() => void handleConfirmName()}
|
onClick={() => void handleConfirmName()}
|
||||||
disabled={
|
disabled={!nameInput.trim() || isCheckingName}
|
||||||
!nameInput.trim() || isCheckingName || isCreatingAgent
|
|
||||||
}
|
|
||||||
>
|
>
|
||||||
{t.agents.nameStepContinue}
|
{t.agents.nameStepContinue}
|
||||||
</Button>
|
</Button>
|
||||||
|
|||||||
@@ -207,7 +207,7 @@ export const enUS: Translations = {
|
|||||||
nameStepApiDisabledError:
|
nameStepApiDisabledError:
|
||||||
"Custom agent management is not enabled on this server. Please contact your administrator.",
|
"Custom agent management is not enabled on this server. Please contact your administrator.",
|
||||||
nameStepBootstrapMessage:
|
nameStepBootstrapMessage:
|
||||||
"The new custom agent name is {name}. Let's bootstrap it's **SOUL**.",
|
"The new custom agent name is {name}. Help me design its purpose, behavior, and SOUL.md before saving it.",
|
||||||
save: "Save agent",
|
save: "Save agent",
|
||||||
saving: "Saving agent...",
|
saving: "Saving agent...",
|
||||||
saveRequested:
|
saveRequested:
|
||||||
|
|||||||
@@ -195,7 +195,7 @@ export const zhCN: Translations = {
|
|||||||
nameStepApiDisabledError:
|
nameStepApiDisabledError:
|
||||||
"服务器未开启自定义智能体管理功能,请联系管理员。",
|
"服务器未开启自定义智能体管理功能,请联系管理员。",
|
||||||
nameStepBootstrapMessage:
|
nameStepBootstrapMessage:
|
||||||
"新智能体的名称是 {name},现在开始为它生成 **SOUL**。",
|
"新智能体的名称是 {name}。请先帮我设计它的用途、行为方式和 SOUL.md,再保存它。",
|
||||||
save: "保存智能体",
|
save: "保存智能体",
|
||||||
saving: "正在保存智能体...",
|
saving: "正在保存智能体...",
|
||||||
saveRequested:
|
saveRequested:
|
||||||
|
|||||||
Reference in New Issue
Block a user