diff --git a/backend/packages/harness/deerflow/subagents/executor.py b/backend/packages/harness/deerflow/subagents/executor.py index d6d2e4fc5..8fcbd5e1d 100644 --- a/backend/packages/harness/deerflow/subagents/executor.py +++ b/backend/packages/harness/deerflow/subagents/executor.py @@ -47,6 +47,15 @@ class SubagentStatus(Enum): CANCELLED = "cancelled" TIMED_OUT = "timed_out" + @property + def is_terminal(self) -> bool: + return self in { + type(self).COMPLETED, + type(self).FAILED, + type(self).CANCELLED, + type(self).TIMED_OUT, + } + @dataclass class SubagentResult: @@ -74,12 +83,48 @@ class SubagentResult: token_usage_records: list[dict[str, int | str]] = field(default_factory=list) usage_reported: bool = False cancel_event: threading.Event = field(default_factory=threading.Event, repr=False) + _state_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False) def __post_init__(self): """Initialize mutable defaults.""" if self.ai_messages is None: self.ai_messages = [] + def try_set_terminal( + self, + status: SubagentStatus, + *, + result: str | None = None, + error: str | None = None, + completed_at: datetime | None = None, + ai_messages: list[dict[str, Any]] | None = None, + token_usage_records: list[dict[str, int | str]] | None = None, + ) -> bool: + """Set a terminal status exactly once. + + Background timeout/cancellation and the execution worker can race on the + same result holder. The first terminal transition wins; late terminal + writes must not change status or payload fields. + """ + if not status.is_terminal: + raise ValueError(f"Status {status} is not terminal") + + with self._state_lock: + if self.status.is_terminal: + return False + + if result is not None: + self.result = result + if error is not None: + self.error = error + if ai_messages is not None: + self.ai_messages = ai_messages + if token_usage_records is not None: + self.token_usage_records = token_usage_records + self.completed_at = completed_at or datetime.now() + self.status = status + return True + # Global storage for background task results _background_tasks: dict[str, SubagentResult] = {} @@ -459,13 +504,11 @@ class SubagentExecutor: # Pre-check: bail out immediately if already cancelled before streaming starts if result.cancel_event.is_set(): logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} cancelled before streaming") - with _background_tasks_lock: - if result.status == SubagentStatus.RUNNING: - result.status = SubagentStatus.CANCELLED - result.error = "Cancelled by user" - result.completed_at = datetime.now() - if collector is not None: - result.token_usage_records = collector.snapshot_records() + result.try_set_terminal( + SubagentStatus.CANCELLED, + error="Cancelled by user", + token_usage_records=collector.snapshot_records(), + ) return result async for chunk in agent.astream(state, config=run_config, context=context, stream_mode="values"): # type: ignore[arg-type] @@ -475,12 +518,11 @@ class SubagentExecutor: # interrupted until the next chunk is yielded. if result.cancel_event.is_set(): logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} cancelled by parent") - with _background_tasks_lock: - if result.status == SubagentStatus.RUNNING: - result.status = SubagentStatus.CANCELLED - result.error = "Cancelled by user" - result.completed_at = datetime.now() - result.token_usage_records = collector.snapshot_records() + result.try_set_terminal( + SubagentStatus.CANCELLED, + error="Cancelled by user", + token_usage_records=collector.snapshot_records(), + ) return result final_state = chunk @@ -507,11 +549,12 @@ class SubagentExecutor: logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} captured AI message #{len(ai_messages)}") logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} completed async execution") - result.token_usage_records = collector.snapshot_records() + token_usage_records = collector.snapshot_records() + final_result: str | None = None if final_state is None: logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no final state") - result.result = "No response generated" + final_result = "No response generated" else: # Extract the final message - find the last AIMessage messages = final_state.get("messages", []) @@ -528,7 +571,7 @@ class SubagentExecutor: content = last_ai_message.content # Handle both str and list content types for the final result if isinstance(content, str): - result.result = content + final_result = content elif isinstance(content, list): # Extract text from list of content blocks for final result only. # Concatenate raw string chunks directly, but preserve separation @@ -547,16 +590,16 @@ class SubagentExecutor: text_parts.append(text_val) if pending_str_parts: text_parts.append("".join(pending_str_parts)) - result.result = "\n".join(text_parts) if text_parts else "No text content in response" + final_result = "\n".join(text_parts) if text_parts else "No text content in response" else: - result.result = str(content) + final_result = str(content) elif messages: # Fallback: use the last message if no AIMessage found last_message = messages[-1] logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no AIMessage found, using last message: {type(last_message)}") raw_content = last_message.content if hasattr(last_message, "content") else str(last_message) if isinstance(raw_content, str): - result.result = raw_content + final_result = raw_content elif isinstance(raw_content, list): parts = [] pending_str_parts = [] @@ -572,23 +615,29 @@ class SubagentExecutor: parts.append(text_val) if pending_str_parts: parts.append("".join(pending_str_parts)) - result.result = "\n".join(parts) if parts else "No text content in response" + final_result = "\n".join(parts) if parts else "No text content in response" else: - result.result = str(raw_content) + final_result = str(raw_content) else: logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no messages in final state") - result.result = "No response generated" + final_result = "No response generated" - result.status = SubagentStatus.COMPLETED - result.completed_at = datetime.now() + if final_result is None: + final_result = "No response generated" + + result.try_set_terminal( + SubagentStatus.COMPLETED, + result=final_result, + token_usage_records=token_usage_records, + ) except Exception as e: logger.exception(f"[trace={self.trace_id}] Subagent {self.config.name} async execution failed") - result.status = SubagentStatus.FAILED - result.error = str(e) - result.completed_at = datetime.now() - if collector is not None: - result.token_usage_records = collector.snapshot_records() + result.try_set_terminal( + SubagentStatus.FAILED, + error=str(e), + token_usage_records=collector.snapshot_records() if collector is not None else None, + ) return result @@ -667,11 +716,9 @@ class SubagentExecutor: result = SubagentResult( task_id=str(uuid.uuid4())[:8], trace_id=self.trace_id, - status=SubagentStatus.FAILED, + status=SubagentStatus.RUNNING, ) - result.status = SubagentStatus.FAILED - result.error = str(e) - result.completed_at = datetime.now() + result.try_set_terminal(SubagentStatus.FAILED, error=str(e)) return result def execute_async(self, task: str, task_id: str | None = None) -> str: @@ -718,29 +765,21 @@ class SubagentExecutor: ) try: # Wait for execution with timeout - exec_result = execution_future.result(timeout=self.config.timeout_seconds) - with _background_tasks_lock: - _background_tasks[task_id].status = exec_result.status - _background_tasks[task_id].result = exec_result.result - _background_tasks[task_id].error = exec_result.error - _background_tasks[task_id].completed_at = datetime.now() - _background_tasks[task_id].ai_messages = exec_result.ai_messages + execution_future.result(timeout=self.config.timeout_seconds) except FuturesTimeoutError: logger.error(f"[trace={self.trace_id}] Subagent {self.config.name} execution timed out after {self.config.timeout_seconds}s") - with _background_tasks_lock: - if _background_tasks[task_id].status == SubagentStatus.RUNNING: - _background_tasks[task_id].status = SubagentStatus.TIMED_OUT - _background_tasks[task_id].error = f"Execution timed out after {self.config.timeout_seconds} seconds" - _background_tasks[task_id].completed_at = datetime.now() # Signal cooperative cancellation and cancel the future result_holder.cancel_event.set() + result_holder.try_set_terminal( + SubagentStatus.TIMED_OUT, + error=f"Execution timed out after {self.config.timeout_seconds} seconds", + ) execution_future.cancel() except Exception as e: logger.exception(f"[trace={self.trace_id}] Subagent {self.config.name} async execution failed") with _background_tasks_lock: - _background_tasks[task_id].status = SubagentStatus.FAILED - _background_tasks[task_id].error = str(e) - _background_tasks[task_id].completed_at = datetime.now() + task_result = _background_tasks[task_id] + task_result.try_set_terminal(SubagentStatus.FAILED, error=str(e)) _scheduler_pool.submit(run_task) return task_id @@ -811,13 +850,7 @@ def cleanup_background_task(task_id: str) -> None: # Only clean up tasks that are in a terminal state to avoid races with # the background executor still updating the task entry. - is_terminal_status = result.status in { - SubagentStatus.COMPLETED, - SubagentStatus.FAILED, - SubagentStatus.CANCELLED, - SubagentStatus.TIMED_OUT, - } - if is_terminal_status or result.completed_at is not None: + if result.status.is_terminal or result.completed_at is not None: del _background_tasks[task_id] logger.debug("Cleaned up background task: %s", task_id) else: diff --git a/backend/tests/test_subagent_executor.py b/backend/tests/test_subagent_executor.py index 87c82ff96..8987958a8 100644 --- a/backend/tests/test_subagent_executor.py +++ b/backend/tests/test_subagent_executor.py @@ -1125,6 +1125,15 @@ class TestAsyncToolSupport: class TestThreadSafety: """Test thread safety of executor operations.""" + @pytest.fixture + def executor_module(self, _setup_executor_classes): + """Import the executor module with real classes.""" + import importlib + + from deerflow.subagents import executor + + return importlib.reload(executor) + def test_multiple_executors_in_parallel(self, classes, base_config, msg): """Test multiple executors running in parallel via thread pool.""" from concurrent.futures import ThreadPoolExecutor, as_completed @@ -1170,6 +1179,68 @@ class TestThreadSafety: assert result.status == SubagentStatus.COMPLETED assert "Result" in result.result + def test_terminal_status_is_published_after_payload_fields(self, executor_module, monkeypatch): + """Readers must not observe terminal status before terminal payload is complete.""" + SubagentResult = executor_module.SubagentResult + SubagentStatus = executor_module.SubagentStatus + + now_entered = threading.Event() + release_now = threading.Event() + completed_at = datetime(2026, 5, 1, 12, 0, 0) + writer_errors: list[BaseException] = [] + + class BlockingDateTime: + @staticmethod + def now(): + now_entered.set() + release_now.wait(timeout=5) + return completed_at + + monkeypatch.setattr(executor_module, "datetime", BlockingDateTime) + + result = SubagentResult( + task_id="test-terminal-publication-order", + trace_id="test-trace", + status=SubagentStatus.RUNNING, + ) + token_usage_records = [ + { + "source_run_id": "run-1", + "caller": "subagent:test-agent", + "input_tokens": 10, + "output_tokens": 5, + "total_tokens": 15, + } + ] + + def set_terminal(): + try: + assert result.try_set_terminal( + SubagentStatus.COMPLETED, + result="done", + token_usage_records=token_usage_records, + ) + except BaseException as exc: + writer_errors.append(exc) + + writer = threading.Thread(target=set_terminal) + writer.start() + + assert now_entered.wait(timeout=3), "try_set_terminal did not reach completed_at assignment" + assert result.completed_at is None + assert result.status == SubagentStatus.RUNNING + assert result.token_usage_records == token_usage_records + + release_now.set() + writer.join(timeout=3) + + assert not writer.is_alive(), "try_set_terminal did not finish" + assert writer_errors == [] + assert result.completed_at == completed_at + assert result.status == SubagentStatus.COMPLETED + assert result.result == "done" + assert result.token_usage_records == token_usage_records + # ----------------------------------------------------------------------------- # Cleanup Background Task Tests @@ -1604,6 +1675,69 @@ class TestCooperativeCancellation: assert result.error == "Cancelled by user" assert result.completed_at is not None + def test_late_completion_after_timeout_does_not_overwrite_timed_out(self, executor_module, classes, msg): + """Late completion from the execution worker must not overwrite TIMED_OUT.""" + SubagentExecutor = classes["SubagentExecutor"] + SubagentStatus = classes["SubagentStatus"] + + short_config = classes["SubagentConfig"]( + name="test-agent", + description="Test agent", + system_prompt="You are a test agent.", + max_turns=10, + timeout_seconds=0.05, + ) + + first_chunk_seen = threading.Event() + finish_stream = threading.Event() + execution_done = threading.Event() + + async def mock_astream(*args, **kwargs): + yield {"messages": [msg.human("Task"), msg.ai("late completion", "msg-late")]} + first_chunk_seen.set() + deadline = asyncio.get_running_loop().time() + 5 + while not finish_stream.is_set(): + if asyncio.get_running_loop().time() >= deadline: + break + await asyncio.sleep(0.001) + + mock_agent = MagicMock() + mock_agent.astream = mock_astream + + executor = SubagentExecutor( + config=short_config, + tools=[], + thread_id="test-thread", + trace_id="test-trace", + ) + original_aexecute = executor._aexecute + + async def tracked_aexecute(task, result_holder=None): + try: + return await original_aexecute(task, result_holder) + finally: + execution_done.set() + + with patch.object(executor, "_create_agent", return_value=mock_agent), patch.object(executor, "_aexecute", tracked_aexecute): + task_id = executor.execute_async("Task") + assert first_chunk_seen.wait(timeout=3), "stream did not yield initial chunk" + + result = executor_module._background_tasks[task_id] + assert result.cancel_event.wait(timeout=3), "timeout handler did not request cancellation" + assert result.status.value == SubagentStatus.TIMED_OUT.value + timed_out_error = result.error + timed_out_completed_at = result.completed_at + + finish_stream.set() + assert execution_done.wait(timeout=3), "execution worker did not finish" + + result = executor_module._background_tasks.get(task_id) + assert result is not None + assert result.status.value == SubagentStatus.TIMED_OUT.value + assert result.result is None + assert result.error == timed_out_error + assert result.completed_at == timed_out_completed_at + def test_cleanup_removes_cancelled_task(self, executor_module, classes): """Test that cleanup removes a CANCELLED task (terminal state).""" SubagentResult = classes["SubagentResult"]