fix(middleware): fix LLM fallback run status (#3321)

* Fix LLM fallback run status

* optimize LLM fallback maker extraction in streaming path
This commit is contained in:
Nan Gao
2026-05-31 16:42:13 +02:00
committed by GitHub
parent 9f3be2a9fa
commit 79cc227917
5 changed files with 362 additions and 5 deletions
@@ -94,6 +94,31 @@ def test_async_model_call_returns_user_message_for_quota_errors() -> None:
assert isinstance(result, AIMessage)
assert "out of quota" in str(result.content)
assert result.additional_kwargs["deerflow_error_fallback"] is True
assert result.additional_kwargs["error_reason"] == "quota"
assert result.additional_kwargs["error_type"] == "FakeError"
def test_async_model_call_marks_transient_retry_exhaustion_as_error_fallback(
monkeypatch: pytest.MonkeyPatch,
) -> None:
middleware = _build_middleware(retry_max_attempts=2, retry_base_delay_ms=25, retry_cap_delay_ms=25)
async def fake_sleep(_delay: float) -> None:
return None
async def handler(_request) -> AIMessage:
raise FakeError("Connection error.", status_code=503)
monkeypatch.setattr("asyncio.sleep", fake_sleep)
result = asyncio.run(middleware.awrap_model_call(SimpleNamespace(), handler))
assert isinstance(result, AIMessage)
assert "temporarily unavailable" in str(result.content)
assert result.additional_kwargs["deerflow_error_fallback"] is True
assert result.additional_kwargs["error_reason"] == "transient"
assert result.additional_kwargs["error_detail"] == "Connection error."
def test_sync_model_call_uses_retry_after_header(monkeypatch: pytest.MonkeyPatch) -> None:
+187 -1
View File
@@ -3,12 +3,22 @@ from types import SimpleNamespace
from unittest.mock import AsyncMock, call
import pytest
from langchain_core.messages import AIMessage
from langgraph.checkpoint.base import empty_checkpoint
from langgraph.checkpoint.memory import InMemorySaver
from deerflow.runtime.runs.manager import RunManager
from deerflow.runtime.runs.schemas import RunStatus
from deerflow.runtime.runs.worker import RunContext, _agent_factory_supports_app_config, _build_runtime_context, _install_runtime_context, _rollback_to_pre_run_checkpoint, run_agent
from deerflow.runtime.runs.worker import (
RunContext,
_agent_factory_supports_app_config,
_build_runtime_context,
_extract_llm_error_fallback_message,
_install_runtime_context,
_rollback_to_pre_run_checkpoint,
_try_extract_from_message,
run_agent,
)
class FakeCheckpointer:
@@ -95,6 +105,52 @@ async def test_run_agent_threads_explicit_app_config_into_config_only_factory():
bridge.cleanup.assert_awaited_once_with(record.run_id, delay=60)
@pytest.mark.anyio
async def test_run_agent_marks_llm_error_fallback_as_error_status():
run_manager = RunManager()
record = await run_manager.create("thread-1")
bridge = SimpleNamespace(
publish=AsyncMock(),
publish_end=AsyncMock(),
cleanup=AsyncMock(),
)
class DummyAgent:
async def astream(self, graph_input, config=None, stream_mode=None, subgraphs=False):
yield {
"messages": [
AIMessage(
content="The configured LLM provider is temporarily unavailable after multiple retries.",
additional_kwargs={
"deerflow_error_fallback": True,
"error_type": "APIConnectionError",
"error_reason": "transient",
"error_detail": "Connection error.",
},
)
]
}
def factory(*, config):
return DummyAgent()
await run_agent(
bridge,
run_manager,
record,
ctx=RunContext(checkpointer=None),
agent_factory=factory,
graph_input={},
config={},
)
fetched = await run_manager.get(record.run_id)
assert fetched is not None
assert fetched.status == RunStatus.error
assert fetched.error == "Connection error."
bridge.publish_end.assert_awaited_once_with(record.run_id)
@pytest.mark.anyio
async def test_run_agent_defaults_root_run_name_from_assistant_id():
run_manager = RunManager()
@@ -486,3 +542,133 @@ def test_agent_factory_supports_app_config_returns_false_when_signature_lookup_f
monkeypatch.setattr("deerflow.runtime.runs.worker.inspect.signature", lambda _obj: (_ for _ in ()).throw(ValueError("boom")))
assert _agent_factory_supports_app_config(BrokenCallable()) is False
# ---------------------------------------------------------------------------
# _extract_llm_error_fallback_message coverage
# ---------------------------------------------------------------------------
def test_try_extract_from_message_finds_fallback_on_message_object():
msg = AIMessage(
content="fallback",
additional_kwargs={
"deerflow_error_fallback": True,
"error_detail": "Connection error.",
"error_reason": "transient",
},
)
assert _try_extract_from_message(msg) == "Connection error."
def test_try_extract_from_message_finds_fallback_on_dict():
msg = {
"content": "fallback",
"additional_kwargs": {
"deerflow_error_fallback": True,
"error_detail": "Quota exceeded.",
},
}
assert _try_extract_from_message(msg) == "Quota exceeded."
def test_try_extract_from_message_returns_none_for_normal_message():
msg = AIMessage(content="hello")
assert _try_extract_from_message(msg) is None
def test_extract_llm_error_fallback_message_large_state_chunk_no_fallback():
"""Normal-size state dict without fallback markers must not raise and should return None."""
large_state = {
"messages": [
AIMessage(content="Hello!"),
{"role": "user", "content": "Hi there"},
],
"foo": "x" * 10_000,
"bar": {"nested": {"deep": {"data": list(range(1000))}}},
"baz": [{"id": i, "payload": "y" * 1000} for i in range(500)],
}
assert _extract_llm_error_fallback_message(large_state) is None
def test_extract_llm_error_fallback_message_finds_fallback_in_messages_list():
state = {
"messages": [
AIMessage(content="Hello!"),
AIMessage(
content="Unavailable.",
additional_kwargs={
"deerflow_error_fallback": True,
"error_detail": "Connection error.",
},
),
],
"other_state": "large_value" * 1000,
}
assert _extract_llm_error_fallback_message(state) == "Connection error."
def test_extract_llm_error_fallback_message_finds_fallback_in_raw_message():
msg = AIMessage(
content="Unavailable.",
additional_kwargs={
"deerflow_error_fallback": True,
"error_reason": "quota",
},
)
assert _extract_llm_error_fallback_message(msg) == "quota"
def test_extract_llm_error_fallback_message_finds_fallback_in_tuple():
item = (
"messages",
AIMessage(
content="Unavailable.",
additional_kwargs={
"deerflow_error_fallback": True,
"error_detail": "Circuit open.",
},
),
)
assert _extract_llm_error_fallback_message(item) == "Circuit open."
def test_extract_llm_error_fallback_message_returns_none_for_empty_values():
assert _extract_llm_error_fallback_message({}) is None
assert _extract_llm_error_fallback_message([]) is None
assert _extract_llm_error_fallback_message(None) is None
assert _extract_llm_error_fallback_message("string") is None
def test_extract_llm_error_fallback_message_finds_fallback_in_updates_mode():
"""stream_mode='updates' yields dicts keyed by node name (e.g. {'call_model': {...}}).
Fallback marker is nested inside the node's state update, not at the top level."""
update_chunk = {
"call_model": {
"messages": [
AIMessage(
content="Unavailable.",
additional_kwargs={
"deerflow_error_fallback": True,
"error_detail": "Connection error.",
},
)
]
}
}
assert _extract_llm_error_fallback_message(update_chunk) == "Connection error."
def test_extract_llm_error_fallback_message_updates_mode_no_fallback():
"""Normal updates chunk without any fallback should return None safely."""
update_chunk = {
"__interrupt__": [
{
"value": "ask_human",
"resumable": True,
"ns": ["agent"],
"when": "during",
}
]
}
assert _extract_llm_error_fallback_message(update_chunk) is None