fix(subagents): make subagent timeout terminal state atomic (#2583)
* Guard subagent terminal state transitions * fix: publish subagent terminal status last * Fix subagent timeout test to avoid blocking event loop * Fix subagent timeout test tracking * Refine subagent terminal state handling --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -47,6 +47,15 @@ class SubagentStatus(Enum):
|
||||
CANCELLED = "cancelled"
|
||||
TIMED_OUT = "timed_out"
|
||||
|
||||
@property
|
||||
def is_terminal(self) -> bool:
|
||||
return self in {
|
||||
type(self).COMPLETED,
|
||||
type(self).FAILED,
|
||||
type(self).CANCELLED,
|
||||
type(self).TIMED_OUT,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubagentResult:
|
||||
@@ -74,12 +83,48 @@ class SubagentResult:
|
||||
token_usage_records: list[dict[str, int | str]] = field(default_factory=list)
|
||||
usage_reported: bool = False
|
||||
cancel_event: threading.Event = field(default_factory=threading.Event, repr=False)
|
||||
_state_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize mutable defaults."""
|
||||
if self.ai_messages is None:
|
||||
self.ai_messages = []
|
||||
|
||||
def try_set_terminal(
|
||||
self,
|
||||
status: SubagentStatus,
|
||||
*,
|
||||
result: str | None = None,
|
||||
error: str | None = None,
|
||||
completed_at: datetime | None = None,
|
||||
ai_messages: list[dict[str, Any]] | None = None,
|
||||
token_usage_records: list[dict[str, int | str]] | None = None,
|
||||
) -> bool:
|
||||
"""Set a terminal status exactly once.
|
||||
|
||||
Background timeout/cancellation and the execution worker can race on the
|
||||
same result holder. The first terminal transition wins; late terminal
|
||||
writes must not change status or payload fields.
|
||||
"""
|
||||
if not status.is_terminal:
|
||||
raise ValueError(f"Status {status} is not terminal")
|
||||
|
||||
with self._state_lock:
|
||||
if self.status.is_terminal:
|
||||
return False
|
||||
|
||||
if result is not None:
|
||||
self.result = result
|
||||
if error is not None:
|
||||
self.error = error
|
||||
if ai_messages is not None:
|
||||
self.ai_messages = ai_messages
|
||||
if token_usage_records is not None:
|
||||
self.token_usage_records = token_usage_records
|
||||
self.completed_at = completed_at or datetime.now()
|
||||
self.status = status
|
||||
return True
|
||||
|
||||
|
||||
# Global storage for background task results
|
||||
_background_tasks: dict[str, SubagentResult] = {}
|
||||
@@ -459,13 +504,11 @@ class SubagentExecutor:
|
||||
# Pre-check: bail out immediately if already cancelled before streaming starts
|
||||
if result.cancel_event.is_set():
|
||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} cancelled before streaming")
|
||||
with _background_tasks_lock:
|
||||
if result.status == SubagentStatus.RUNNING:
|
||||
result.status = SubagentStatus.CANCELLED
|
||||
result.error = "Cancelled by user"
|
||||
result.completed_at = datetime.now()
|
||||
if collector is not None:
|
||||
result.token_usage_records = collector.snapshot_records()
|
||||
result.try_set_terminal(
|
||||
SubagentStatus.CANCELLED,
|
||||
error="Cancelled by user",
|
||||
token_usage_records=collector.snapshot_records(),
|
||||
)
|
||||
return result
|
||||
|
||||
async for chunk in agent.astream(state, config=run_config, context=context, stream_mode="values"): # type: ignore[arg-type]
|
||||
@@ -475,12 +518,11 @@ class SubagentExecutor:
|
||||
# interrupted until the next chunk is yielded.
|
||||
if result.cancel_event.is_set():
|
||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} cancelled by parent")
|
||||
with _background_tasks_lock:
|
||||
if result.status == SubagentStatus.RUNNING:
|
||||
result.status = SubagentStatus.CANCELLED
|
||||
result.error = "Cancelled by user"
|
||||
result.completed_at = datetime.now()
|
||||
result.token_usage_records = collector.snapshot_records()
|
||||
result.try_set_terminal(
|
||||
SubagentStatus.CANCELLED,
|
||||
error="Cancelled by user",
|
||||
token_usage_records=collector.snapshot_records(),
|
||||
)
|
||||
return result
|
||||
|
||||
final_state = chunk
|
||||
@@ -507,11 +549,12 @@ class SubagentExecutor:
|
||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} captured AI message #{len(ai_messages)}")
|
||||
|
||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} completed async execution")
|
||||
result.token_usage_records = collector.snapshot_records()
|
||||
token_usage_records = collector.snapshot_records()
|
||||
final_result: str | None = None
|
||||
|
||||
if final_state is None:
|
||||
logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no final state")
|
||||
result.result = "No response generated"
|
||||
final_result = "No response generated"
|
||||
else:
|
||||
# Extract the final message - find the last AIMessage
|
||||
messages = final_state.get("messages", [])
|
||||
@@ -528,7 +571,7 @@ class SubagentExecutor:
|
||||
content = last_ai_message.content
|
||||
# Handle both str and list content types for the final result
|
||||
if isinstance(content, str):
|
||||
result.result = content
|
||||
final_result = content
|
||||
elif isinstance(content, list):
|
||||
# Extract text from list of content blocks for final result only.
|
||||
# Concatenate raw string chunks directly, but preserve separation
|
||||
@@ -547,16 +590,16 @@ class SubagentExecutor:
|
||||
text_parts.append(text_val)
|
||||
if pending_str_parts:
|
||||
text_parts.append("".join(pending_str_parts))
|
||||
result.result = "\n".join(text_parts) if text_parts else "No text content in response"
|
||||
final_result = "\n".join(text_parts) if text_parts else "No text content in response"
|
||||
else:
|
||||
result.result = str(content)
|
||||
final_result = str(content)
|
||||
elif messages:
|
||||
# Fallback: use the last message if no AIMessage found
|
||||
last_message = messages[-1]
|
||||
logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no AIMessage found, using last message: {type(last_message)}")
|
||||
raw_content = last_message.content if hasattr(last_message, "content") else str(last_message)
|
||||
if isinstance(raw_content, str):
|
||||
result.result = raw_content
|
||||
final_result = raw_content
|
||||
elif isinstance(raw_content, list):
|
||||
parts = []
|
||||
pending_str_parts = []
|
||||
@@ -572,23 +615,29 @@ class SubagentExecutor:
|
||||
parts.append(text_val)
|
||||
if pending_str_parts:
|
||||
parts.append("".join(pending_str_parts))
|
||||
result.result = "\n".join(parts) if parts else "No text content in response"
|
||||
final_result = "\n".join(parts) if parts else "No text content in response"
|
||||
else:
|
||||
result.result = str(raw_content)
|
||||
final_result = str(raw_content)
|
||||
else:
|
||||
logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no messages in final state")
|
||||
result.result = "No response generated"
|
||||
final_result = "No response generated"
|
||||
|
||||
result.status = SubagentStatus.COMPLETED
|
||||
result.completed_at = datetime.now()
|
||||
if final_result is None:
|
||||
final_result = "No response generated"
|
||||
|
||||
result.try_set_terminal(
|
||||
SubagentStatus.COMPLETED,
|
||||
result=final_result,
|
||||
token_usage_records=token_usage_records,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"[trace={self.trace_id}] Subagent {self.config.name} async execution failed")
|
||||
result.status = SubagentStatus.FAILED
|
||||
result.error = str(e)
|
||||
result.completed_at = datetime.now()
|
||||
if collector is not None:
|
||||
result.token_usage_records = collector.snapshot_records()
|
||||
result.try_set_terminal(
|
||||
SubagentStatus.FAILED,
|
||||
error=str(e),
|
||||
token_usage_records=collector.snapshot_records() if collector is not None else None,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@@ -667,11 +716,9 @@ class SubagentExecutor:
|
||||
result = SubagentResult(
|
||||
task_id=str(uuid.uuid4())[:8],
|
||||
trace_id=self.trace_id,
|
||||
status=SubagentStatus.FAILED,
|
||||
status=SubagentStatus.RUNNING,
|
||||
)
|
||||
result.status = SubagentStatus.FAILED
|
||||
result.error = str(e)
|
||||
result.completed_at = datetime.now()
|
||||
result.try_set_terminal(SubagentStatus.FAILED, error=str(e))
|
||||
return result
|
||||
|
||||
def execute_async(self, task: str, task_id: str | None = None) -> str:
|
||||
@@ -718,29 +765,21 @@ class SubagentExecutor:
|
||||
)
|
||||
try:
|
||||
# Wait for execution with timeout
|
||||
exec_result = execution_future.result(timeout=self.config.timeout_seconds)
|
||||
with _background_tasks_lock:
|
||||
_background_tasks[task_id].status = exec_result.status
|
||||
_background_tasks[task_id].result = exec_result.result
|
||||
_background_tasks[task_id].error = exec_result.error
|
||||
_background_tasks[task_id].completed_at = datetime.now()
|
||||
_background_tasks[task_id].ai_messages = exec_result.ai_messages
|
||||
execution_future.result(timeout=self.config.timeout_seconds)
|
||||
except FuturesTimeoutError:
|
||||
logger.error(f"[trace={self.trace_id}] Subagent {self.config.name} execution timed out after {self.config.timeout_seconds}s")
|
||||
with _background_tasks_lock:
|
||||
if _background_tasks[task_id].status == SubagentStatus.RUNNING:
|
||||
_background_tasks[task_id].status = SubagentStatus.TIMED_OUT
|
||||
_background_tasks[task_id].error = f"Execution timed out after {self.config.timeout_seconds} seconds"
|
||||
_background_tasks[task_id].completed_at = datetime.now()
|
||||
# Signal cooperative cancellation and cancel the future
|
||||
result_holder.cancel_event.set()
|
||||
result_holder.try_set_terminal(
|
||||
SubagentStatus.TIMED_OUT,
|
||||
error=f"Execution timed out after {self.config.timeout_seconds} seconds",
|
||||
)
|
||||
execution_future.cancel()
|
||||
except Exception as e:
|
||||
logger.exception(f"[trace={self.trace_id}] Subagent {self.config.name} async execution failed")
|
||||
with _background_tasks_lock:
|
||||
_background_tasks[task_id].status = SubagentStatus.FAILED
|
||||
_background_tasks[task_id].error = str(e)
|
||||
_background_tasks[task_id].completed_at = datetime.now()
|
||||
task_result = _background_tasks[task_id]
|
||||
task_result.try_set_terminal(SubagentStatus.FAILED, error=str(e))
|
||||
|
||||
_scheduler_pool.submit(run_task)
|
||||
return task_id
|
||||
@@ -811,13 +850,7 @@ def cleanup_background_task(task_id: str) -> None:
|
||||
|
||||
# Only clean up tasks that are in a terminal state to avoid races with
|
||||
# the background executor still updating the task entry.
|
||||
is_terminal_status = result.status in {
|
||||
SubagentStatus.COMPLETED,
|
||||
SubagentStatus.FAILED,
|
||||
SubagentStatus.CANCELLED,
|
||||
SubagentStatus.TIMED_OUT,
|
||||
}
|
||||
if is_terminal_status or result.completed_at is not None:
|
||||
if result.status.is_terminal or result.completed_at is not None:
|
||||
del _background_tasks[task_id]
|
||||
logger.debug("Cleaned up background task: %s", task_id)
|
||||
else:
|
||||
|
||||
@@ -1125,6 +1125,15 @@ class TestAsyncToolSupport:
|
||||
class TestThreadSafety:
|
||||
"""Test thread safety of executor operations."""
|
||||
|
||||
@pytest.fixture
|
||||
def executor_module(self, _setup_executor_classes):
|
||||
"""Import the executor module with real classes."""
|
||||
import importlib
|
||||
|
||||
from deerflow.subagents import executor
|
||||
|
||||
return importlib.reload(executor)
|
||||
|
||||
def test_multiple_executors_in_parallel(self, classes, base_config, msg):
|
||||
"""Test multiple executors running in parallel via thread pool."""
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
@@ -1170,6 +1179,68 @@ class TestThreadSafety:
|
||||
assert result.status == SubagentStatus.COMPLETED
|
||||
assert "Result" in result.result
|
||||
|
||||
def test_terminal_status_is_published_after_payload_fields(self, executor_module, monkeypatch):
|
||||
"""Readers must not observe terminal status before terminal payload is complete."""
|
||||
SubagentResult = executor_module.SubagentResult
|
||||
SubagentStatus = executor_module.SubagentStatus
|
||||
|
||||
now_entered = threading.Event()
|
||||
release_now = threading.Event()
|
||||
completed_at = datetime(2026, 5, 1, 12, 0, 0)
|
||||
writer_errors: list[BaseException] = []
|
||||
|
||||
class BlockingDateTime:
|
||||
@staticmethod
|
||||
def now():
|
||||
now_entered.set()
|
||||
release_now.wait(timeout=5)
|
||||
return completed_at
|
||||
|
||||
monkeypatch.setattr(executor_module, "datetime", BlockingDateTime)
|
||||
|
||||
result = SubagentResult(
|
||||
task_id="test-terminal-publication-order",
|
||||
trace_id="test-trace",
|
||||
status=SubagentStatus.RUNNING,
|
||||
)
|
||||
token_usage_records = [
|
||||
{
|
||||
"source_run_id": "run-1",
|
||||
"caller": "subagent:test-agent",
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 5,
|
||||
"total_tokens": 15,
|
||||
}
|
||||
]
|
||||
|
||||
def set_terminal():
|
||||
try:
|
||||
assert result.try_set_terminal(
|
||||
SubagentStatus.COMPLETED,
|
||||
result="done",
|
||||
token_usage_records=token_usage_records,
|
||||
)
|
||||
except BaseException as exc:
|
||||
writer_errors.append(exc)
|
||||
|
||||
writer = threading.Thread(target=set_terminal)
|
||||
writer.start()
|
||||
|
||||
assert now_entered.wait(timeout=3), "try_set_terminal did not reach completed_at assignment"
|
||||
assert result.completed_at is None
|
||||
assert result.status == SubagentStatus.RUNNING
|
||||
assert result.token_usage_records == token_usage_records
|
||||
|
||||
release_now.set()
|
||||
writer.join(timeout=3)
|
||||
|
||||
assert not writer.is_alive(), "try_set_terminal did not finish"
|
||||
assert writer_errors == []
|
||||
assert result.completed_at == completed_at
|
||||
assert result.status == SubagentStatus.COMPLETED
|
||||
assert result.result == "done"
|
||||
assert result.token_usage_records == token_usage_records
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Cleanup Background Task Tests
|
||||
@@ -1604,6 +1675,69 @@ class TestCooperativeCancellation:
|
||||
assert result.error == "Cancelled by user"
|
||||
assert result.completed_at is not None
|
||||
|
||||
def test_late_completion_after_timeout_does_not_overwrite_timed_out(self, executor_module, classes, msg):
|
||||
"""Late completion from the execution worker must not overwrite TIMED_OUT."""
|
||||
SubagentExecutor = classes["SubagentExecutor"]
|
||||
SubagentStatus = classes["SubagentStatus"]
|
||||
|
||||
short_config = classes["SubagentConfig"](
|
||||
name="test-agent",
|
||||
description="Test agent",
|
||||
system_prompt="You are a test agent.",
|
||||
max_turns=10,
|
||||
timeout_seconds=0.05,
|
||||
)
|
||||
|
||||
first_chunk_seen = threading.Event()
|
||||
finish_stream = threading.Event()
|
||||
execution_done = threading.Event()
|
||||
|
||||
async def mock_astream(*args, **kwargs):
|
||||
yield {"messages": [msg.human("Task"), msg.ai("late completion", "msg-late")]}
|
||||
first_chunk_seen.set()
|
||||
deadline = asyncio.get_running_loop().time() + 5
|
||||
while not finish_stream.is_set():
|
||||
if asyncio.get_running_loop().time() >= deadline:
|
||||
break
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.astream = mock_astream
|
||||
|
||||
executor = SubagentExecutor(
|
||||
config=short_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
trace_id="test-trace",
|
||||
)
|
||||
original_aexecute = executor._aexecute
|
||||
|
||||
async def tracked_aexecute(task, result_holder=None):
|
||||
try:
|
||||
return await original_aexecute(task, result_holder)
|
||||
finally:
|
||||
execution_done.set()
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent), patch.object(executor, "_aexecute", tracked_aexecute):
|
||||
task_id = executor.execute_async("Task")
|
||||
assert first_chunk_seen.wait(timeout=3), "stream did not yield initial chunk"
|
||||
|
||||
result = executor_module._background_tasks[task_id]
|
||||
assert result.cancel_event.wait(timeout=3), "timeout handler did not request cancellation"
|
||||
assert result.status.value == SubagentStatus.TIMED_OUT.value
|
||||
timed_out_error = result.error
|
||||
timed_out_completed_at = result.completed_at
|
||||
|
||||
finish_stream.set()
|
||||
assert execution_done.wait(timeout=3), "execution worker did not finish"
|
||||
|
||||
result = executor_module._background_tasks.get(task_id)
|
||||
assert result is not None
|
||||
assert result.status.value == SubagentStatus.TIMED_OUT.value
|
||||
assert result.result is None
|
||||
assert result.error == timed_out_error
|
||||
assert result.completed_at == timed_out_completed_at
|
||||
|
||||
def test_cleanup_removes_cancelled_task(self, executor_module, classes):
|
||||
"""Test that cleanup removes a CANCELLED task (terminal state)."""
|
||||
SubagentResult = classes["SubagentResult"]
|
||||
|
||||
Reference in New Issue
Block a user