mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-21 23:46:50 +00:00
fix(subagents): make subagent timeout terminal state atomic (#2583)
* Guard subagent terminal state transitions * fix: publish subagent terminal status last * Fix subagent timeout test to avoid blocking event loop * Fix subagent timeout test tracking * Refine subagent terminal state handling --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -1125,6 +1125,15 @@ class TestAsyncToolSupport:
|
||||
class TestThreadSafety:
|
||||
"""Test thread safety of executor operations."""
|
||||
|
||||
@pytest.fixture
|
||||
def executor_module(self, _setup_executor_classes):
|
||||
"""Import the executor module with real classes."""
|
||||
import importlib
|
||||
|
||||
from deerflow.subagents import executor
|
||||
|
||||
return importlib.reload(executor)
|
||||
|
||||
def test_multiple_executors_in_parallel(self, classes, base_config, msg):
|
||||
"""Test multiple executors running in parallel via thread pool."""
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
@@ -1170,6 +1179,68 @@ class TestThreadSafety:
|
||||
assert result.status == SubagentStatus.COMPLETED
|
||||
assert "Result" in result.result
|
||||
|
||||
def test_terminal_status_is_published_after_payload_fields(self, executor_module, monkeypatch):
|
||||
"""Readers must not observe terminal status before terminal payload is complete."""
|
||||
SubagentResult = executor_module.SubagentResult
|
||||
SubagentStatus = executor_module.SubagentStatus
|
||||
|
||||
now_entered = threading.Event()
|
||||
release_now = threading.Event()
|
||||
completed_at = datetime(2026, 5, 1, 12, 0, 0)
|
||||
writer_errors: list[BaseException] = []
|
||||
|
||||
class BlockingDateTime:
|
||||
@staticmethod
|
||||
def now():
|
||||
now_entered.set()
|
||||
release_now.wait(timeout=5)
|
||||
return completed_at
|
||||
|
||||
monkeypatch.setattr(executor_module, "datetime", BlockingDateTime)
|
||||
|
||||
result = SubagentResult(
|
||||
task_id="test-terminal-publication-order",
|
||||
trace_id="test-trace",
|
||||
status=SubagentStatus.RUNNING,
|
||||
)
|
||||
token_usage_records = [
|
||||
{
|
||||
"source_run_id": "run-1",
|
||||
"caller": "subagent:test-agent",
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 5,
|
||||
"total_tokens": 15,
|
||||
}
|
||||
]
|
||||
|
||||
def set_terminal():
|
||||
try:
|
||||
assert result.try_set_terminal(
|
||||
SubagentStatus.COMPLETED,
|
||||
result="done",
|
||||
token_usage_records=token_usage_records,
|
||||
)
|
||||
except BaseException as exc:
|
||||
writer_errors.append(exc)
|
||||
|
||||
writer = threading.Thread(target=set_terminal)
|
||||
writer.start()
|
||||
|
||||
assert now_entered.wait(timeout=3), "try_set_terminal did not reach completed_at assignment"
|
||||
assert result.completed_at is None
|
||||
assert result.status == SubagentStatus.RUNNING
|
||||
assert result.token_usage_records == token_usage_records
|
||||
|
||||
release_now.set()
|
||||
writer.join(timeout=3)
|
||||
|
||||
assert not writer.is_alive(), "try_set_terminal did not finish"
|
||||
assert writer_errors == []
|
||||
assert result.completed_at == completed_at
|
||||
assert result.status == SubagentStatus.COMPLETED
|
||||
assert result.result == "done"
|
||||
assert result.token_usage_records == token_usage_records
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Cleanup Background Task Tests
|
||||
@@ -1604,6 +1675,69 @@ class TestCooperativeCancellation:
|
||||
assert result.error == "Cancelled by user"
|
||||
assert result.completed_at is not None
|
||||
|
||||
def test_late_completion_after_timeout_does_not_overwrite_timed_out(self, executor_module, classes, msg):
|
||||
"""Late completion from the execution worker must not overwrite TIMED_OUT."""
|
||||
SubagentExecutor = classes["SubagentExecutor"]
|
||||
SubagentStatus = classes["SubagentStatus"]
|
||||
|
||||
short_config = classes["SubagentConfig"](
|
||||
name="test-agent",
|
||||
description="Test agent",
|
||||
system_prompt="You are a test agent.",
|
||||
max_turns=10,
|
||||
timeout_seconds=0.05,
|
||||
)
|
||||
|
||||
first_chunk_seen = threading.Event()
|
||||
finish_stream = threading.Event()
|
||||
execution_done = threading.Event()
|
||||
|
||||
async def mock_astream(*args, **kwargs):
|
||||
yield {"messages": [msg.human("Task"), msg.ai("late completion", "msg-late")]}
|
||||
first_chunk_seen.set()
|
||||
deadline = asyncio.get_running_loop().time() + 5
|
||||
while not finish_stream.is_set():
|
||||
if asyncio.get_running_loop().time() >= deadline:
|
||||
break
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.astream = mock_astream
|
||||
|
||||
executor = SubagentExecutor(
|
||||
config=short_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
trace_id="test-trace",
|
||||
)
|
||||
original_aexecute = executor._aexecute
|
||||
|
||||
async def tracked_aexecute(task, result_holder=None):
|
||||
try:
|
||||
return await original_aexecute(task, result_holder)
|
||||
finally:
|
||||
execution_done.set()
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent), patch.object(executor, "_aexecute", tracked_aexecute):
|
||||
task_id = executor.execute_async("Task")
|
||||
assert first_chunk_seen.wait(timeout=3), "stream did not yield initial chunk"
|
||||
|
||||
result = executor_module._background_tasks[task_id]
|
||||
assert result.cancel_event.wait(timeout=3), "timeout handler did not request cancellation"
|
||||
assert result.status.value == SubagentStatus.TIMED_OUT.value
|
||||
timed_out_error = result.error
|
||||
timed_out_completed_at = result.completed_at
|
||||
|
||||
finish_stream.set()
|
||||
assert execution_done.wait(timeout=3), "execution worker did not finish"
|
||||
|
||||
result = executor_module._background_tasks.get(task_id)
|
||||
assert result is not None
|
||||
assert result.status.value == SubagentStatus.TIMED_OUT.value
|
||||
assert result.result is None
|
||||
assert result.error == timed_out_error
|
||||
assert result.completed_at == timed_out_completed_at
|
||||
|
||||
def test_cleanup_removes_cancelled_task(self, executor_module, classes):
|
||||
"""Test that cleanup removes a CANCELLED task (terminal state)."""
|
||||
SubagentResult = classes["SubagentResult"]
|
||||
|
||||
Reference in New Issue
Block a user