fix(subagents): propagate user context across threaded execution (#2676)
This commit is contained in:
@@ -5,8 +5,10 @@ import atexit
|
|||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
|
from collections.abc import Callable, Coroutine
|
||||||
from concurrent.futures import Future, ThreadPoolExecutor
|
from concurrent.futures import Future, ThreadPoolExecutor
|
||||||
from concurrent.futures import TimeoutError as FuturesTimeoutError
|
from concurrent.futures import TimeoutError as FuturesTimeoutError
|
||||||
|
from contextvars import Context, copy_context
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@@ -168,6 +170,19 @@ def _get_isolated_subagent_loop() -> asyncio.AbstractEventLoop:
|
|||||||
return _isolated_subagent_loop
|
return _isolated_subagent_loop
|
||||||
|
|
||||||
|
|
||||||
|
def _submit_to_isolated_loop_in_context(
|
||||||
|
context: Context,
|
||||||
|
coro_factory: Callable[[], Coroutine[Any, Any, SubagentResult]],
|
||||||
|
) -> Future[SubagentResult]:
|
||||||
|
"""Submit a coroutine to the isolated loop while preserving ContextVar state."""
|
||||||
|
return context.run(
|
||||||
|
lambda: asyncio.run_coroutine_threadsafe(
|
||||||
|
coro_factory(),
|
||||||
|
_get_isolated_subagent_loop(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _filter_tools(
|
def _filter_tools(
|
||||||
all_tools: list[BaseTool],
|
all_tools: list[BaseTool],
|
||||||
allowed: list[str] | None,
|
allowed: list[str] | None,
|
||||||
@@ -549,10 +564,11 @@ class SubagentExecutor:
|
|||||||
from being tied to a short-lived loop that gets closed per execution.
|
from being tied to a short-lived loop that gets closed per execution.
|
||||||
"""
|
"""
|
||||||
future: Future[SubagentResult] | None = None
|
future: Future[SubagentResult] | None = None
|
||||||
|
parent_context = copy_context()
|
||||||
try:
|
try:
|
||||||
future = asyncio.run_coroutine_threadsafe(
|
future = _submit_to_isolated_loop_in_context(
|
||||||
self._aexecute(task, result_holder),
|
parent_context,
|
||||||
_get_isolated_subagent_loop(),
|
lambda: self._aexecute(task, result_holder),
|
||||||
)
|
)
|
||||||
return future.result(timeout=self.config.timeout_seconds)
|
return future.result(timeout=self.config.timeout_seconds)
|
||||||
except FuturesTimeoutError:
|
except FuturesTimeoutError:
|
||||||
@@ -646,6 +662,8 @@ class SubagentExecutor:
|
|||||||
with _background_tasks_lock:
|
with _background_tasks_lock:
|
||||||
_background_tasks[task_id] = result
|
_background_tasks[task_id] = result
|
||||||
|
|
||||||
|
parent_context = copy_context()
|
||||||
|
|
||||||
# Submit to scheduler pool
|
# Submit to scheduler pool
|
||||||
def run_task():
|
def run_task():
|
||||||
with _background_tasks_lock:
|
with _background_tasks_lock:
|
||||||
@@ -656,9 +674,9 @@ class SubagentExecutor:
|
|||||||
try:
|
try:
|
||||||
# Submit execution directly to the persistent isolated loop so the
|
# Submit execution directly to the persistent isolated loop so the
|
||||||
# background path does not create a temporary loop via execute().
|
# background path does not create a temporary loop via execute().
|
||||||
execution_future = asyncio.run_coroutine_threadsafe(
|
execution_future = _submit_to_isolated_loop_in_context(
|
||||||
self._aexecute(task, result_holder),
|
parent_context,
|
||||||
_get_isolated_subagent_loop(),
|
lambda: self._aexecute(task, result_holder),
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
# Wait for execution with timeout
|
# Wait for execution with timeout
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import asyncio
|
|||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from types import ModuleType
|
from types import ModuleType, SimpleNamespace
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -526,12 +526,19 @@ class TestSyncExecutionPath:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_execute_in_running_event_loop_calls_isolated_loop_directly(self, classes, base_config, mock_agent, msg):
|
async def test_execute_in_running_event_loop_calls_isolated_loop_directly(self, classes, base_config, mock_agent, msg):
|
||||||
"""Test that execute() calls the isolated-loop helper directly in a running loop."""
|
"""Test that execute() calls the isolated-loop helper directly in a running loop."""
|
||||||
|
from deerflow.runtime.user_context import (
|
||||||
|
get_effective_user_id,
|
||||||
|
reset_current_user,
|
||||||
|
set_current_user,
|
||||||
|
)
|
||||||
|
|
||||||
SubagentExecutor = classes["SubagentExecutor"]
|
SubagentExecutor = classes["SubagentExecutor"]
|
||||||
SubagentStatus = classes["SubagentStatus"]
|
SubagentStatus = classes["SubagentStatus"]
|
||||||
|
|
||||||
caller_thread = threading.current_thread().name
|
caller_thread = threading.current_thread().name
|
||||||
isolated_helper_threads = []
|
isolated_helper_threads = []
|
||||||
execution_threads = []
|
execution_threads = []
|
||||||
|
effective_user_ids = []
|
||||||
final_state = {
|
final_state = {
|
||||||
"messages": [
|
"messages": [
|
||||||
msg.human("Task"),
|
msg.human("Task"),
|
||||||
@@ -541,6 +548,7 @@ class TestSyncExecutionPath:
|
|||||||
|
|
||||||
async def mock_astream(*args, **kwargs):
|
async def mock_astream(*args, **kwargs):
|
||||||
execution_threads.append(threading.current_thread().name)
|
execution_threads.append(threading.current_thread().name)
|
||||||
|
effective_user_ids.append(get_effective_user_id())
|
||||||
yield final_state
|
yield final_state
|
||||||
|
|
||||||
mock_agent.astream = mock_astream
|
mock_agent.astream = mock_astream
|
||||||
@@ -557,14 +565,19 @@ class TestSyncExecutionPath:
|
|||||||
isolated_helper_threads.append(threading.current_thread().name)
|
isolated_helper_threads.append(threading.current_thread().name)
|
||||||
return original_isolated_execute(task, result_holder)
|
return original_isolated_execute(task, result_holder)
|
||||||
|
|
||||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
token = set_current_user(SimpleNamespace(id="alice"))
|
||||||
with patch.object(executor, "_execute_in_isolated_loop", side_effect=tracked_isolated_execute) as isolated:
|
try:
|
||||||
result = executor.execute("Task")
|
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||||
|
with patch.object(executor, "_execute_in_isolated_loop", side_effect=tracked_isolated_execute) as isolated:
|
||||||
|
result = executor.execute("Task")
|
||||||
|
finally:
|
||||||
|
reset_current_user(token)
|
||||||
|
|
||||||
assert isolated.call_count == 1
|
assert isolated.call_count == 1
|
||||||
assert isolated_helper_threads == [caller_thread]
|
assert isolated_helper_threads == [caller_thread]
|
||||||
assert execution_threads
|
assert execution_threads
|
||||||
assert execution_threads == ["subagent-persistent-loop"]
|
assert execution_threads == ["subagent-persistent-loop"]
|
||||||
|
assert effective_user_ids == ["alice"]
|
||||||
assert result.status == SubagentStatus.COMPLETED
|
assert result.status == SubagentStatus.COMPLETED
|
||||||
assert result.result == "Async loop result"
|
assert result.result == "Async loop result"
|
||||||
|
|
||||||
@@ -1114,6 +1127,53 @@ class TestCooperativeCancellation:
|
|||||||
assert result.result == "done: Task"
|
assert result.result == "done: Task"
|
||||||
assert result.error is None
|
assert result.error is None
|
||||||
|
|
||||||
|
def test_execute_async_propagates_user_context_to_isolated_loop(self, executor_module, classes, base_config):
|
||||||
|
"""Regression: background subagent execution must keep request user context."""
|
||||||
|
import concurrent.futures
|
||||||
|
|
||||||
|
from deerflow.runtime.user_context import (
|
||||||
|
get_effective_user_id,
|
||||||
|
reset_current_user,
|
||||||
|
set_current_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
SubagentExecutor = classes["SubagentExecutor"]
|
||||||
|
SubagentStatus = classes["SubagentStatus"]
|
||||||
|
|
||||||
|
async def fake_aexecute(task, result_holder=None):
|
||||||
|
result = result_holder
|
||||||
|
result.status = SubagentStatus.COMPLETED
|
||||||
|
result.result = get_effective_user_id()
|
||||||
|
result.completed_at = datetime.now()
|
||||||
|
return result
|
||||||
|
|
||||||
|
executor = SubagentExecutor(
|
||||||
|
config=base_config,
|
||||||
|
tools=[],
|
||||||
|
thread_id="test-thread",
|
||||||
|
trace_id="test-trace",
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||||
|
token = set_current_user(SimpleNamespace(id="alice"))
|
||||||
|
try:
|
||||||
|
with (
|
||||||
|
patch.object(executor_module, "_scheduler_pool", scheduler),
|
||||||
|
patch.object(executor, "_aexecute", side_effect=fake_aexecute),
|
||||||
|
patch.object(executor, "execute", side_effect=AssertionError("execute() should not be called by execute_async")),
|
||||||
|
):
|
||||||
|
task_id = executor.execute_async("Task")
|
||||||
|
executor_module._scheduler_pool.shutdown(wait=True)
|
||||||
|
finally:
|
||||||
|
reset_current_user(token)
|
||||||
|
scheduler.shutdown(wait=False, cancel_futures=True)
|
||||||
|
|
||||||
|
result = executor_module._background_tasks.get(task_id)
|
||||||
|
assert result is not None
|
||||||
|
assert result.status == SubagentStatus.COMPLETED
|
||||||
|
assert result.result == "alice"
|
||||||
|
assert result.error is None
|
||||||
|
|
||||||
def test_timeout_does_not_overwrite_cancelled(self, executor_module, classes, base_config, msg):
|
def test_timeout_does_not_overwrite_cancelled(self, executor_module, classes, base_config, msg):
|
||||||
"""Test that the real timeout handler does not overwrite CANCELLED status.
|
"""Test that the real timeout handler does not overwrite CANCELLED status.
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user