Merge remote-tracking branch 'origin/main' into codex/im-channel-connections

# Conflicts:
#	backend/app/channels/discord.py
#	backend/app/channels/manager.py
#	backend/app/channels/slack.py
#	backend/app/channels/telegram.py
This commit is contained in:
taohe
2026-06-10 21:13:02 +08:00
85 changed files with 5575 additions and 253 deletions
+7
View File
@@ -18,3 +18,10 @@ KNOWN_CHANNEL_COMMANDS: frozenset[str] = frozenset(
"/help",
}
)
def is_known_channel_command(text: str) -> bool:
"""Return whether text starts with a registered channel control command."""
if not text.startswith("/"):
return False
return text.split(maxsplit=1)[0].lower() in KNOWN_CHANNEL_COMMANDS
+2 -4
View File
@@ -14,7 +14,7 @@ from typing import Any
import httpx
from app.channels.base import Channel
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
from app.channels.commands import is_known_channel_command
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
@@ -59,9 +59,7 @@ def _normalize_allowed_users(allowed_users: Any) -> set[str]:
def _is_dingtalk_command(text: str) -> bool:
if not text.startswith("/"):
return False
return text.split(maxsplit=1)[0].lower() in KNOWN_CHANNEL_COMMANDS
return is_known_channel_command(text)
def _extract_text_from_rich_text(rich_text_list: list) -> str:
+3 -2
View File
@@ -10,6 +10,7 @@ from pathlib import Path
from typing import Any
from app.channels.base import Channel
from app.channels.commands import is_known_channel_command
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
@@ -301,7 +302,7 @@ class DiscordChannel(Channel):
# If this is a known active thread, process normally
if thread_id in self._active_thread_ids:
msg_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
msg_type = InboundMessageType.COMMAND if is_known_channel_command(text) else InboundMessageType.CHAT
inbound = self._make_inbound(
chat_id=chat_id,
user_id=str(message.author.id),
@@ -409,7 +410,7 @@ class DiscordChannel(Channel):
chat_id = channel_id
typing_target = message.channel # Type into the channel
msg_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
msg_type = InboundMessageType.COMMAND if is_known_channel_command(text) else InboundMessageType.CHAT
inbound = self._make_inbound(
chat_id=chat_id,
user_id=str(message.author.id),
+2 -4
View File
@@ -11,7 +11,7 @@ import time
from typing import Any, Literal
from app.channels.base import Channel
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
from app.channels.commands import is_known_channel_command
from app.channels.message_bus import (
PENDING_CLARIFICATION_METADATA_KEY,
RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY,
@@ -30,9 +30,7 @@ PENDING_CLARIFICATION_TTL_SECONDS = 30 * 60
def _is_feishu_command(text: str) -> bool:
if not text.startswith("/"):
return False
return text.split(maxsplit=1)[0].lower() in KNOWN_CHANNEL_COMMANDS
return is_known_channel_command(text)
class FeishuChannel(Channel):
+127 -13
View File
@@ -8,6 +8,7 @@ import mimetypes
import re
import time
from collections.abc import Awaitable, Callable, Mapping
from dataclasses import dataclass
from pathlib import Path
from typing import Any
@@ -26,8 +27,13 @@ from app.channels.message_bus import (
from app.channels.store import ChannelStore
from app.gateway.csrf_middleware import CSRF_COOKIE_NAME, CSRF_HEADER_NAME, generate_csrf_token
from app.gateway.internal_auth import create_internal_auth_headers
from deerflow.config.agents_config import load_agent_config
from deerflow.config.paths import make_safe_user_id
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.skills.slash import parse_slash_skill_reference
from deerflow.skills.storage import get_or_new_skill_storage
from deerflow.skills.storage.skill_storage import SkillStorage
from deerflow.utils.messages import ORIGINAL_USER_CONTENT_KEY
logger = logging.getLogger(__name__)
@@ -124,6 +130,16 @@ class InvalidChannelSessionConfigError(ValueError):
"""Raised when IM channel session overrides contain invalid agent config."""
class SlashSkillCommandResolutionError(RuntimeError):
"""Raised when IM slash-skill command resolution cannot complete safely."""
@dataclass(frozen=True, slots=True)
class _SlashSkillCommandResolution:
route_to_chat: bool = False
failure_message: str | None = None
def _is_thread_busy_error(exc: BaseException | None) -> bool:
if exc is None:
return False
@@ -410,6 +426,46 @@ def _format_artifact_text(artifacts: list[str]) -> str:
_OUTPUTS_VIRTUAL_PREFIX = "/mnt/user-data/outputs/"
def _unknown_command_reply(command: str | None = None) -> str:
available = " | ".join(sorted(KNOWN_CHANNEL_COMMANDS))
if command:
return f"Unknown command: /{command}. Available commands: {available}"
return f"Unknown command. Available commands: {available}"
def _human_input_message(content: str, *, original_content: str | None = None) -> dict[str, Any]:
message: dict[str, Any] = {"role": "human", "content": content}
if original_content is not None and original_content != content:
message["additional_kwargs"] = {ORIGINAL_USER_CONTENT_KEY: original_content}
return message
def _resolve_slash_skill_command(
text: str,
available_skills: set[str] | None = None,
storage: SkillStorage | Callable[[], SkillStorage] | None = None,
) -> _SlashSkillCommandResolution | None:
reference = parse_slash_skill_reference(text)
if reference is None:
return None
try:
resolved_storage = storage() if callable(storage) else storage or get_or_new_skill_storage()
skills = resolved_storage.load_skills(enabled_only=False)
skill = next((candidate for candidate in skills if candidate.name == reference.name), None)
if skill is None:
return None
if not skill.enabled:
return _SlashSkillCommandResolution(failure_message=f"Skill `/{reference.name}` is installed but disabled. Enable it before using slash activation.")
if available_skills is not None and reference.name not in available_skills:
return _SlashSkillCommandResolution(failure_message=f"Skill `/{reference.name}` is not available for this agent.")
return _SlashSkillCommandResolution(route_to_chat=True)
except Exception as exc:
logger.exception("[Manager] failed to resolve slash skill command")
raise SlashSkillCommandResolutionError("Failed to resolve slash skill command. Please check the skill configuration.") from exc
def _resolve_attachments(thread_id: str, artifacts: list[str]) -> list[ResolvedAttachment]:
"""Resolve virtual artifact paths to host filesystem paths with metadata.
@@ -626,6 +682,7 @@ class ChannelManager:
self._channel_sessions = dict(channel_sessions or {})
self._connection_repo = connection_repo
self._client = None # lazy init — langgraph_sdk async client
self._skill_storage: SkillStorage | None = None
self._csrf_token = generate_csrf_token()
self._semaphore: asyncio.Semaphore | None = None
self._running = False
@@ -702,6 +759,21 @@ class ChannelManager:
return assistant_id, run_config, run_context
def _resolve_available_skill_names(self, msg: InboundMessage) -> set[str] | None:
thread_id = self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id) or ""
_, _, run_context = self._resolve_run_params(msg, thread_id)
if run_context.get("is_bootstrap"):
return {"bootstrap"}
agent_name = run_context.get("agent_name")
if not isinstance(agent_name, str) or not agent_name.strip():
return None
agent_config = load_agent_config(_normalize_custom_agent_name(agent_name))
if agent_config and agent_config.skills is not None:
return set(agent_config.skills)
return None
# -- LangGraph SDK client (lazy) ----------------------------------------
def _get_client(self):
@@ -719,6 +791,11 @@ class ChannelManager:
)
return self._client
def _get_skill_storage(self) -> SkillStorage:
if self._skill_storage is None:
self._skill_storage = get_or_new_skill_storage()
return self._skill_storage
# -- lifecycle ---------------------------------------------------------
async def start(self) -> None:
@@ -788,6 +865,14 @@ class ChannelManager:
exc,
)
await self._send_error(msg, str(exc))
except SlashSkillCommandResolutionError as exc:
logger.warning(
"Slash skill command resolution failed for %s (chat=%s): %s",
msg.channel_name,
msg.chat_id,
exc,
)
await self._send_error(msg, str(exc))
except Exception:
logger.exception(
"Error handling message from %s (chat=%s)",
@@ -865,9 +950,11 @@ class ChannelManager:
if extra_context:
run_context.update(extra_context)
original_text = msg.text
uploaded = await _ingest_inbound_files(thread_id, msg)
if uploaded:
msg.text = f"{_format_uploaded_files_block(uploaded)}\n\n{msg.text}".strip()
human_message = _human_input_message(msg.text, original_content=original_text)
if self._channel_supports_streaming(msg.channel_name):
await self._handle_streaming_chat(
@@ -877,6 +964,7 @@ class ChannelManager:
assistant_id,
run_config,
run_context,
human_message,
)
return
@@ -885,7 +973,7 @@ class ChannelManager:
result = await client.runs.wait(
thread_id,
assistant_id,
input={"messages": [{"role": "human", "content": msg.text}]},
input={"messages": [human_message]},
config=run_config,
context=run_context,
multitask_strategy="reject",
@@ -940,6 +1028,7 @@ class ChannelManager:
assistant_id: str,
run_config: dict[str, Any],
run_context: dict[str, Any],
human_message: dict[str, Any],
) -> None:
logger.info("[Manager] invoking runs.stream(thread_id=%s, text=%r)", thread_id, msg.text[:100])
@@ -955,7 +1044,7 @@ class ChannelManager:
async for chunk in client.runs.stream(
thread_id,
assistant_id,
input={"messages": [{"role": "human", "content": msg.text}]},
input={"messages": [human_message]},
config=run_config,
context=run_context,
stream_mode=["messages-tuple", "values"],
@@ -1046,11 +1135,20 @@ class ChannelManager:
# -- command handling --------------------------------------------------
async def _handle_command(self, msg: InboundMessage) -> None:
text = msg.text.strip()
raw_text = msg.text
text = raw_text.strip()
parts = text.split(maxsplit=1)
command = parts[0].lower().lstrip("/")
reply: str | None = None
if not parts:
command = None
reply = _unknown_command_reply()
else:
command = parts[0].lower().removeprefix("/")
if command == "bootstrap":
if reply is None and not raw_text.startswith("/"):
reply = _unknown_command_reply(command)
if reply is None and command == "bootstrap":
from dataclasses import replace as _dc_replace
chat_text = parts[1] if len(parts) > 1 else "Initialize workspace"
@@ -1058,21 +1156,21 @@ class ChannelManager:
await self._handle_chat(chat_msg, extra_context={"is_bootstrap": True})
return
if command == "new":
if reply is None and command == "new":
# Create a new thread through Gateway
client = self._get_client()
thread = await client.threads.create()
new_thread_id = thread["thread_id"]
await self._store_thread_id(msg, new_thread_id)
reply = "New conversation started."
elif command == "status":
elif reply is None and command == "status":
thread_id = await self._lookup_thread_id(msg)
reply = f"Active thread: {thread_id}" if thread_id else "No active conversation."
elif command == "models":
elif reply is None and command == "models":
reply = await self._fetch_gateway("/api/models", "models")
elif command == "memory":
elif reply is None and command == "memory":
reply = await self._fetch_gateway("/api/memory", "memory")
elif command == "help":
elif reply is None and command == "help":
reply = (
"Available commands:\n"
"/bootstrap — Start a bootstrap session (enables agent setup)\n"
@@ -1080,11 +1178,27 @@ class ChannelManager:
"/status — Show current thread info\n"
"/models — List available models\n"
"/memory — Show memory status\n"
"/<skill-name> <task> — Activate an enabled skill for one turn\n"
"/help — Show this help"
)
else:
available = " | ".join(sorted(KNOWN_CHANNEL_COMMANDS))
reply = f"Unknown command: /{command}. Available commands: {available}"
elif reply is None:
slash_resolution = await asyncio.to_thread(
lambda: _resolve_slash_skill_command(
raw_text,
self._resolve_available_skill_names(msg),
self._get_skill_storage,
)
)
if slash_resolution and slash_resolution.failure_message:
reply = slash_resolution.failure_message
elif slash_resolution and slash_resolution.route_to_chat:
from dataclasses import replace as _dc_replace
chat_msg = _dc_replace(msg, msg_type=InboundMessageType.CHAT)
await self._handle_chat(chat_msg)
return
else:
reply = _unknown_command_reply(command)
outbound = OutboundMessage(
channel_name=msg.channel_name,
+37 -1
View File
@@ -9,6 +9,7 @@ from typing import Any
from markdown_to_mrkdwn import SlackMarkdownConverter
from app.channels.base import Channel
from app.channels.commands import is_known_channel_command
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
@@ -32,6 +33,20 @@ def _normalize_allowed_users(allowed_users: Any) -> set[str]:
return {str(user_id) for user_id in values if str(user_id)}
def _strip_leading_slack_bot_mention(text: str, bot_user_id: str | None) -> str:
if not bot_user_id:
return text
if not text.startswith("<@"):
return text
end = text.find(">")
if end <= 2:
return text
mentioned_user_id = text[2:end].split("|", 1)[0].lstrip("!")
if mentioned_user_id != bot_user_id:
return text
return text[end + 1 :].lstrip()
class SlackChannel(Channel):
"""Slack IM channel using Socket Mode (WebSocket, no public IP).
@@ -51,6 +66,8 @@ class SlackChannel(Channel):
self._allowed_users = _normalize_allowed_users(config.get("allowed_users", []))
self._connection_repo = config.get("connection_repo")
self._web_client_factory = config.get("web_client_factory")
configured_bot_user_id = config.get("bot_user_id")
self._bot_user_id = str(configured_bot_user_id).lstrip("@") if configured_bot_user_id else None
async def start(self) -> None:
if self._running:
@@ -83,6 +100,17 @@ class SlackChannel(Channel):
return
self._web_client = self._web_client_factory(token=bot_token)
if self._bot_user_id is None:
try:
auth_info = await asyncio.to_thread(self._web_client.auth_test)
user_id = auth_info.get("user_id") if isinstance(auth_info, dict) else None
if user_id is None:
auth_get = getattr(auth_info, "get", None)
user_id = auth_get("user_id") if callable(auth_get) else None
if isinstance(user_id, str) and user_id:
self._bot_user_id = user_id
except Exception:
logger.warning("[Slack] failed to resolve bot user id; app mention text may include the bot mention", exc_info=True)
self._socket_client = SocketModeClient(
app_token=app_token,
web_client=self._web_client,
@@ -243,6 +271,12 @@ class SlackChannel(Channel):
if event_type != "events_api":
return
if self._bot_user_id is None:
authorization = next((item for item in req.payload.get("authorizations", []) if isinstance(item, dict)), None)
user_id = authorization.get("user_id") if authorization else None
if isinstance(user_id, str) and user_id:
self._bot_user_id = user_id
event = req.payload.get("event", {})
etype = event.get("type", "")
@@ -266,13 +300,15 @@ class SlackChannel(Channel):
return
text = event.get("text", "").strip()
if event.get("type") == "app_mention":
text = _strip_leading_slack_bot_mention(text, self._bot_user_id)
if not text:
return
channel_id = event.get("channel", "")
thread_ts = event.get("thread_ts") or event.get("ts", "")
if text.startswith("/"):
if is_known_channel_command(text):
msg_type = InboundMessageType.COMMAND
else:
msg_type = InboundMessageType.CHAT
+34 -2
View File
@@ -61,12 +61,17 @@ class TelegramChannel(Channel):
# Command handlers
app.add_handler(CommandHandler("start", self._cmd_start))
app.add_handler(CommandHandler("bootstrap", self._cmd_generic))
app.add_handler(CommandHandler("new", self._cmd_generic))
app.add_handler(CommandHandler("status", self._cmd_generic))
app.add_handler(CommandHandler("models", self._cmd_generic))
app.add_handler(CommandHandler("memory", self._cmd_generic))
app.add_handler(CommandHandler("help", self._cmd_generic))
# Slash skill commands are dynamic and cannot all be pre-registered
# with Telegram, so route unknown slash commands through chat handling.
app.add_handler(MessageHandler(filters.TEXT & filters.COMMAND, self._on_text))
# General message handler
app.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, self._on_text))
@@ -306,6 +311,33 @@ class TelegramChannel(Channel):
inbound.workspace_id = connection.get("workspace_id")
return inbound
def _get_bot_username(self, context) -> str | None:
bot = getattr(context, "bot", None)
username = getattr(bot, "username", None)
if not username and self._application is not None:
username = getattr(getattr(self._application, "bot", None), "username", None)
return str(username) if username else None
@staticmethod
def _strip_bot_username_from_leading_command(text: str, bot_username: str | None) -> str:
username = (bot_username or "").lstrip("@").lower()
if not username or not text.startswith("/"):
return text
parts = text.split(maxsplit=1)
command_token = parts[0]
if "@" not in command_token:
return text
command_name, addressed_username = command_token[1:].rsplit("@", 1)
if not command_name or addressed_username.lower() != username:
return text
normalized = f"/{command_name}"
if len(parts) > 1:
normalized = f"{normalized} {parts[1]}"
return normalized
async def _cmd_start(self, update, context) -> None:
"""Handle /start command."""
if not self._check_user(update.effective_user.id):
@@ -326,7 +358,7 @@ class TelegramChannel(Channel):
if not self._check_user(update.effective_user.id):
return
text = update.message.text
text = self._strip_bot_username_from_leading_command(update.message.text.strip(), self._get_bot_username(context))
chat_id = str(update.effective_chat.id)
user_id = str(update.effective_user.id)
msg_id = str(update.message.message_id)
@@ -363,7 +395,7 @@ class TelegramChannel(Channel):
if not self._check_user(update.effective_user.id):
return
text = update.message.text.strip()
text = self._strip_bot_username_from_leading_command(update.message.text.strip(), self._get_bot_username(context))
if not text:
return
+2 -1
View File
@@ -22,6 +22,7 @@ from cryptography.hazmat.primitives import padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from app.channels.base import Channel
from app.channels.commands import is_known_channel_command
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
@@ -620,7 +621,7 @@ class WechatChannel(Channel):
chat_id=chat_id,
user_id=chat_id,
text=text,
msg_type=InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT,
msg_type=InboundMessageType.COMMAND if is_known_channel_command(text) else InboundMessageType.CHAT,
thread_ts=thread_ts,
files=files,
metadata={
+2 -1
View File
@@ -8,6 +8,7 @@ from collections.abc import Awaitable, Callable
from typing import Any, cast
from app.channels.base import Channel
from app.channels.commands import is_known_channel_command
from app.channels.message_bus import (
InboundMessageType,
MessageBus,
@@ -270,7 +271,7 @@ class WeComChannel(Channel):
user_id = (body.get("from") or {}).get("userid")
inbound_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
inbound_type = InboundMessageType.COMMAND if is_known_channel_command(text) else InboundMessageType.CHAT
inbound = self._make_inbound(
chat_id=user_id, # keep user's conversation in memory
user_id=user_id,
+2
View File
@@ -6,6 +6,7 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.gateway.auth_disabled import warn_if_auth_disabled_enabled
from app.gateway.auth_middleware import AuthMiddleware
from app.gateway.config import get_gateway_config
from app.gateway.csrf_middleware import CSRFMiddleware, get_configured_cors_origins
@@ -173,6 +174,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
startup_config = get_app_config()
apply_logging_level(startup_config.log_level)
logger.info("Configuration loaded successfully")
warn_if_auth_disabled_enabled()
except Exception as e:
error_msg = f"Failed to load configuration during gateway startup: {e}"
logger.exception(error_msg)
+54
View File
@@ -0,0 +1,54 @@
"""Shared helpers for local/E2E auth-disabled mode."""
from __future__ import annotations
import logging
import os
from types import SimpleNamespace
AUTH_DISABLED_ENV_VAR = "DEER_FLOW_AUTH_DISABLED"
AUTH_DISABLED_USER_ID = "e2e-user"
AUTH_DISABLED_USER_EMAIL = "e2e@test.local"
AUTH_SOURCE_SESSION = "session"
AUTH_SOURCE_INTERNAL = "internal"
AUTH_SOURCE_AUTH_DISABLED = "auth_disabled"
_PRODUCTION_ENV_VARS: tuple[str, ...] = ("DEER_FLOW_ENV", "ENVIRONMENT")
_PRODUCTION_ENV_VALUES: frozenset[str] = frozenset({"prod", "production"})
logger = logging.getLogger(__name__)
def is_explicit_production_environment() -> bool:
return any(os.environ.get(name, "").strip().lower() in _PRODUCTION_ENV_VALUES for name in _PRODUCTION_ENV_VARS)
def is_auth_disabled_requested() -> bool:
return os.environ.get(AUTH_DISABLED_ENV_VAR) == "1"
def is_auth_disabled() -> bool:
return is_auth_disabled_requested() and not is_explicit_production_environment()
def warn_if_auth_disabled_enabled() -> None:
if not is_auth_disabled():
return
logger.warning(
"%s=1 is active: authentication is bypassed and anonymous requests run as synthetic admin user %r. Do not enable this in shared or production deployments.",
AUTH_DISABLED_ENV_VAR,
AUTH_DISABLED_USER_ID,
)
def get_auth_disabled_user():
return SimpleNamespace(
id=AUTH_DISABLED_USER_ID,
email=AUTH_DISABLED_USER_EMAIL,
password_hash=None,
system_role="admin",
needs_setup=False,
token_version=0,
)
+39 -22
View File
@@ -17,6 +17,13 @@ from starlette.responses import JSONResponse
from starlette.types import ASGIApp
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
from app.gateway.auth_disabled import (
AUTH_SOURCE_AUTH_DISABLED,
AUTH_SOURCE_INTERNAL,
AUTH_SOURCE_SESSION,
get_auth_disabled_user,
is_auth_disabled,
)
from app.gateway.authz import _ALL_PERMISSIONS, AuthContext
from app.gateway.internal_auth import INTERNAL_AUTH_HEADER_NAME, get_internal_user, is_valid_internal_auth_token
from deerflow.runtime.user_context import reset_current_user, set_current_user
@@ -83,8 +90,38 @@ class AuthMiddleware(BaseHTTPMiddleware):
if is_valid_internal_auth_token(request.headers.get(INTERNAL_AUTH_HEADER_NAME)):
internal_user = get_internal_user()
auth_source = AUTH_SOURCE_SESSION
access_token = request.cookies.get("access_token")
# Non-public path: require session cookie
if internal_user is None and not request.cookies.get("access_token"):
if internal_user is not None:
user = internal_user
auth_source = AUTH_SOURCE_INTERNAL
elif access_token:
# Strict JWT validation: reject junk/expired tokens with 401
# right here instead of silently passing through. This closes
# the "junk cookie bypass" gap (AUTH_TEST_PLAN test 7.5.8):
# without this, non-isolation routes like /api/models would
# accept any cookie-shaped string as authentication.
#
# We call the *strict* resolver so that fine-grained error
# codes (token_expired, token_invalid, user_not_found, …)
# propagate from AuthErrorCode, not get flattened into one
# generic code. BaseHTTPMiddleware doesn't let HTTPException
# bubble up, so we catch and render it as JSONResponse here.
from app.gateway.deps import get_current_user_from_request
try:
user = await get_current_user_from_request(request)
except HTTPException as exc:
if not is_auth_disabled():
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
user = get_auth_disabled_user()
auth_source = AUTH_SOURCE_AUTH_DISABLED
elif is_auth_disabled():
user = get_auth_disabled_user()
auth_source = AUTH_SOURCE_AUTH_DISABLED
else:
return JSONResponse(
status_code=401,
content={
@@ -95,32 +132,12 @@ class AuthMiddleware(BaseHTTPMiddleware):
},
)
# Strict JWT validation: reject junk/expired tokens with 401
# right here instead of silently passing through. This closes
# the "junk cookie bypass" gap (AUTH_TEST_PLAN test 7.5.8):
# without this, non-isolation routes like /api/models would
# accept any cookie-shaped string as authentication.
#
# We call the *strict* resolver so that fine-grained error
# codes (token_expired, token_invalid, user_not_found, …)
# propagate from AuthErrorCode, not get flattened into one
# generic code. BaseHTTPMiddleware doesn't let HTTPException
# bubble up, so we catch and render it as JSONResponse here.
from app.gateway.deps import get_current_user_from_request
if internal_user is not None:
user = internal_user
else:
try:
user = await get_current_user_from_request(request)
except HTTPException as exc:
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
# Stamp both request.state.user (for the contextvar pattern)
# and request.state.auth (so @require_permission's "auth is
# None" branch short-circuits instead of running the entire
# JWT-decode + DB-lookup pipeline a second time per request).
request.state.user = user
request.state.auth_source = auth_source
request.state.auth = AuthContext(user=user, permissions=_ALL_PERMISSIONS)
token = set_current_user(user)
try:
+5
View File
@@ -14,6 +14,8 @@ from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from starlette.types import ASGIApp
from app.gateway.auth_disabled import is_auth_disabled
CSRF_COOKIE_NAME = "csrf_token"
CSRF_HEADER_NAME = "X-CSRF-Token"
CSRF_TOKEN_LENGTH = 64 # bytes
@@ -38,6 +40,9 @@ def should_check_csrf(request: Request) -> bool:
if request.method not in ("POST", "PUT", "DELETE", "PATCH"):
return False
if is_auth_disabled():
return False
path = request.url.path.rstrip("/")
if path.startswith("/api/channels/webhooks/"):
return False
+11
View File
@@ -331,6 +331,17 @@ async def get_current_user_from_request(request: Request):
Raises HTTPException 401 if not authenticated.
"""
state = getattr(request, "state", None)
state_user = getattr(state, "user", None)
from app.gateway.auth_disabled import AUTH_SOURCE_AUTH_DISABLED, AUTH_SOURCE_INTERNAL, AUTH_SOURCE_SESSION
if state_user is not None and getattr(state, "auth_source", None) in {
AUTH_SOURCE_SESSION,
AUTH_SOURCE_AUTH_DISABLED,
AUTH_SOURCE_INTERNAL,
}:
return state_user
from app.gateway.auth import decode_token
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError, token_error_to_code
+7
View File
@@ -20,6 +20,7 @@ from langgraph_sdk import Auth
from app.gateway.auth.errors import TokenError
from app.gateway.auth.jwt import decode_token
from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID, is_auth_disabled
from app.gateway.deps import get_local_provider
auth = Auth()
@@ -38,6 +39,9 @@ def _check_csrf(request) -> None:
if method.upper() not in _CSRF_METHODS:
return
if is_auth_disabled():
return
cookie_token = request.cookies.get("csrf_token")
header_token = request.headers.get("x-csrf-token")
@@ -66,6 +70,9 @@ async def authenticate(request):
# are rejected early, even if the cookie carries a valid JWT.
_check_csrf(request)
if is_auth_disabled():
return AUTH_DISABLED_USER_ID
token = request.cookies.get("access_token")
if not token:
raise Auth.exceptions.HTTPException(
+70 -45
View File
@@ -1,5 +1,6 @@
"""CRUD API for custom agents."""
import asyncio
import logging
import re
import shutil
@@ -213,48 +214,61 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
user_id = get_effective_user_id()
paths = get_paths()
agent_dir = paths.user_agent_dir(user_id, normalized_name)
legacy_dir = paths.agent_dir(normalized_name)
def _create_agent() -> AgentResponse | None:
# Worker thread: base-dir resolution, existence checks, directory/file
# creation, read-back, and failure cleanup are all blocking filesystem
# IO that must stay off the event loop.
agent_dir = paths.user_agent_dir(user_id, normalized_name)
legacy_dir = paths.agent_dir(normalized_name)
if agent_dir.exists() or legacy_dir.exists():
raise HTTPException(status_code=409, detail=f"Agent '{normalized_name}' already exists")
if legacy_dir.exists():
return None # signals 409 to the caller
try:
try:
agent_dir.mkdir(parents=True, exist_ok=False)
except FileExistsError:
return None # signals 409 to the caller
# Write config.yaml
config_data: dict = {"name": normalized_name}
if request.description:
config_data["description"] = request.description
if request.model is not None:
config_data["model"] = request.model
if request.tool_groups is not None:
config_data["tool_groups"] = request.tool_groups
if request.skills is not None:
config_data["skills"] = request.skills
config_file = agent_dir / "config.yaml"
with open(config_file, "w", encoding="utf-8") as f:
yaml.dump(config_data, f, default_flow_style=False, allow_unicode=True)
# Write SOUL.md
soul_file = agent_dir / "SOUL.md"
soul_file.write_text(request.soul, encoding="utf-8")
logger.info(f"Created agent '{normalized_name}' at {agent_dir}")
agent_cfg = load_agent_config(normalized_name, user_id=user_id)
return _agent_config_to_response(agent_cfg, include_soul=True, user_id=user_id)
except Exception:
# Clean up partial state on failure before surfacing the error.
if agent_dir.exists():
shutil.rmtree(agent_dir)
raise
try:
agent_dir.mkdir(parents=True, exist_ok=True)
# Write config.yaml
config_data: dict = {"name": normalized_name}
if request.description:
config_data["description"] = request.description
if request.model is not None:
config_data["model"] = request.model
if request.tool_groups is not None:
config_data["tool_groups"] = request.tool_groups
if request.skills is not None:
config_data["skills"] = request.skills
config_file = agent_dir / "config.yaml"
with open(config_file, "w", encoding="utf-8") as f:
yaml.dump(config_data, f, default_flow_style=False, allow_unicode=True)
# Write SOUL.md
soul_file = agent_dir / "SOUL.md"
soul_file.write_text(request.soul, encoding="utf-8")
logger.info(f"Created agent '{normalized_name}' at {agent_dir}")
agent_cfg = load_agent_config(normalized_name, user_id=user_id)
return _agent_config_to_response(agent_cfg, include_soul=True, user_id=user_id)
except HTTPException:
raise
response = await asyncio.to_thread(_create_agent)
except Exception as e:
# Clean up on failure
if agent_dir.exists():
shutil.rmtree(agent_dir)
logger.error(f"Failed to create agent '{request.name}': {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to create agent: {str(e)}")
if response is None:
raise HTTPException(status_code=409, detail=f"Agent '{normalized_name}' already exists")
return response
@router.put(
"/agents/{name}",
@@ -428,19 +442,30 @@ async def delete_agent(name: str) -> None:
name = _normalize_agent_name(name)
user_id = get_effective_user_id()
paths = get_paths()
agent_dir = paths.user_agent_dir(user_id, name)
if not agent_dir.exists():
if paths.agent_dir(name).exists():
raise HTTPException(
status_code=409,
detail=(f"Agent '{name}' only exists in the legacy shared layout and is not scoped to a user. Run scripts/migrate_user_isolation.py to move legacy agents into the per-user layout before deleting."),
)
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
def _remove_agent_dir() -> tuple[str, str]:
# Runs in a worker thread: resolving the base dir, probing the directory
# (`exists`), and removing it (`rmtree`) are all blocking filesystem IO
# that must stay off the event loop.
agent_dir = paths.user_agent_dir(user_id, name)
if not agent_dir.exists():
outcome = "legacy" if paths.agent_dir(name).exists() else "missing"
return outcome, str(agent_dir)
shutil.rmtree(agent_dir)
return "deleted", str(agent_dir)
try:
shutil.rmtree(agent_dir)
logger.info(f"Deleted agent '{name}' from {agent_dir}")
outcome, agent_dir = await asyncio.to_thread(_remove_agent_dir)
except Exception as e:
logger.error(f"Failed to delete agent '{name}': {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to delete agent: {str(e)}")
if outcome == "legacy":
raise HTTPException(
status_code=409,
detail=(f"Agent '{name}' only exists in the legacy shared layout and is not scoped to a user. Run scripts/migrate_user_isolation.py to move legacy agents into the per-user layout before deleting."),
)
if outcome == "missing":
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
logger.info(f"Deleted agent '{name}' from {agent_dir}")
+10
View File
@@ -341,9 +341,19 @@ async def change_password(request: Request, response: Response, body: ChangePass
- Re-issues session cookie with new token_version
"""
from app.gateway.auth.password import hash_password_async, verify_password_async
from app.gateway.auth_disabled import AUTH_SOURCE_AUTH_DISABLED
user = await get_current_user_from_request(request)
if getattr(request.state, "auth_source", None) == AUTH_SOURCE_AUTH_DISABLED:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=AuthErrorResponse(
code=AuthErrorCode.INVALID_CREDENTIALS,
message="Password changes are not available when DEER_FLOW_AUTH_DISABLED=1.",
).model_dump(),
)
if user.password_hash is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="OAuth users cannot change password").model_dump())