mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-11 09:55:59 +00:00
Route no-auth channel sessions to local user
This commit is contained in:
@@ -440,10 +440,31 @@ def _human_input_message(content: str, *, original_content: str | None = None) -
|
|||||||
return message
|
return message
|
||||||
|
|
||||||
|
|
||||||
def _owner_headers(msg: InboundMessage) -> dict[str, str] | None:
|
def _auth_disabled_owner_user_id() -> str | None:
|
||||||
if not msg.owner_user_id:
|
try:
|
||||||
|
from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID, is_auth_disabled
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Unable to inspect auth-disabled mode for channel owner fallback", exc_info=True)
|
||||||
return None
|
return None
|
||||||
return create_internal_auth_headers(owner_user_id=msg.owner_user_id)
|
return AUTH_DISABLED_USER_ID if is_auth_disabled() else None
|
||||||
|
|
||||||
|
|
||||||
|
def _effective_owner_user_id(msg: InboundMessage) -> str | None:
|
||||||
|
return _auth_disabled_owner_user_id() or msg.owner_user_id
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_effective_owner(msg: InboundMessage) -> InboundMessage:
|
||||||
|
owner_user_id = _effective_owner_user_id(msg)
|
||||||
|
if owner_user_id:
|
||||||
|
msg.owner_user_id = owner_user_id
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
def _owner_headers(msg: InboundMessage) -> dict[str, str] | None:
|
||||||
|
owner_user_id = _effective_owner_user_id(msg)
|
||||||
|
if not owner_user_id:
|
||||||
|
return None
|
||||||
|
return create_internal_auth_headers(owner_user_id=owner_user_id)
|
||||||
|
|
||||||
|
|
||||||
def _resolve_slash_skill_command(
|
def _resolve_slash_skill_command(
|
||||||
@@ -741,8 +762,9 @@ class ChannelManager:
|
|||||||
# owns the connection. Preserve the raw platform user under
|
# owns the connection. Preserve the raw platform user under
|
||||||
# ``channel_user_id`` for platform-facing lookups and audits.
|
# ``channel_user_id`` for platform-facing lookups and audits.
|
||||||
run_context_identity: dict[str, Any] = {"thread_id": thread_id}
|
run_context_identity: dict[str, Any] = {"thread_id": thread_id}
|
||||||
if msg.owner_user_id:
|
owner_user_id = _effective_owner_user_id(msg)
|
||||||
run_context_identity["user_id"] = make_safe_user_id(msg.owner_user_id)
|
if owner_user_id:
|
||||||
|
run_context_identity["user_id"] = make_safe_user_id(owner_user_id)
|
||||||
elif msg.user_id:
|
elif msg.user_id:
|
||||||
run_context_identity["user_id"] = make_safe_user_id(msg.user_id)
|
run_context_identity["user_id"] = make_safe_user_id(msg.user_id)
|
||||||
if msg.user_id:
|
if msg.user_id:
|
||||||
@@ -857,6 +879,7 @@ class ChannelManager:
|
|||||||
logger.error("[Manager] unhandled error in message task: %s", exc, exc_info=exc)
|
logger.error("[Manager] unhandled error in message task: %s", exc, exc_info=exc)
|
||||||
|
|
||||||
async def _handle_message(self, msg: InboundMessage) -> None:
|
async def _handle_message(self, msg: InboundMessage) -> None:
|
||||||
|
msg = _apply_effective_owner(msg)
|
||||||
async with self._semaphore:
|
async with self._semaphore:
|
||||||
try:
|
try:
|
||||||
if msg.msg_type == InboundMessageType.COMMAND:
|
if msg.msg_type == InboundMessageType.COMMAND:
|
||||||
|
|||||||
@@ -894,10 +894,12 @@ class TestChannelManager:
|
|||||||
|
|
||||||
_run(go())
|
_run(go())
|
||||||
|
|
||||||
def test_clarification_follow_up_preserves_history(self):
|
def test_clarification_follow_up_preserves_history(self, monkeypatch):
|
||||||
"""Conversation should continue after ask_clarification instead of resetting history."""
|
"""Conversation should continue after ask_clarification instead of resetting history."""
|
||||||
from app.channels.manager import ChannelManager
|
from app.channels.manager import ChannelManager
|
||||||
|
|
||||||
|
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
|
||||||
|
|
||||||
async def go():
|
async def go():
|
||||||
bus = MessageBus()
|
bus = MessageBus()
|
||||||
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
|
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
|
||||||
@@ -2336,8 +2338,9 @@ class TestResolveRunParamsUserId:
|
|||||||
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
|
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
|
||||||
return ChannelManager(bus=bus, store=store)
|
return ChannelManager(bus=bus, store=store)
|
||||||
|
|
||||||
def test_safe_user_id_is_passed_through(self):
|
def test_safe_user_id_is_passed_through(self, monkeypatch):
|
||||||
manager = self._manager()
|
manager = self._manager()
|
||||||
|
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
|
||||||
msg = InboundMessage(channel_name="telegram", chat_id="c", user_id="123456", text="hi")
|
msg = InboundMessage(channel_name="telegram", chat_id="c", user_id="123456", text="hi")
|
||||||
|
|
||||||
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
|
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
|
||||||
@@ -2345,8 +2348,9 @@ class TestResolveRunParamsUserId:
|
|||||||
assert run_context["user_id"] == "123456"
|
assert run_context["user_id"] == "123456"
|
||||||
assert run_context["channel_user_id"] == "123456"
|
assert run_context["channel_user_id"] == "123456"
|
||||||
|
|
||||||
def test_connection_owner_user_id_takes_precedence_over_platform_user_id(self):
|
def test_connection_owner_user_id_takes_precedence_over_platform_user_id(self, monkeypatch):
|
||||||
manager = self._manager()
|
manager = self._manager()
|
||||||
|
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
|
||||||
msg = InboundMessage(
|
msg = InboundMessage(
|
||||||
channel_name="slack",
|
channel_name="slack",
|
||||||
chat_id="C123",
|
chat_id="C123",
|
||||||
@@ -2361,10 +2365,61 @@ class TestResolveRunParamsUserId:
|
|||||||
assert run_context["user_id"] == "deerflow-user-1"
|
assert run_context["user_id"] == "deerflow-user-1"
|
||||||
assert run_context["channel_user_id"] == "U-platform"
|
assert run_context["channel_user_id"] == "U-platform"
|
||||||
|
|
||||||
def test_unsafe_user_id_is_normalized_but_raw_preserved(self):
|
def test_auth_disabled_user_id_is_used_for_unbound_channel_messages(self, monkeypatch):
|
||||||
|
from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID
|
||||||
|
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME
|
||||||
|
|
||||||
|
manager = self._manager()
|
||||||
|
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||||
|
msg = InboundMessage(channel_name="slack", chat_id="C123", user_id="U-platform", text="hi")
|
||||||
|
|
||||||
|
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
|
||||||
|
|
||||||
|
assert run_context["user_id"] == AUTH_DISABLED_USER_ID
|
||||||
|
assert run_context["channel_user_id"] == "U-platform"
|
||||||
|
|
||||||
|
from app.channels.manager import _owner_headers
|
||||||
|
|
||||||
|
headers = _owner_headers(msg)
|
||||||
|
assert headers is not None
|
||||||
|
assert headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == AUTH_DISABLED_USER_ID
|
||||||
|
|
||||||
|
def test_auth_disabled_user_id_overrides_bound_owner_for_local_visibility(self, monkeypatch):
|
||||||
|
from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID
|
||||||
|
|
||||||
|
manager = self._manager()
|
||||||
|
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||||
|
msg = InboundMessage(
|
||||||
|
channel_name="slack",
|
||||||
|
chat_id="C123",
|
||||||
|
user_id="U-platform",
|
||||||
|
owner_user_id="real-user-from-old-binding",
|
||||||
|
text="hi",
|
||||||
|
)
|
||||||
|
|
||||||
|
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
|
||||||
|
|
||||||
|
assert run_context["user_id"] == AUTH_DISABLED_USER_ID
|
||||||
|
assert run_context["channel_user_id"] == "U-platform"
|
||||||
|
|
||||||
|
def test_unbound_channel_messages_keep_platform_user_id_when_auth_is_enabled(self, monkeypatch):
|
||||||
|
from app.channels.manager import _owner_headers
|
||||||
|
|
||||||
|
manager = self._manager()
|
||||||
|
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
|
||||||
|
msg = InboundMessage(channel_name="slack", chat_id="C123", user_id="U-platform", text="hi")
|
||||||
|
|
||||||
|
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
|
||||||
|
|
||||||
|
assert run_context["user_id"] == "U-platform"
|
||||||
|
assert run_context["channel_user_id"] == "U-platform"
|
||||||
|
assert _owner_headers(msg) is None
|
||||||
|
|
||||||
|
def test_unsafe_user_id_is_normalized_but_raw_preserved(self, monkeypatch):
|
||||||
from deerflow.config.paths import make_safe_user_id
|
from deerflow.config.paths import make_safe_user_id
|
||||||
|
|
||||||
manager = self._manager()
|
manager = self._manager()
|
||||||
|
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
|
||||||
raw = "user@example.com"
|
raw = "user@example.com"
|
||||||
msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw, text="hi")
|
msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw, text="hi")
|
||||||
|
|
||||||
@@ -2375,8 +2430,9 @@ class TestResolveRunParamsUserId:
|
|||||||
assert run_context["channel_user_id"] == raw
|
assert run_context["channel_user_id"] == raw
|
||||||
|
|
||||||
@pytest.mark.parametrize("raw_user_id", ["", None])
|
@pytest.mark.parametrize("raw_user_id", ["", None])
|
||||||
def test_empty_or_none_user_id_is_not_injected(self, raw_user_id):
|
def test_empty_or_none_user_id_is_not_injected(self, raw_user_id, monkeypatch):
|
||||||
manager = self._manager()
|
manager = self._manager()
|
||||||
|
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
|
||||||
msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw_user_id, text="hi")
|
msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw_user_id, text="hi")
|
||||||
|
|
||||||
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
|
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
|
||||||
@@ -2386,11 +2442,13 @@ class TestResolveRunParamsUserId:
|
|||||||
|
|
||||||
|
|
||||||
class TestChannelManagerConnectionRouting:
|
class TestChannelManagerConnectionRouting:
|
||||||
def test_connection_scoped_conversations_do_not_share_threads(self, tmp_path):
|
def test_connection_scoped_conversations_do_not_share_threads(self, tmp_path, monkeypatch):
|
||||||
from app.channels.manager import ChannelManager
|
from app.channels.manager import ChannelManager
|
||||||
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME
|
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME
|
||||||
from deerflow.persistence.engine import close_engine
|
from deerflow.persistence.engine import close_engine
|
||||||
|
|
||||||
|
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
|
||||||
|
|
||||||
async def go():
|
async def go():
|
||||||
repo = await _make_channel_connection_repo(tmp_path)
|
repo = await _make_channel_connection_repo(tmp_path)
|
||||||
alice = await repo.upsert_connection(
|
alice = await repo.upsert_connection(
|
||||||
|
|||||||
Reference in New Issue
Block a user