mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-11 18:05:58 +00:00
Make channel threads visible to connection owners
This commit is contained in:
@@ -71,6 +71,15 @@ class ThreadMetaStore(abc.ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def update_owner(self, thread_id: str, owner_user_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
"""Move a thread metadata row to a new owner.
|
||||
|
||||
Intended for trusted internal repair/migration paths. No-op if the
|
||||
row does not exist or the caller fails the owner check.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool:
|
||||
"""Check if ``user_id`` has access to ``thread_id``."""
|
||||
|
||||
@@ -127,6 +127,14 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
||||
record["updated_at"] = now_iso()
|
||||
await self._store.aput(THREADS_NS, thread_id, record)
|
||||
|
||||
async def update_owner(self, thread_id: str, owner_user_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_owner")
|
||||
if record is None:
|
||||
return
|
||||
record["user_id"] = owner_user_id
|
||||
record["updated_at"] = now_iso()
|
||||
await self._store.aput(THREADS_NS, thread_id, record)
|
||||
|
||||
async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.delete")
|
||||
if record is None:
|
||||
|
||||
@@ -211,6 +211,21 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
row.updated_at = datetime.now(UTC)
|
||||
await session.commit()
|
||||
|
||||
async def update_owner(
|
||||
self,
|
||||
thread_id: str,
|
||||
owner_user_id: str,
|
||||
*,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> None:
|
||||
"""Move a thread metadata row to ``owner_user_id``."""
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_owner")
|
||||
async with self._sf() as session:
|
||||
if not await self._check_ownership(session, thread_id, resolved_user_id):
|
||||
return
|
||||
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(user_id=owner_user_id, updated_at=datetime.now(UTC)))
|
||||
await session.commit()
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
thread_id: str,
|
||||
|
||||
@@ -83,6 +83,7 @@ class RunRecord:
|
||||
multitask_strategy: str = "reject"
|
||||
metadata: dict = field(default_factory=dict)
|
||||
kwargs: dict = field(default_factory=dict)
|
||||
user_id: str | None = None
|
||||
created_at: str = ""
|
||||
updated_at: str = ""
|
||||
task: asyncio.Task | None = field(default=None, repr=False)
|
||||
@@ -124,7 +125,7 @@ class RunManager:
|
||||
|
||||
@staticmethod
|
||||
def _store_put_payload(record: RunRecord, *, error: str | None = None) -> dict[str, Any]:
|
||||
return {
|
||||
payload = {
|
||||
"thread_id": record.thread_id,
|
||||
"assistant_id": record.assistant_id,
|
||||
"status": record.status.value,
|
||||
@@ -135,6 +136,9 @@ class RunManager:
|
||||
"created_at": record.created_at,
|
||||
"model_name": record.model_name,
|
||||
}
|
||||
if record.user_id is not None:
|
||||
payload["user_id"] = record.user_id
|
||||
return payload
|
||||
|
||||
async def _call_store_with_retry(
|
||||
self,
|
||||
@@ -241,6 +245,7 @@ class RunManager:
|
||||
kwargs=row.get("kwargs") or {},
|
||||
created_at=row.get("created_at") or "",
|
||||
updated_at=row.get("updated_at") or "",
|
||||
user_id=row.get("user_id"),
|
||||
error=row.get("error"),
|
||||
model_name=row.get("model_name"),
|
||||
store_only=True,
|
||||
@@ -320,6 +325,7 @@ class RunManager:
|
||||
metadata: dict | None = None,
|
||||
kwargs: dict | None = None,
|
||||
multitask_strategy: str = "reject",
|
||||
user_id: str | None = None,
|
||||
) -> RunRecord:
|
||||
"""Create a new pending run and register it."""
|
||||
run_id = str(uuid.uuid4())
|
||||
@@ -333,6 +339,7 @@ class RunManager:
|
||||
multitask_strategy=multitask_strategy,
|
||||
metadata=metadata or {},
|
||||
kwargs=kwargs or {},
|
||||
user_id=user_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
@@ -504,6 +511,7 @@ class RunManager:
|
||||
kwargs: dict | None = None,
|
||||
multitask_strategy: str = "reject",
|
||||
model_name: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> RunRecord:
|
||||
"""Atomically check for inflight runs and create a new one.
|
||||
|
||||
@@ -546,6 +554,7 @@ class RunManager:
|
||||
multitask_strategy=multitask_strategy,
|
||||
metadata=metadata or {},
|
||||
kwargs=kwargs or {},
|
||||
user_id=user_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
model_name=model_name,
|
||||
|
||||
Reference in New Issue
Block a user