fix(subagents): make subagent timeout terminal state atomic (#2583)

* Guard subagent terminal state transitions

* fix: publish subagent terminal status last

* Fix subagent timeout test to avoid blocking event loop

* Fix subagent timeout test tracking

* Refine subagent terminal state handling

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
KiteEater
2026-05-18 22:19:32 +08:00
committed by GitHub
parent b5108e3520
commit 3acca12614
2 changed files with 222 additions and 55 deletions
@@ -47,6 +47,15 @@ class SubagentStatus(Enum):
CANCELLED = "cancelled"
TIMED_OUT = "timed_out"
@property
def is_terminal(self) -> bool:
return self in {
type(self).COMPLETED,
type(self).FAILED,
type(self).CANCELLED,
type(self).TIMED_OUT,
}
@dataclass
class SubagentResult:
@@ -74,12 +83,48 @@ class SubagentResult:
token_usage_records: list[dict[str, int | str]] = field(default_factory=list)
usage_reported: bool = False
cancel_event: threading.Event = field(default_factory=threading.Event, repr=False)
_state_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False)
def __post_init__(self):
"""Initialize mutable defaults."""
if self.ai_messages is None:
self.ai_messages = []
def try_set_terminal(
self,
status: SubagentStatus,
*,
result: str | None = None,
error: str | None = None,
completed_at: datetime | None = None,
ai_messages: list[dict[str, Any]] | None = None,
token_usage_records: list[dict[str, int | str]] | None = None,
) -> bool:
"""Set a terminal status exactly once.
Background timeout/cancellation and the execution worker can race on the
same result holder. The first terminal transition wins; late terminal
writes must not change status or payload fields.
"""
if not status.is_terminal:
raise ValueError(f"Status {status} is not terminal")
with self._state_lock:
if self.status.is_terminal:
return False
if result is not None:
self.result = result
if error is not None:
self.error = error
if ai_messages is not None:
self.ai_messages = ai_messages
if token_usage_records is not None:
self.token_usage_records = token_usage_records
self.completed_at = completed_at or datetime.now()
self.status = status
return True
# Global storage for background task results
_background_tasks: dict[str, SubagentResult] = {}
@@ -459,13 +504,11 @@ class SubagentExecutor:
# Pre-check: bail out immediately if already cancelled before streaming starts
if result.cancel_event.is_set():
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} cancelled before streaming")
with _background_tasks_lock:
if result.status == SubagentStatus.RUNNING:
result.status = SubagentStatus.CANCELLED
result.error = "Cancelled by user"
result.completed_at = datetime.now()
if collector is not None:
result.token_usage_records = collector.snapshot_records()
result.try_set_terminal(
SubagentStatus.CANCELLED,
error="Cancelled by user",
token_usage_records=collector.snapshot_records(),
)
return result
async for chunk in agent.astream(state, config=run_config, context=context, stream_mode="values"): # type: ignore[arg-type]
@@ -475,12 +518,11 @@ class SubagentExecutor:
# interrupted until the next chunk is yielded.
if result.cancel_event.is_set():
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} cancelled by parent")
with _background_tasks_lock:
if result.status == SubagentStatus.RUNNING:
result.status = SubagentStatus.CANCELLED
result.error = "Cancelled by user"
result.completed_at = datetime.now()
result.token_usage_records = collector.snapshot_records()
result.try_set_terminal(
SubagentStatus.CANCELLED,
error="Cancelled by user",
token_usage_records=collector.snapshot_records(),
)
return result
final_state = chunk
@@ -507,11 +549,12 @@ class SubagentExecutor:
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} captured AI message #{len(ai_messages)}")
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} completed async execution")
result.token_usage_records = collector.snapshot_records()
token_usage_records = collector.snapshot_records()
final_result: str | None = None
if final_state is None:
logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no final state")
result.result = "No response generated"
final_result = "No response generated"
else:
# Extract the final message - find the last AIMessage
messages = final_state.get("messages", [])
@@ -528,7 +571,7 @@ class SubagentExecutor:
content = last_ai_message.content
# Handle both str and list content types for the final result
if isinstance(content, str):
result.result = content
final_result = content
elif isinstance(content, list):
# Extract text from list of content blocks for final result only.
# Concatenate raw string chunks directly, but preserve separation
@@ -547,16 +590,16 @@ class SubagentExecutor:
text_parts.append(text_val)
if pending_str_parts:
text_parts.append("".join(pending_str_parts))
result.result = "\n".join(text_parts) if text_parts else "No text content in response"
final_result = "\n".join(text_parts) if text_parts else "No text content in response"
else:
result.result = str(content)
final_result = str(content)
elif messages:
# Fallback: use the last message if no AIMessage found
last_message = messages[-1]
logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no AIMessage found, using last message: {type(last_message)}")
raw_content = last_message.content if hasattr(last_message, "content") else str(last_message)
if isinstance(raw_content, str):
result.result = raw_content
final_result = raw_content
elif isinstance(raw_content, list):
parts = []
pending_str_parts = []
@@ -572,23 +615,29 @@ class SubagentExecutor:
parts.append(text_val)
if pending_str_parts:
parts.append("".join(pending_str_parts))
result.result = "\n".join(parts) if parts else "No text content in response"
final_result = "\n".join(parts) if parts else "No text content in response"
else:
result.result = str(raw_content)
final_result = str(raw_content)
else:
logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no messages in final state")
result.result = "No response generated"
final_result = "No response generated"
result.status = SubagentStatus.COMPLETED
result.completed_at = datetime.now()
if final_result is None:
final_result = "No response generated"
result.try_set_terminal(
SubagentStatus.COMPLETED,
result=final_result,
token_usage_records=token_usage_records,
)
except Exception as e:
logger.exception(f"[trace={self.trace_id}] Subagent {self.config.name} async execution failed")
result.status = SubagentStatus.FAILED
result.error = str(e)
result.completed_at = datetime.now()
if collector is not None:
result.token_usage_records = collector.snapshot_records()
result.try_set_terminal(
SubagentStatus.FAILED,
error=str(e),
token_usage_records=collector.snapshot_records() if collector is not None else None,
)
return result
@@ -667,11 +716,9 @@ class SubagentExecutor:
result = SubagentResult(
task_id=str(uuid.uuid4())[:8],
trace_id=self.trace_id,
status=SubagentStatus.FAILED,
status=SubagentStatus.RUNNING,
)
result.status = SubagentStatus.FAILED
result.error = str(e)
result.completed_at = datetime.now()
result.try_set_terminal(SubagentStatus.FAILED, error=str(e))
return result
def execute_async(self, task: str, task_id: str | None = None) -> str:
@@ -718,29 +765,21 @@ class SubagentExecutor:
)
try:
# Wait for execution with timeout
exec_result = execution_future.result(timeout=self.config.timeout_seconds)
with _background_tasks_lock:
_background_tasks[task_id].status = exec_result.status
_background_tasks[task_id].result = exec_result.result
_background_tasks[task_id].error = exec_result.error
_background_tasks[task_id].completed_at = datetime.now()
_background_tasks[task_id].ai_messages = exec_result.ai_messages
execution_future.result(timeout=self.config.timeout_seconds)
except FuturesTimeoutError:
logger.error(f"[trace={self.trace_id}] Subagent {self.config.name} execution timed out after {self.config.timeout_seconds}s")
with _background_tasks_lock:
if _background_tasks[task_id].status == SubagentStatus.RUNNING:
_background_tasks[task_id].status = SubagentStatus.TIMED_OUT
_background_tasks[task_id].error = f"Execution timed out after {self.config.timeout_seconds} seconds"
_background_tasks[task_id].completed_at = datetime.now()
# Signal cooperative cancellation and cancel the future
result_holder.cancel_event.set()
result_holder.try_set_terminal(
SubagentStatus.TIMED_OUT,
error=f"Execution timed out after {self.config.timeout_seconds} seconds",
)
execution_future.cancel()
except Exception as e:
logger.exception(f"[trace={self.trace_id}] Subagent {self.config.name} async execution failed")
with _background_tasks_lock:
_background_tasks[task_id].status = SubagentStatus.FAILED
_background_tasks[task_id].error = str(e)
_background_tasks[task_id].completed_at = datetime.now()
task_result = _background_tasks[task_id]
task_result.try_set_terminal(SubagentStatus.FAILED, error=str(e))
_scheduler_pool.submit(run_task)
return task_id
@@ -811,13 +850,7 @@ def cleanup_background_task(task_id: str) -> None:
# Only clean up tasks that are in a terminal state to avoid races with
# the background executor still updating the task entry.
is_terminal_status = result.status in {
SubagentStatus.COMPLETED,
SubagentStatus.FAILED,
SubagentStatus.CANCELLED,
SubagentStatus.TIMED_OUT,
}
if is_terminal_status or result.completed_at is not None:
if result.status.is_terminal or result.completed_at is not None:
del _background_tasks[task_id]
logger.debug("Cleaned up background task: %s", task_id)
else:
+134
View File
@@ -1125,6 +1125,15 @@ class TestAsyncToolSupport:
class TestThreadSafety:
"""Test thread safety of executor operations."""
@pytest.fixture
def executor_module(self, _setup_executor_classes):
"""Import the executor module with real classes."""
import importlib
from deerflow.subagents import executor
return importlib.reload(executor)
def test_multiple_executors_in_parallel(self, classes, base_config, msg):
"""Test multiple executors running in parallel via thread pool."""
from concurrent.futures import ThreadPoolExecutor, as_completed
@@ -1170,6 +1179,68 @@ class TestThreadSafety:
assert result.status == SubagentStatus.COMPLETED
assert "Result" in result.result
def test_terminal_status_is_published_after_payload_fields(self, executor_module, monkeypatch):
"""Readers must not observe terminal status before terminal payload is complete."""
SubagentResult = executor_module.SubagentResult
SubagentStatus = executor_module.SubagentStatus
now_entered = threading.Event()
release_now = threading.Event()
completed_at = datetime(2026, 5, 1, 12, 0, 0)
writer_errors: list[BaseException] = []
class BlockingDateTime:
@staticmethod
def now():
now_entered.set()
release_now.wait(timeout=5)
return completed_at
monkeypatch.setattr(executor_module, "datetime", BlockingDateTime)
result = SubagentResult(
task_id="test-terminal-publication-order",
trace_id="test-trace",
status=SubagentStatus.RUNNING,
)
token_usage_records = [
{
"source_run_id": "run-1",
"caller": "subagent:test-agent",
"input_tokens": 10,
"output_tokens": 5,
"total_tokens": 15,
}
]
def set_terminal():
try:
assert result.try_set_terminal(
SubagentStatus.COMPLETED,
result="done",
token_usage_records=token_usage_records,
)
except BaseException as exc:
writer_errors.append(exc)
writer = threading.Thread(target=set_terminal)
writer.start()
assert now_entered.wait(timeout=3), "try_set_terminal did not reach completed_at assignment"
assert result.completed_at is None
assert result.status == SubagentStatus.RUNNING
assert result.token_usage_records == token_usage_records
release_now.set()
writer.join(timeout=3)
assert not writer.is_alive(), "try_set_terminal did not finish"
assert writer_errors == []
assert result.completed_at == completed_at
assert result.status == SubagentStatus.COMPLETED
assert result.result == "done"
assert result.token_usage_records == token_usage_records
# -----------------------------------------------------------------------------
# Cleanup Background Task Tests
@@ -1604,6 +1675,69 @@ class TestCooperativeCancellation:
assert result.error == "Cancelled by user"
assert result.completed_at is not None
def test_late_completion_after_timeout_does_not_overwrite_timed_out(self, executor_module, classes, msg):
"""Late completion from the execution worker must not overwrite TIMED_OUT."""
SubagentExecutor = classes["SubagentExecutor"]
SubagentStatus = classes["SubagentStatus"]
short_config = classes["SubagentConfig"](
name="test-agent",
description="Test agent",
system_prompt="You are a test agent.",
max_turns=10,
timeout_seconds=0.05,
)
first_chunk_seen = threading.Event()
finish_stream = threading.Event()
execution_done = threading.Event()
async def mock_astream(*args, **kwargs):
yield {"messages": [msg.human("Task"), msg.ai("late completion", "msg-late")]}
first_chunk_seen.set()
deadline = asyncio.get_running_loop().time() + 5
while not finish_stream.is_set():
if asyncio.get_running_loop().time() >= deadline:
break
await asyncio.sleep(0.001)
mock_agent = MagicMock()
mock_agent.astream = mock_astream
executor = SubagentExecutor(
config=short_config,
tools=[],
thread_id="test-thread",
trace_id="test-trace",
)
original_aexecute = executor._aexecute
async def tracked_aexecute(task, result_holder=None):
try:
return await original_aexecute(task, result_holder)
finally:
execution_done.set()
with patch.object(executor, "_create_agent", return_value=mock_agent), patch.object(executor, "_aexecute", tracked_aexecute):
task_id = executor.execute_async("Task")
assert first_chunk_seen.wait(timeout=3), "stream did not yield initial chunk"
result = executor_module._background_tasks[task_id]
assert result.cancel_event.wait(timeout=3), "timeout handler did not request cancellation"
assert result.status.value == SubagentStatus.TIMED_OUT.value
timed_out_error = result.error
timed_out_completed_at = result.completed_at
finish_stream.set()
assert execution_done.wait(timeout=3), "execution worker did not finish"
result = executor_module._background_tasks.get(task_id)
assert result is not None
assert result.status.value == SubagentStatus.TIMED_OUT.value
assert result.result is None
assert result.error == timed_out_error
assert result.completed_at == timed_out_completed_at
def test_cleanup_removes_cancelled_task(self, executor_module, classes):
"""Test that cleanup removes a CANCELLED task (terminal state)."""
SubagentResult = classes["SubagentResult"]