fix(middleware): fix LLM fallback run status (#3321)

* Fix LLM fallback run status

* optimize LLM fallback maker extraction in streaming path
This commit is contained in:
Nan Gao
2026-05-31 16:42:13 +02:00
committed by GitHub
parent 9f3be2a9fa
commit 79cc227917
5 changed files with 362 additions and 5 deletions
@@ -177,6 +177,24 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
def _build_circuit_breaker_message(self) -> str:
return "The configured LLM provider is currently unavailable due to continuous failures. Circuit breaker is engaged to protect the system. Please wait a moment before trying again."
def _build_error_fallback_message(
self,
content: str,
*,
error_type: str,
reason: str,
detail: str,
) -> AIMessage:
return AIMessage(
content=content,
additional_kwargs={
"deerflow_error_fallback": True,
"error_type": error_type,
"error_reason": reason,
"error_detail": detail,
},
)
def _build_user_message(self, exc: BaseException, reason: str) -> str:
detail = _extract_error_detail(exc)
if reason == "quota":
@@ -187,6 +205,14 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
return "The configured LLM provider is temporarily unavailable after multiple retries. Please wait a moment and continue the conversation."
return f"LLM request failed: {detail}"
def _build_user_fallback_message(self, exc: BaseException, reason: str) -> AIMessage:
return self._build_error_fallback_message(
self._build_user_message(exc, reason),
error_type=type(exc).__name__,
reason=reason,
detail=_extract_error_detail(exc),
)
def _emit_retry_event(self, attempt: int, wait_ms: int, reason: str) -> None:
try:
from langgraph.config import get_stream_writer
@@ -212,7 +238,12 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
if self._check_circuit():
return AIMessage(content=self._build_circuit_breaker_message())
return self._build_error_fallback_message(
self._build_circuit_breaker_message(),
error_type="CircuitBreakerOpen",
reason="circuit_open",
detail="LLM circuit breaker is open",
)
attempt = 1
while True:
@@ -249,7 +280,7 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
)
if retriable:
self._record_failure()
return AIMessage(content=self._build_user_message(exc, reason))
return self._build_user_fallback_message(exc, reason)
@override
async def awrap_model_call(
@@ -258,7 +289,12 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
if self._check_circuit():
return AIMessage(content=self._build_circuit_breaker_message())
return self._build_error_fallback_message(
self._build_circuit_breaker_message(),
error_type="CircuitBreakerOpen",
reason="circuit_open",
detail="LLM circuit breaker is open",
)
attempt = 1
while True:
@@ -295,7 +331,7 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
)
if retriable:
self._record_failure()
return AIMessage(content=self._build_user_message(exc, reason))
return self._build_user_fallback_message(exc, reason)
def _matches_any(detail: str, patterns: tuple[str, ...]) -> bool:
@@ -86,6 +86,8 @@ class RunJournal(BaseCallbackHandler):
self._last_ai_msg: str | None = None
self._first_human_msg: str | None = None
self._msg_count = 0
self._had_llm_error_fallback = False
self._llm_error_fallback_message: str | None = None
# Latency tracking
self._llm_start_times: dict[str, float] = {} # langchain run_id -> start time
@@ -256,6 +258,18 @@ class RunJournal(BaseCallbackHandler):
# Token usage from message
usage = getattr(message, "usage_metadata", None)
usage_dict = dict(usage) if usage else {}
additional_kwargs = getattr(message, "additional_kwargs", None) or {}
if isinstance(additional_kwargs, dict) and additional_kwargs.get("deerflow_error_fallback"):
self._had_llm_error_fallback = True
detail = additional_kwargs.get("error_detail")
reason = additional_kwargs.get("error_reason")
fallback_text = self._message_text(message).strip()
if isinstance(detail, str) and detail.strip():
self._llm_error_fallback_message = detail.strip()
elif isinstance(reason, str) and reason.strip():
self._llm_error_fallback_message = reason.strip()
elif fallback_text:
self._llm_error_fallback_message = fallback_text[:2000]
# Resolve call index
call_index = self._llm_call_index
@@ -569,3 +583,11 @@ class RunJournal(BaseCallbackHandler):
"last_ai_message": self._last_ai_msg,
"first_human_message": self._first_human_msg,
}
@property
def had_llm_error_fallback(self) -> bool:
return self._had_llm_error_fallback
@property
def llm_error_fallback_message(self) -> str | None:
return self._llm_error_fallback_message
@@ -150,6 +150,7 @@ async def run_agent(
pre_run_checkpoint_id: str | None = None
pre_run_snapshot: dict[str, Any] | None = None
snapshot_capture_failed = False
llm_error_fallback_message: str | None = None
journal = None
@@ -312,6 +313,7 @@ async def run_agent(
if record.abort_event.is_set():
logger.info("Run %s abort requested — stopping", run_id)
break
llm_error_fallback_message = llm_error_fallback_message or _extract_llm_error_fallback_message(chunk)
sse_event = _lg_mode_to_sse_event(single_mode)
await bridge.publish(run_id, sse_event, serialize(chunk, mode=single_mode))
else:
@@ -330,6 +332,7 @@ async def run_agent(
if mode is None:
continue
llm_error_fallback_message = llm_error_fallback_message or _extract_llm_error_fallback_message(chunk)
sse_event = _lg_mode_to_sse_event(mode)
await bridge.publish(run_id, sse_event, serialize(chunk, mode=mode))
@@ -352,6 +355,12 @@ async def run_agent(
logger.warning("Failed to rollback checkpoint for run %s", run_id, exc_info=True)
else:
await run_manager.set_status(run_id, RunStatus.interrupted)
elif llm_error_fallback_message or (journal is not None and journal.had_llm_error_fallback):
error_msg = llm_error_fallback_message
if error_msg is None and journal is not None:
error_msg = journal.llm_error_fallback_message
error_msg = error_msg or "LLM provider failed after retries"
await run_manager.set_status(run_id, RunStatus.error, error=error_msg)
else:
await run_manager.set_status(run_id, RunStatus.success)
@@ -554,6 +563,85 @@ def _lg_mode_to_sse_event(mode: str) -> str:
return mode
def _error_fallback_message_from_metadata(metadata: dict[str, Any], content: Any) -> str:
detail = metadata.get("error_detail")
if isinstance(detail, str) and detail.strip():
return detail.strip()
reason = metadata.get("error_reason")
if isinstance(reason, str) and reason.strip():
return reason.strip()
if isinstance(content, str) and content.strip():
return content.strip()[:2000]
return "LLM provider failed after retries"
def _try_extract_from_message(obj: Any) -> str | None:
"""Try to extract fallback marker from a single message object or dict."""
additional_kwargs = getattr(obj, "additional_kwargs", None)
if isinstance(additional_kwargs, dict) and additional_kwargs.get("deerflow_error_fallback"):
return _error_fallback_message_from_metadata(additional_kwargs, getattr(obj, "content", None))
if isinstance(obj, dict):
nested_kwargs = obj.get("additional_kwargs")
if isinstance(nested_kwargs, dict) and nested_kwargs.get("deerflow_error_fallback"):
return _error_fallback_message_from_metadata(nested_kwargs, obj.get("content"))
return None
def _extract_llm_error_fallback_message(value: Any) -> str | None:
"""Find LLM fallback markers in streamed LangGraph chunks.
Error fallback messages returned by model-call middleware are not guaranteed
to pass through LLM end callbacks, but they do appear in graph state chunks.
"""
# Fast path: large state chunks produced by stream_mode="values" have a
# top-level "messages" list. Scanning only that list avoids expensive deep
# recursion into large state dicts.
if isinstance(value, dict):
messages = value.get("messages")
if isinstance(messages, (list, tuple)):
for msg in messages:
result = _try_extract_from_message(msg)
if result is not None:
return result
# Fallback marker is attached to an AI message in the messages
# channel; it will never appear elsewhere in a values chunk.
return None
# No top-level "messages" — this is likely an "updates" chunk (small
# dict keyed by node name). Fall through to deep walk, which is cheap
# for these payloads.
# Deep walk for updates / messages / tuple / list modes. Payloads are
# small, so full recursion is acceptable here.
seen: set[int] = set()
def walk(obj: Any) -> str | None:
oid = id(obj)
if oid in seen:
return None
seen.add(oid)
result = _try_extract_from_message(obj)
if result is not None:
return result
if isinstance(obj, dict):
for item in obj.values():
result = walk(item)
if result is not None:
return result
return None
if isinstance(obj, (list, tuple, set)):
for item in obj:
result = walk(item)
if result is not None:
return result
return None
return walk(value)
def _extract_human_message(graph_input: dict) -> HumanMessage | None:
"""Extract or construct a HumanMessage from graph_input for event recording.