mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-11 18:05:58 +00:00
Route no-auth channel sessions to local user
This commit is contained in:
@@ -440,10 +440,31 @@ def _human_input_message(content: str, *, original_content: str | None = None) -
|
||||
return message
|
||||
|
||||
|
||||
def _owner_headers(msg: InboundMessage) -> dict[str, str] | None:
|
||||
if not msg.owner_user_id:
|
||||
def _auth_disabled_owner_user_id() -> str | None:
|
||||
try:
|
||||
from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID, is_auth_disabled
|
||||
except Exception:
|
||||
logger.debug("Unable to inspect auth-disabled mode for channel owner fallback", exc_info=True)
|
||||
return None
|
||||
return create_internal_auth_headers(owner_user_id=msg.owner_user_id)
|
||||
return AUTH_DISABLED_USER_ID if is_auth_disabled() else None
|
||||
|
||||
|
||||
def _effective_owner_user_id(msg: InboundMessage) -> str | None:
|
||||
return _auth_disabled_owner_user_id() or msg.owner_user_id
|
||||
|
||||
|
||||
def _apply_effective_owner(msg: InboundMessage) -> InboundMessage:
|
||||
owner_user_id = _effective_owner_user_id(msg)
|
||||
if owner_user_id:
|
||||
msg.owner_user_id = owner_user_id
|
||||
return msg
|
||||
|
||||
|
||||
def _owner_headers(msg: InboundMessage) -> dict[str, str] | None:
|
||||
owner_user_id = _effective_owner_user_id(msg)
|
||||
if not owner_user_id:
|
||||
return None
|
||||
return create_internal_auth_headers(owner_user_id=owner_user_id)
|
||||
|
||||
|
||||
def _resolve_slash_skill_command(
|
||||
@@ -741,8 +762,9 @@ class ChannelManager:
|
||||
# owns the connection. Preserve the raw platform user under
|
||||
# ``channel_user_id`` for platform-facing lookups and audits.
|
||||
run_context_identity: dict[str, Any] = {"thread_id": thread_id}
|
||||
if msg.owner_user_id:
|
||||
run_context_identity["user_id"] = make_safe_user_id(msg.owner_user_id)
|
||||
owner_user_id = _effective_owner_user_id(msg)
|
||||
if owner_user_id:
|
||||
run_context_identity["user_id"] = make_safe_user_id(owner_user_id)
|
||||
elif msg.user_id:
|
||||
run_context_identity["user_id"] = make_safe_user_id(msg.user_id)
|
||||
if msg.user_id:
|
||||
@@ -857,6 +879,7 @@ class ChannelManager:
|
||||
logger.error("[Manager] unhandled error in message task: %s", exc, exc_info=exc)
|
||||
|
||||
async def _handle_message(self, msg: InboundMessage) -> None:
|
||||
msg = _apply_effective_owner(msg)
|
||||
async with self._semaphore:
|
||||
try:
|
||||
if msg.msg_type == InboundMessageType.COMMAND:
|
||||
|
||||
@@ -894,10 +894,12 @@ class TestChannelManager:
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_clarification_follow_up_preserves_history(self):
|
||||
def test_clarification_follow_up_preserves_history(self, monkeypatch):
|
||||
"""Conversation should continue after ask_clarification instead of resetting history."""
|
||||
from app.channels.manager import ChannelManager
|
||||
|
||||
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
|
||||
|
||||
async def go():
|
||||
bus = MessageBus()
|
||||
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
|
||||
@@ -2336,8 +2338,9 @@ class TestResolveRunParamsUserId:
|
||||
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
|
||||
return ChannelManager(bus=bus, store=store)
|
||||
|
||||
def test_safe_user_id_is_passed_through(self):
|
||||
def test_safe_user_id_is_passed_through(self, monkeypatch):
|
||||
manager = self._manager()
|
||||
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
|
||||
msg = InboundMessage(channel_name="telegram", chat_id="c", user_id="123456", text="hi")
|
||||
|
||||
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
|
||||
@@ -2345,8 +2348,9 @@ class TestResolveRunParamsUserId:
|
||||
assert run_context["user_id"] == "123456"
|
||||
assert run_context["channel_user_id"] == "123456"
|
||||
|
||||
def test_connection_owner_user_id_takes_precedence_over_platform_user_id(self):
|
||||
def test_connection_owner_user_id_takes_precedence_over_platform_user_id(self, monkeypatch):
|
||||
manager = self._manager()
|
||||
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
|
||||
msg = InboundMessage(
|
||||
channel_name="slack",
|
||||
chat_id="C123",
|
||||
@@ -2361,10 +2365,61 @@ class TestResolveRunParamsUserId:
|
||||
assert run_context["user_id"] == "deerflow-user-1"
|
||||
assert run_context["channel_user_id"] == "U-platform"
|
||||
|
||||
def test_unsafe_user_id_is_normalized_but_raw_preserved(self):
|
||||
def test_auth_disabled_user_id_is_used_for_unbound_channel_messages(self, monkeypatch):
|
||||
from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID
|
||||
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME
|
||||
|
||||
manager = self._manager()
|
||||
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||
msg = InboundMessage(channel_name="slack", chat_id="C123", user_id="U-platform", text="hi")
|
||||
|
||||
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
|
||||
|
||||
assert run_context["user_id"] == AUTH_DISABLED_USER_ID
|
||||
assert run_context["channel_user_id"] == "U-platform"
|
||||
|
||||
from app.channels.manager import _owner_headers
|
||||
|
||||
headers = _owner_headers(msg)
|
||||
assert headers is not None
|
||||
assert headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == AUTH_DISABLED_USER_ID
|
||||
|
||||
def test_auth_disabled_user_id_overrides_bound_owner_for_local_visibility(self, monkeypatch):
|
||||
from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID
|
||||
|
||||
manager = self._manager()
|
||||
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||
msg = InboundMessage(
|
||||
channel_name="slack",
|
||||
chat_id="C123",
|
||||
user_id="U-platform",
|
||||
owner_user_id="real-user-from-old-binding",
|
||||
text="hi",
|
||||
)
|
||||
|
||||
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
|
||||
|
||||
assert run_context["user_id"] == AUTH_DISABLED_USER_ID
|
||||
assert run_context["channel_user_id"] == "U-platform"
|
||||
|
||||
def test_unbound_channel_messages_keep_platform_user_id_when_auth_is_enabled(self, monkeypatch):
|
||||
from app.channels.manager import _owner_headers
|
||||
|
||||
manager = self._manager()
|
||||
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
|
||||
msg = InboundMessage(channel_name="slack", chat_id="C123", user_id="U-platform", text="hi")
|
||||
|
||||
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
|
||||
|
||||
assert run_context["user_id"] == "U-platform"
|
||||
assert run_context["channel_user_id"] == "U-platform"
|
||||
assert _owner_headers(msg) is None
|
||||
|
||||
def test_unsafe_user_id_is_normalized_but_raw_preserved(self, monkeypatch):
|
||||
from deerflow.config.paths import make_safe_user_id
|
||||
|
||||
manager = self._manager()
|
||||
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
|
||||
raw = "user@example.com"
|
||||
msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw, text="hi")
|
||||
|
||||
@@ -2375,8 +2430,9 @@ class TestResolveRunParamsUserId:
|
||||
assert run_context["channel_user_id"] == raw
|
||||
|
||||
@pytest.mark.parametrize("raw_user_id", ["", None])
|
||||
def test_empty_or_none_user_id_is_not_injected(self, raw_user_id):
|
||||
def test_empty_or_none_user_id_is_not_injected(self, raw_user_id, monkeypatch):
|
||||
manager = self._manager()
|
||||
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
|
||||
msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw_user_id, text="hi")
|
||||
|
||||
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
|
||||
@@ -2386,11 +2442,13 @@ class TestResolveRunParamsUserId:
|
||||
|
||||
|
||||
class TestChannelManagerConnectionRouting:
|
||||
def test_connection_scoped_conversations_do_not_share_threads(self, tmp_path):
|
||||
def test_connection_scoped_conversations_do_not_share_threads(self, tmp_path, monkeypatch):
|
||||
from app.channels.manager import ChannelManager
|
||||
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME
|
||||
from deerflow.persistence.engine import close_engine
|
||||
|
||||
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
|
||||
|
||||
async def go():
|
||||
repo = await _make_channel_connection_repo(tmp_path)
|
||||
alice = await repo.upsert_connection(
|
||||
|
||||
Reference in New Issue
Block a user