fix(subagents): propagate user context across threaded execution (#2676)

This commit is contained in:
JerryLee
2026-05-01 18:27:18 +10:00
committed by GitHub
parent 78633c69ac
commit 83938cf35a
2 changed files with 88 additions and 10 deletions
@@ -5,8 +5,10 @@ import atexit
import logging import logging
import threading import threading
import uuid import uuid
from collections.abc import Callable, Coroutine
from concurrent.futures import Future, ThreadPoolExecutor from concurrent.futures import Future, ThreadPoolExecutor
from concurrent.futures import TimeoutError as FuturesTimeoutError from concurrent.futures import TimeoutError as FuturesTimeoutError
from contextvars import Context, copy_context
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
@@ -168,6 +170,19 @@ def _get_isolated_subagent_loop() -> asyncio.AbstractEventLoop:
return _isolated_subagent_loop return _isolated_subagent_loop
def _submit_to_isolated_loop_in_context(
context: Context,
coro_factory: Callable[[], Coroutine[Any, Any, SubagentResult]],
) -> Future[SubagentResult]:
"""Submit a coroutine to the isolated loop while preserving ContextVar state."""
return context.run(
lambda: asyncio.run_coroutine_threadsafe(
coro_factory(),
_get_isolated_subagent_loop(),
)
)
def _filter_tools( def _filter_tools(
all_tools: list[BaseTool], all_tools: list[BaseTool],
allowed: list[str] | None, allowed: list[str] | None,
@@ -549,10 +564,11 @@ class SubagentExecutor:
from being tied to a short-lived loop that gets closed per execution. from being tied to a short-lived loop that gets closed per execution.
""" """
future: Future[SubagentResult] | None = None future: Future[SubagentResult] | None = None
parent_context = copy_context()
try: try:
future = asyncio.run_coroutine_threadsafe( future = _submit_to_isolated_loop_in_context(
self._aexecute(task, result_holder), parent_context,
_get_isolated_subagent_loop(), lambda: self._aexecute(task, result_holder),
) )
return future.result(timeout=self.config.timeout_seconds) return future.result(timeout=self.config.timeout_seconds)
except FuturesTimeoutError: except FuturesTimeoutError:
@@ -646,6 +662,8 @@ class SubagentExecutor:
with _background_tasks_lock: with _background_tasks_lock:
_background_tasks[task_id] = result _background_tasks[task_id] = result
parent_context = copy_context()
# Submit to scheduler pool # Submit to scheduler pool
def run_task(): def run_task():
with _background_tasks_lock: with _background_tasks_lock:
@@ -656,9 +674,9 @@ class SubagentExecutor:
try: try:
# Submit execution directly to the persistent isolated loop so the # Submit execution directly to the persistent isolated loop so the
# background path does not create a temporary loop via execute(). # background path does not create a temporary loop via execute().
execution_future = asyncio.run_coroutine_threadsafe( execution_future = _submit_to_isolated_loop_in_context(
self._aexecute(task, result_holder), parent_context,
_get_isolated_subagent_loop(), lambda: self._aexecute(task, result_holder),
) )
try: try:
# Wait for execution with timeout # Wait for execution with timeout
+64 -4
View File
@@ -17,7 +17,7 @@ import asyncio
import sys import sys
import threading import threading
from datetime import datetime from datetime import datetime
from types import ModuleType from types import ModuleType, SimpleNamespace
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
@@ -526,12 +526,19 @@ class TestSyncExecutionPath:
@pytest.mark.anyio @pytest.mark.anyio
async def test_execute_in_running_event_loop_calls_isolated_loop_directly(self, classes, base_config, mock_agent, msg): async def test_execute_in_running_event_loop_calls_isolated_loop_directly(self, classes, base_config, mock_agent, msg):
"""Test that execute() calls the isolated-loop helper directly in a running loop.""" """Test that execute() calls the isolated-loop helper directly in a running loop."""
from deerflow.runtime.user_context import (
get_effective_user_id,
reset_current_user,
set_current_user,
)
SubagentExecutor = classes["SubagentExecutor"] SubagentExecutor = classes["SubagentExecutor"]
SubagentStatus = classes["SubagentStatus"] SubagentStatus = classes["SubagentStatus"]
caller_thread = threading.current_thread().name caller_thread = threading.current_thread().name
isolated_helper_threads = [] isolated_helper_threads = []
execution_threads = [] execution_threads = []
effective_user_ids = []
final_state = { final_state = {
"messages": [ "messages": [
msg.human("Task"), msg.human("Task"),
@@ -541,6 +548,7 @@ class TestSyncExecutionPath:
async def mock_astream(*args, **kwargs): async def mock_astream(*args, **kwargs):
execution_threads.append(threading.current_thread().name) execution_threads.append(threading.current_thread().name)
effective_user_ids.append(get_effective_user_id())
yield final_state yield final_state
mock_agent.astream = mock_astream mock_agent.astream = mock_astream
@@ -557,14 +565,19 @@ class TestSyncExecutionPath:
isolated_helper_threads.append(threading.current_thread().name) isolated_helper_threads.append(threading.current_thread().name)
return original_isolated_execute(task, result_holder) return original_isolated_execute(task, result_holder)
with patch.object(executor, "_create_agent", return_value=mock_agent): token = set_current_user(SimpleNamespace(id="alice"))
with patch.object(executor, "_execute_in_isolated_loop", side_effect=tracked_isolated_execute) as isolated: try:
result = executor.execute("Task") with patch.object(executor, "_create_agent", return_value=mock_agent):
with patch.object(executor, "_execute_in_isolated_loop", side_effect=tracked_isolated_execute) as isolated:
result = executor.execute("Task")
finally:
reset_current_user(token)
assert isolated.call_count == 1 assert isolated.call_count == 1
assert isolated_helper_threads == [caller_thread] assert isolated_helper_threads == [caller_thread]
assert execution_threads assert execution_threads
assert execution_threads == ["subagent-persistent-loop"] assert execution_threads == ["subagent-persistent-loop"]
assert effective_user_ids == ["alice"]
assert result.status == SubagentStatus.COMPLETED assert result.status == SubagentStatus.COMPLETED
assert result.result == "Async loop result" assert result.result == "Async loop result"
@@ -1114,6 +1127,53 @@ class TestCooperativeCancellation:
assert result.result == "done: Task" assert result.result == "done: Task"
assert result.error is None assert result.error is None
def test_execute_async_propagates_user_context_to_isolated_loop(self, executor_module, classes, base_config):
"""Regression: background subagent execution must keep request user context."""
import concurrent.futures
from deerflow.runtime.user_context import (
get_effective_user_id,
reset_current_user,
set_current_user,
)
SubagentExecutor = classes["SubagentExecutor"]
SubagentStatus = classes["SubagentStatus"]
async def fake_aexecute(task, result_holder=None):
result = result_holder
result.status = SubagentStatus.COMPLETED
result.result = get_effective_user_id()
result.completed_at = datetime.now()
return result
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id="test-thread",
trace_id="test-trace",
)
scheduler = concurrent.futures.ThreadPoolExecutor(max_workers=1)
token = set_current_user(SimpleNamespace(id="alice"))
try:
with (
patch.object(executor_module, "_scheduler_pool", scheduler),
patch.object(executor, "_aexecute", side_effect=fake_aexecute),
patch.object(executor, "execute", side_effect=AssertionError("execute() should not be called by execute_async")),
):
task_id = executor.execute_async("Task")
executor_module._scheduler_pool.shutdown(wait=True)
finally:
reset_current_user(token)
scheduler.shutdown(wait=False, cancel_futures=True)
result = executor_module._background_tasks.get(task_id)
assert result is not None
assert result.status == SubagentStatus.COMPLETED
assert result.result == "alice"
assert result.error is None
def test_timeout_does_not_overwrite_cancelled(self, executor_module, classes, base_config, msg): def test_timeout_does_not_overwrite_cancelled(self, executor_module, classes, base_config, msg):
"""Test that the real timeout handler does not overwrite CANCELLED status. """Test that the real timeout handler does not overwrite CANCELLED status.