mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-11 09:55:59 +00:00
Make channel threads visible to connection owners
This commit is contained in:
@@ -2388,6 +2388,7 @@ class TestResolveRunParamsUserId:
|
||||
class TestChannelManagerConnectionRouting:
|
||||
def test_connection_scoped_conversations_do_not_share_threads(self, tmp_path):
|
||||
from app.channels.manager import ChannelManager
|
||||
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME
|
||||
from deerflow.persistence.engine import close_engine
|
||||
|
||||
async def go():
|
||||
@@ -2453,6 +2454,16 @@ class TestChannelManagerConnectionRouting:
|
||||
assert second_context["user_id"] == "bob"
|
||||
assert second_context["channel_user_id"] == "U-bob"
|
||||
|
||||
first_create_headers = mock_client.threads.create.call_args_list[0].kwargs["headers"]
|
||||
second_create_headers = mock_client.threads.create.call_args_list[1].kwargs["headers"]
|
||||
assert first_create_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "alice"
|
||||
assert second_create_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "bob"
|
||||
|
||||
first_run_headers = mock_client.runs.wait.call_args_list[0].kwargs["headers"]
|
||||
second_run_headers = mock_client.runs.wait.call_args_list[1].kwargs["headers"]
|
||||
assert first_run_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "alice"
|
||||
assert second_run_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "bob"
|
||||
|
||||
try:
|
||||
_run(go())
|
||||
finally:
|
||||
|
||||
@@ -474,6 +474,83 @@ def test_inject_authenticated_user_context_skips_internal_role():
|
||||
assert config["context"]["user_id"] == "channel-user-7"
|
||||
|
||||
|
||||
def test_start_run_uses_internal_owner_header_for_persistence():
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME, INTERNAL_SYSTEM_ROLE
|
||||
from app.gateway.services import start_run
|
||||
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
|
||||
from deerflow.runtime import RunManager
|
||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
async def _scenario():
|
||||
run_store = MemoryRunStore()
|
||||
thread_store = MemoryThreadMetaStore(InMemoryStore())
|
||||
await thread_store.create("channel-thread", user_id="default", metadata={"legacy": True})
|
||||
run_manager = RunManager(store=run_store)
|
||||
state = SimpleNamespace(
|
||||
stream_bridge=SimpleNamespace(),
|
||||
run_manager=run_manager,
|
||||
checkpointer=InMemorySaver(),
|
||||
store=InMemoryStore(),
|
||||
run_event_store=SimpleNamespace(),
|
||||
run_events_config=None,
|
||||
thread_store=thread_store,
|
||||
)
|
||||
request = SimpleNamespace(
|
||||
headers={INTERNAL_OWNER_USER_ID_HEADER_NAME: "owner-1"},
|
||||
state=SimpleNamespace(user=SimpleNamespace(id="default", system_role=INTERNAL_SYSTEM_ROLE)),
|
||||
app=SimpleNamespace(state=state),
|
||||
)
|
||||
body = SimpleNamespace(
|
||||
assistant_id="lead_agent",
|
||||
input={"messages": [{"role": "human", "content": "hi"}]},
|
||||
metadata={},
|
||||
config=None,
|
||||
context=None,
|
||||
on_disconnect="cancel",
|
||||
multitask_strategy="reject",
|
||||
stream_mode=None,
|
||||
stream_subgraphs=False,
|
||||
interrupt_before=None,
|
||||
interrupt_after=None,
|
||||
)
|
||||
task_context: dict[str, str] = {}
|
||||
|
||||
async def fake_run_agent(*args, **kwargs):
|
||||
task_context["user_id"] = get_effective_user_id()
|
||||
|
||||
with (
|
||||
patch("app.gateway.services.resolve_agent_factory", return_value=object()),
|
||||
patch("app.gateway.services.run_agent", side_effect=fake_run_agent),
|
||||
):
|
||||
record = await start_run(body, "channel-thread", request)
|
||||
await record.task
|
||||
|
||||
owner_run = await run_store.get(record.run_id, user_id="owner-1")
|
||||
default_run = await run_store.get(record.run_id, user_id="default")
|
||||
owner_thread = await thread_store.get("channel-thread", user_id="owner-1")
|
||||
default_thread = await thread_store.get("channel-thread", user_id="default")
|
||||
return owner_run, default_run, owner_thread, default_thread, task_context
|
||||
|
||||
owner_run, default_run, owner_thread, default_thread, task_context = asyncio.run(_scenario())
|
||||
|
||||
assert owner_run is not None
|
||||
assert owner_run["user_id"] == "owner-1"
|
||||
assert default_run is None
|
||||
assert owner_thread is not None
|
||||
assert owner_thread["user_id"] == "owner-1"
|
||||
assert owner_thread["metadata"] == {"legacy": True}
|
||||
assert default_thread is None
|
||||
assert task_context["user_id"] == "owner-1"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_run_config — context / configurable precedence (LangGraph >= 0.6.0)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -33,3 +33,18 @@ def test_internal_auth_generates_process_local_fallback(monkeypatch):
|
||||
assert reloaded.is_valid_internal_auth_token(token) is True
|
||||
finally:
|
||||
importlib.reload(reloaded)
|
||||
|
||||
|
||||
def test_internal_auth_headers_can_carry_owner_user_id(monkeypatch):
|
||||
import app.gateway.internal_auth as internal_auth
|
||||
|
||||
monkeypatch.setenv("DEER_FLOW_INTERNAL_AUTH_TOKEN", "shared-token")
|
||||
reloaded = importlib.reload(internal_auth)
|
||||
try:
|
||||
headers = reloaded.create_internal_auth_headers(owner_user_id="owner-1")
|
||||
|
||||
assert headers[reloaded.INTERNAL_AUTH_HEADER_NAME] == "shared-token"
|
||||
assert headers[reloaded.INTERNAL_OWNER_USER_ID_HEADER_NAME] == "owner-1"
|
||||
finally:
|
||||
monkeypatch.delenv("DEER_FLOW_INTERNAL_AUTH_TOKEN", raising=False)
|
||||
importlib.reload(reloaded)
|
||||
|
||||
@@ -137,6 +137,19 @@ class TestThreadMetaRepository:
|
||||
async def test_update_metadata_nonexistent_is_noop(self, repo):
|
||||
await repo.update_metadata("nonexistent", {"k": "v"}) # should not raise
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_owner_with_bypass_moves_row(self, repo):
|
||||
await repo.create("t1", user_id="default", metadata={"source": "channel"})
|
||||
await repo.update_owner("t1", "owner-1", user_id=None)
|
||||
|
||||
owner_row = await repo.get("t1", user_id="owner-1")
|
||||
default_row = await repo.get("t1", user_id="default")
|
||||
|
||||
assert owner_row is not None
|
||||
assert owner_row["user_id"] == "owner-1"
|
||||
assert owner_row["metadata"] == {"source": "channel"}
|
||||
assert default_row is None
|
||||
|
||||
# --- search with metadata filter (SQL push-down) ---
|
||||
|
||||
@pytest.mark.anyio
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import re
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -218,6 +219,37 @@ def test_create_thread_returns_iso_timestamps() -> None:
|
||||
assert body["created_at"] == body["updated_at"]
|
||||
|
||||
|
||||
def test_internal_owner_header_assigns_thread_to_owner() -> None:
|
||||
import asyncio
|
||||
|
||||
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME, INTERNAL_SYSTEM_ROLE
|
||||
|
||||
store = InMemoryStore()
|
||||
checkpointer = InMemorySaver()
|
||||
thread_store = MemoryThreadMetaStore(store)
|
||||
request = SimpleNamespace(
|
||||
headers={INTERNAL_OWNER_USER_ID_HEADER_NAME: "owner-1"},
|
||||
state=SimpleNamespace(user=SimpleNamespace(id="default", system_role=INTERNAL_SYSTEM_ROLE)),
|
||||
app=SimpleNamespace(state=SimpleNamespace(checkpointer=checkpointer, thread_store=thread_store)),
|
||||
)
|
||||
|
||||
async def _scenario():
|
||||
response = await threads.create_thread(
|
||||
threads.ThreadCreateRequest(thread_id="channel-thread", metadata={}),
|
||||
request,
|
||||
)
|
||||
owner_row = await thread_store.get("channel-thread", user_id="owner-1")
|
||||
internal_row = await thread_store.get("channel-thread", user_id="default")
|
||||
return response, owner_row, internal_row
|
||||
|
||||
response, owner_row, internal_row = asyncio.run(_scenario())
|
||||
|
||||
assert response.thread_id == "channel-thread"
|
||||
assert owner_row is not None
|
||||
assert owner_row["user_id"] == "owner-1"
|
||||
assert internal_row is None
|
||||
|
||||
|
||||
def test_get_thread_returns_iso_for_legacy_unix_record() -> None:
|
||||
"""A thread record written by older versions stores ``time.time()``
|
||||
floats. ``get_thread`` must transparently surface them as ISO so the
|
||||
|
||||
Reference in New Issue
Block a user