mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-10 17:35:57 +00:00
fix(middleware): fix LLM fallback run status (#3321)
* Fix LLM fallback run status * optimize LLM fallback maker extraction in streaming path
This commit is contained in:
+40
-4
@@ -177,6 +177,24 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
|||||||
def _build_circuit_breaker_message(self) -> str:
|
def _build_circuit_breaker_message(self) -> str:
|
||||||
return "The configured LLM provider is currently unavailable due to continuous failures. Circuit breaker is engaged to protect the system. Please wait a moment before trying again."
|
return "The configured LLM provider is currently unavailable due to continuous failures. Circuit breaker is engaged to protect the system. Please wait a moment before trying again."
|
||||||
|
|
||||||
|
def _build_error_fallback_message(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
*,
|
||||||
|
error_type: str,
|
||||||
|
reason: str,
|
||||||
|
detail: str,
|
||||||
|
) -> AIMessage:
|
||||||
|
return AIMessage(
|
||||||
|
content=content,
|
||||||
|
additional_kwargs={
|
||||||
|
"deerflow_error_fallback": True,
|
||||||
|
"error_type": error_type,
|
||||||
|
"error_reason": reason,
|
||||||
|
"error_detail": detail,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def _build_user_message(self, exc: BaseException, reason: str) -> str:
|
def _build_user_message(self, exc: BaseException, reason: str) -> str:
|
||||||
detail = _extract_error_detail(exc)
|
detail = _extract_error_detail(exc)
|
||||||
if reason == "quota":
|
if reason == "quota":
|
||||||
@@ -187,6 +205,14 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
|||||||
return "The configured LLM provider is temporarily unavailable after multiple retries. Please wait a moment and continue the conversation."
|
return "The configured LLM provider is temporarily unavailable after multiple retries. Please wait a moment and continue the conversation."
|
||||||
return f"LLM request failed: {detail}"
|
return f"LLM request failed: {detail}"
|
||||||
|
|
||||||
|
def _build_user_fallback_message(self, exc: BaseException, reason: str) -> AIMessage:
|
||||||
|
return self._build_error_fallback_message(
|
||||||
|
self._build_user_message(exc, reason),
|
||||||
|
error_type=type(exc).__name__,
|
||||||
|
reason=reason,
|
||||||
|
detail=_extract_error_detail(exc),
|
||||||
|
)
|
||||||
|
|
||||||
def _emit_retry_event(self, attempt: int, wait_ms: int, reason: str) -> None:
|
def _emit_retry_event(self, attempt: int, wait_ms: int, reason: str) -> None:
|
||||||
try:
|
try:
|
||||||
from langgraph.config import get_stream_writer
|
from langgraph.config import get_stream_writer
|
||||||
@@ -212,7 +238,12 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
|||||||
handler: Callable[[ModelRequest], ModelResponse],
|
handler: Callable[[ModelRequest], ModelResponse],
|
||||||
) -> ModelCallResult:
|
) -> ModelCallResult:
|
||||||
if self._check_circuit():
|
if self._check_circuit():
|
||||||
return AIMessage(content=self._build_circuit_breaker_message())
|
return self._build_error_fallback_message(
|
||||||
|
self._build_circuit_breaker_message(),
|
||||||
|
error_type="CircuitBreakerOpen",
|
||||||
|
reason="circuit_open",
|
||||||
|
detail="LLM circuit breaker is open",
|
||||||
|
)
|
||||||
|
|
||||||
attempt = 1
|
attempt = 1
|
||||||
while True:
|
while True:
|
||||||
@@ -249,7 +280,7 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
|||||||
)
|
)
|
||||||
if retriable:
|
if retriable:
|
||||||
self._record_failure()
|
self._record_failure()
|
||||||
return AIMessage(content=self._build_user_message(exc, reason))
|
return self._build_user_fallback_message(exc, reason)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def awrap_model_call(
|
async def awrap_model_call(
|
||||||
@@ -258,7 +289,12 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
|||||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||||
) -> ModelCallResult:
|
) -> ModelCallResult:
|
||||||
if self._check_circuit():
|
if self._check_circuit():
|
||||||
return AIMessage(content=self._build_circuit_breaker_message())
|
return self._build_error_fallback_message(
|
||||||
|
self._build_circuit_breaker_message(),
|
||||||
|
error_type="CircuitBreakerOpen",
|
||||||
|
reason="circuit_open",
|
||||||
|
detail="LLM circuit breaker is open",
|
||||||
|
)
|
||||||
|
|
||||||
attempt = 1
|
attempt = 1
|
||||||
while True:
|
while True:
|
||||||
@@ -295,7 +331,7 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
|||||||
)
|
)
|
||||||
if retriable:
|
if retriable:
|
||||||
self._record_failure()
|
self._record_failure()
|
||||||
return AIMessage(content=self._build_user_message(exc, reason))
|
return self._build_user_fallback_message(exc, reason)
|
||||||
|
|
||||||
|
|
||||||
def _matches_any(detail: str, patterns: tuple[str, ...]) -> bool:
|
def _matches_any(detail: str, patterns: tuple[str, ...]) -> bool:
|
||||||
|
|||||||
@@ -86,6 +86,8 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
self._last_ai_msg: str | None = None
|
self._last_ai_msg: str | None = None
|
||||||
self._first_human_msg: str | None = None
|
self._first_human_msg: str | None = None
|
||||||
self._msg_count = 0
|
self._msg_count = 0
|
||||||
|
self._had_llm_error_fallback = False
|
||||||
|
self._llm_error_fallback_message: str | None = None
|
||||||
|
|
||||||
# Latency tracking
|
# Latency tracking
|
||||||
self._llm_start_times: dict[str, float] = {} # langchain run_id -> start time
|
self._llm_start_times: dict[str, float] = {} # langchain run_id -> start time
|
||||||
@@ -256,6 +258,18 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
# Token usage from message
|
# Token usage from message
|
||||||
usage = getattr(message, "usage_metadata", None)
|
usage = getattr(message, "usage_metadata", None)
|
||||||
usage_dict = dict(usage) if usage else {}
|
usage_dict = dict(usage) if usage else {}
|
||||||
|
additional_kwargs = getattr(message, "additional_kwargs", None) or {}
|
||||||
|
if isinstance(additional_kwargs, dict) and additional_kwargs.get("deerflow_error_fallback"):
|
||||||
|
self._had_llm_error_fallback = True
|
||||||
|
detail = additional_kwargs.get("error_detail")
|
||||||
|
reason = additional_kwargs.get("error_reason")
|
||||||
|
fallback_text = self._message_text(message).strip()
|
||||||
|
if isinstance(detail, str) and detail.strip():
|
||||||
|
self._llm_error_fallback_message = detail.strip()
|
||||||
|
elif isinstance(reason, str) and reason.strip():
|
||||||
|
self._llm_error_fallback_message = reason.strip()
|
||||||
|
elif fallback_text:
|
||||||
|
self._llm_error_fallback_message = fallback_text[:2000]
|
||||||
|
|
||||||
# Resolve call index
|
# Resolve call index
|
||||||
call_index = self._llm_call_index
|
call_index = self._llm_call_index
|
||||||
@@ -569,3 +583,11 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
"last_ai_message": self._last_ai_msg,
|
"last_ai_message": self._last_ai_msg,
|
||||||
"first_human_message": self._first_human_msg,
|
"first_human_message": self._first_human_msg,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def had_llm_error_fallback(self) -> bool:
|
||||||
|
return self._had_llm_error_fallback
|
||||||
|
|
||||||
|
@property
|
||||||
|
def llm_error_fallback_message(self) -> str | None:
|
||||||
|
return self._llm_error_fallback_message
|
||||||
|
|||||||
@@ -150,6 +150,7 @@ async def run_agent(
|
|||||||
pre_run_checkpoint_id: str | None = None
|
pre_run_checkpoint_id: str | None = None
|
||||||
pre_run_snapshot: dict[str, Any] | None = None
|
pre_run_snapshot: dict[str, Any] | None = None
|
||||||
snapshot_capture_failed = False
|
snapshot_capture_failed = False
|
||||||
|
llm_error_fallback_message: str | None = None
|
||||||
|
|
||||||
journal = None
|
journal = None
|
||||||
|
|
||||||
@@ -312,6 +313,7 @@ async def run_agent(
|
|||||||
if record.abort_event.is_set():
|
if record.abort_event.is_set():
|
||||||
logger.info("Run %s abort requested — stopping", run_id)
|
logger.info("Run %s abort requested — stopping", run_id)
|
||||||
break
|
break
|
||||||
|
llm_error_fallback_message = llm_error_fallback_message or _extract_llm_error_fallback_message(chunk)
|
||||||
sse_event = _lg_mode_to_sse_event(single_mode)
|
sse_event = _lg_mode_to_sse_event(single_mode)
|
||||||
await bridge.publish(run_id, sse_event, serialize(chunk, mode=single_mode))
|
await bridge.publish(run_id, sse_event, serialize(chunk, mode=single_mode))
|
||||||
else:
|
else:
|
||||||
@@ -330,6 +332,7 @@ async def run_agent(
|
|||||||
if mode is None:
|
if mode is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
llm_error_fallback_message = llm_error_fallback_message or _extract_llm_error_fallback_message(chunk)
|
||||||
sse_event = _lg_mode_to_sse_event(mode)
|
sse_event = _lg_mode_to_sse_event(mode)
|
||||||
await bridge.publish(run_id, sse_event, serialize(chunk, mode=mode))
|
await bridge.publish(run_id, sse_event, serialize(chunk, mode=mode))
|
||||||
|
|
||||||
@@ -352,6 +355,12 @@ async def run_agent(
|
|||||||
logger.warning("Failed to rollback checkpoint for run %s", run_id, exc_info=True)
|
logger.warning("Failed to rollback checkpoint for run %s", run_id, exc_info=True)
|
||||||
else:
|
else:
|
||||||
await run_manager.set_status(run_id, RunStatus.interrupted)
|
await run_manager.set_status(run_id, RunStatus.interrupted)
|
||||||
|
elif llm_error_fallback_message or (journal is not None and journal.had_llm_error_fallback):
|
||||||
|
error_msg = llm_error_fallback_message
|
||||||
|
if error_msg is None and journal is not None:
|
||||||
|
error_msg = journal.llm_error_fallback_message
|
||||||
|
error_msg = error_msg or "LLM provider failed after retries"
|
||||||
|
await run_manager.set_status(run_id, RunStatus.error, error=error_msg)
|
||||||
else:
|
else:
|
||||||
await run_manager.set_status(run_id, RunStatus.success)
|
await run_manager.set_status(run_id, RunStatus.success)
|
||||||
|
|
||||||
@@ -554,6 +563,85 @@ def _lg_mode_to_sse_event(mode: str) -> str:
|
|||||||
return mode
|
return mode
|
||||||
|
|
||||||
|
|
||||||
|
def _error_fallback_message_from_metadata(metadata: dict[str, Any], content: Any) -> str:
|
||||||
|
detail = metadata.get("error_detail")
|
||||||
|
if isinstance(detail, str) and detail.strip():
|
||||||
|
return detail.strip()
|
||||||
|
reason = metadata.get("error_reason")
|
||||||
|
if isinstance(reason, str) and reason.strip():
|
||||||
|
return reason.strip()
|
||||||
|
if isinstance(content, str) and content.strip():
|
||||||
|
return content.strip()[:2000]
|
||||||
|
return "LLM provider failed after retries"
|
||||||
|
|
||||||
|
|
||||||
|
def _try_extract_from_message(obj: Any) -> str | None:
|
||||||
|
"""Try to extract fallback marker from a single message object or dict."""
|
||||||
|
additional_kwargs = getattr(obj, "additional_kwargs", None)
|
||||||
|
if isinstance(additional_kwargs, dict) and additional_kwargs.get("deerflow_error_fallback"):
|
||||||
|
return _error_fallback_message_from_metadata(additional_kwargs, getattr(obj, "content", None))
|
||||||
|
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
nested_kwargs = obj.get("additional_kwargs")
|
||||||
|
if isinstance(nested_kwargs, dict) and nested_kwargs.get("deerflow_error_fallback"):
|
||||||
|
return _error_fallback_message_from_metadata(nested_kwargs, obj.get("content"))
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_llm_error_fallback_message(value: Any) -> str | None:
|
||||||
|
"""Find LLM fallback markers in streamed LangGraph chunks.
|
||||||
|
|
||||||
|
Error fallback messages returned by model-call middleware are not guaranteed
|
||||||
|
to pass through LLM end callbacks, but they do appear in graph state chunks.
|
||||||
|
"""
|
||||||
|
# Fast path: large state chunks produced by stream_mode="values" have a
|
||||||
|
# top-level "messages" list. Scanning only that list avoids expensive deep
|
||||||
|
# recursion into large state dicts.
|
||||||
|
if isinstance(value, dict):
|
||||||
|
messages = value.get("messages")
|
||||||
|
if isinstance(messages, (list, tuple)):
|
||||||
|
for msg in messages:
|
||||||
|
result = _try_extract_from_message(msg)
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
# Fallback marker is attached to an AI message in the messages
|
||||||
|
# channel; it will never appear elsewhere in a values chunk.
|
||||||
|
return None
|
||||||
|
# No top-level "messages" — this is likely an "updates" chunk (small
|
||||||
|
# dict keyed by node name). Fall through to deep walk, which is cheap
|
||||||
|
# for these payloads.
|
||||||
|
|
||||||
|
# Deep walk for updates / messages / tuple / list modes. Payloads are
|
||||||
|
# small, so full recursion is acceptable here.
|
||||||
|
seen: set[int] = set()
|
||||||
|
|
||||||
|
def walk(obj: Any) -> str | None:
|
||||||
|
oid = id(obj)
|
||||||
|
if oid in seen:
|
||||||
|
return None
|
||||||
|
seen.add(oid)
|
||||||
|
|
||||||
|
result = _try_extract_from_message(obj)
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
for item in obj.values():
|
||||||
|
result = walk(item)
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(obj, (list, tuple, set)):
|
||||||
|
for item in obj:
|
||||||
|
result = walk(item)
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
return None
|
||||||
|
|
||||||
|
return walk(value)
|
||||||
|
|
||||||
|
|
||||||
def _extract_human_message(graph_input: dict) -> HumanMessage | None:
|
def _extract_human_message(graph_input: dict) -> HumanMessage | None:
|
||||||
"""Extract or construct a HumanMessage from graph_input for event recording.
|
"""Extract or construct a HumanMessage from graph_input for event recording.
|
||||||
|
|
||||||
|
|||||||
@@ -94,6 +94,31 @@ def test_async_model_call_returns_user_message_for_quota_errors() -> None:
|
|||||||
|
|
||||||
assert isinstance(result, AIMessage)
|
assert isinstance(result, AIMessage)
|
||||||
assert "out of quota" in str(result.content)
|
assert "out of quota" in str(result.content)
|
||||||
|
assert result.additional_kwargs["deerflow_error_fallback"] is True
|
||||||
|
assert result.additional_kwargs["error_reason"] == "quota"
|
||||||
|
assert result.additional_kwargs["error_type"] == "FakeError"
|
||||||
|
|
||||||
|
|
||||||
|
def test_async_model_call_marks_transient_retry_exhaustion_as_error_fallback(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
middleware = _build_middleware(retry_max_attempts=2, retry_base_delay_ms=25, retry_cap_delay_ms=25)
|
||||||
|
|
||||||
|
async def fake_sleep(_delay: float) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def handler(_request) -> AIMessage:
|
||||||
|
raise FakeError("Connection error.", status_code=503)
|
||||||
|
|
||||||
|
monkeypatch.setattr("asyncio.sleep", fake_sleep)
|
||||||
|
|
||||||
|
result = asyncio.run(middleware.awrap_model_call(SimpleNamespace(), handler))
|
||||||
|
|
||||||
|
assert isinstance(result, AIMessage)
|
||||||
|
assert "temporarily unavailable" in str(result.content)
|
||||||
|
assert result.additional_kwargs["deerflow_error_fallback"] is True
|
||||||
|
assert result.additional_kwargs["error_reason"] == "transient"
|
||||||
|
assert result.additional_kwargs["error_detail"] == "Connection error."
|
||||||
|
|
||||||
|
|
||||||
def test_sync_model_call_uses_retry_after_header(monkeypatch: pytest.MonkeyPatch) -> None:
|
def test_sync_model_call_uses_retry_after_header(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
|||||||
@@ -3,12 +3,22 @@ from types import SimpleNamespace
|
|||||||
from unittest.mock import AsyncMock, call
|
from unittest.mock import AsyncMock, call
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
from langgraph.checkpoint.base import empty_checkpoint
|
from langgraph.checkpoint.base import empty_checkpoint
|
||||||
from langgraph.checkpoint.memory import InMemorySaver
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
|
|
||||||
from deerflow.runtime.runs.manager import RunManager
|
from deerflow.runtime.runs.manager import RunManager
|
||||||
from deerflow.runtime.runs.schemas import RunStatus
|
from deerflow.runtime.runs.schemas import RunStatus
|
||||||
from deerflow.runtime.runs.worker import RunContext, _agent_factory_supports_app_config, _build_runtime_context, _install_runtime_context, _rollback_to_pre_run_checkpoint, run_agent
|
from deerflow.runtime.runs.worker import (
|
||||||
|
RunContext,
|
||||||
|
_agent_factory_supports_app_config,
|
||||||
|
_build_runtime_context,
|
||||||
|
_extract_llm_error_fallback_message,
|
||||||
|
_install_runtime_context,
|
||||||
|
_rollback_to_pre_run_checkpoint,
|
||||||
|
_try_extract_from_message,
|
||||||
|
run_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FakeCheckpointer:
|
class FakeCheckpointer:
|
||||||
@@ -95,6 +105,52 @@ async def test_run_agent_threads_explicit_app_config_into_config_only_factory():
|
|||||||
bridge.cleanup.assert_awaited_once_with(record.run_id, delay=60)
|
bridge.cleanup.assert_awaited_once_with(record.run_id, delay=60)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_run_agent_marks_llm_error_fallback_as_error_status():
|
||||||
|
run_manager = RunManager()
|
||||||
|
record = await run_manager.create("thread-1")
|
||||||
|
bridge = SimpleNamespace(
|
||||||
|
publish=AsyncMock(),
|
||||||
|
publish_end=AsyncMock(),
|
||||||
|
cleanup=AsyncMock(),
|
||||||
|
)
|
||||||
|
|
||||||
|
class DummyAgent:
|
||||||
|
async def astream(self, graph_input, config=None, stream_mode=None, subgraphs=False):
|
||||||
|
yield {
|
||||||
|
"messages": [
|
||||||
|
AIMessage(
|
||||||
|
content="The configured LLM provider is temporarily unavailable after multiple retries.",
|
||||||
|
additional_kwargs={
|
||||||
|
"deerflow_error_fallback": True,
|
||||||
|
"error_type": "APIConnectionError",
|
||||||
|
"error_reason": "transient",
|
||||||
|
"error_detail": "Connection error.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
def factory(*, config):
|
||||||
|
return DummyAgent()
|
||||||
|
|
||||||
|
await run_agent(
|
||||||
|
bridge,
|
||||||
|
run_manager,
|
||||||
|
record,
|
||||||
|
ctx=RunContext(checkpointer=None),
|
||||||
|
agent_factory=factory,
|
||||||
|
graph_input={},
|
||||||
|
config={},
|
||||||
|
)
|
||||||
|
|
||||||
|
fetched = await run_manager.get(record.run_id)
|
||||||
|
assert fetched is not None
|
||||||
|
assert fetched.status == RunStatus.error
|
||||||
|
assert fetched.error == "Connection error."
|
||||||
|
bridge.publish_end.assert_awaited_once_with(record.run_id)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_run_agent_defaults_root_run_name_from_assistant_id():
|
async def test_run_agent_defaults_root_run_name_from_assistant_id():
|
||||||
run_manager = RunManager()
|
run_manager = RunManager()
|
||||||
@@ -486,3 +542,133 @@ def test_agent_factory_supports_app_config_returns_false_when_signature_lookup_f
|
|||||||
monkeypatch.setattr("deerflow.runtime.runs.worker.inspect.signature", lambda _obj: (_ for _ in ()).throw(ValueError("boom")))
|
monkeypatch.setattr("deerflow.runtime.runs.worker.inspect.signature", lambda _obj: (_ for _ in ()).throw(ValueError("boom")))
|
||||||
|
|
||||||
assert _agent_factory_supports_app_config(BrokenCallable()) is False
|
assert _agent_factory_supports_app_config(BrokenCallable()) is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _extract_llm_error_fallback_message coverage
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_try_extract_from_message_finds_fallback_on_message_object():
|
||||||
|
msg = AIMessage(
|
||||||
|
content="fallback",
|
||||||
|
additional_kwargs={
|
||||||
|
"deerflow_error_fallback": True,
|
||||||
|
"error_detail": "Connection error.",
|
||||||
|
"error_reason": "transient",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert _try_extract_from_message(msg) == "Connection error."
|
||||||
|
|
||||||
|
|
||||||
|
def test_try_extract_from_message_finds_fallback_on_dict():
|
||||||
|
msg = {
|
||||||
|
"content": "fallback",
|
||||||
|
"additional_kwargs": {
|
||||||
|
"deerflow_error_fallback": True,
|
||||||
|
"error_detail": "Quota exceeded.",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert _try_extract_from_message(msg) == "Quota exceeded."
|
||||||
|
|
||||||
|
|
||||||
|
def test_try_extract_from_message_returns_none_for_normal_message():
|
||||||
|
msg = AIMessage(content="hello")
|
||||||
|
assert _try_extract_from_message(msg) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_llm_error_fallback_message_large_state_chunk_no_fallback():
|
||||||
|
"""Normal-size state dict without fallback markers must not raise and should return None."""
|
||||||
|
large_state = {
|
||||||
|
"messages": [
|
||||||
|
AIMessage(content="Hello!"),
|
||||||
|
{"role": "user", "content": "Hi there"},
|
||||||
|
],
|
||||||
|
"foo": "x" * 10_000,
|
||||||
|
"bar": {"nested": {"deep": {"data": list(range(1000))}}},
|
||||||
|
"baz": [{"id": i, "payload": "y" * 1000} for i in range(500)],
|
||||||
|
}
|
||||||
|
assert _extract_llm_error_fallback_message(large_state) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_llm_error_fallback_message_finds_fallback_in_messages_list():
|
||||||
|
state = {
|
||||||
|
"messages": [
|
||||||
|
AIMessage(content="Hello!"),
|
||||||
|
AIMessage(
|
||||||
|
content="Unavailable.",
|
||||||
|
additional_kwargs={
|
||||||
|
"deerflow_error_fallback": True,
|
||||||
|
"error_detail": "Connection error.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
"other_state": "large_value" * 1000,
|
||||||
|
}
|
||||||
|
assert _extract_llm_error_fallback_message(state) == "Connection error."
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_llm_error_fallback_message_finds_fallback_in_raw_message():
|
||||||
|
msg = AIMessage(
|
||||||
|
content="Unavailable.",
|
||||||
|
additional_kwargs={
|
||||||
|
"deerflow_error_fallback": True,
|
||||||
|
"error_reason": "quota",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert _extract_llm_error_fallback_message(msg) == "quota"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_llm_error_fallback_message_finds_fallback_in_tuple():
|
||||||
|
item = (
|
||||||
|
"messages",
|
||||||
|
AIMessage(
|
||||||
|
content="Unavailable.",
|
||||||
|
additional_kwargs={
|
||||||
|
"deerflow_error_fallback": True,
|
||||||
|
"error_detail": "Circuit open.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert _extract_llm_error_fallback_message(item) == "Circuit open."
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_llm_error_fallback_message_returns_none_for_empty_values():
|
||||||
|
assert _extract_llm_error_fallback_message({}) is None
|
||||||
|
assert _extract_llm_error_fallback_message([]) is None
|
||||||
|
assert _extract_llm_error_fallback_message(None) is None
|
||||||
|
assert _extract_llm_error_fallback_message("string") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_llm_error_fallback_message_finds_fallback_in_updates_mode():
|
||||||
|
"""stream_mode='updates' yields dicts keyed by node name (e.g. {'call_model': {...}}).
|
||||||
|
Fallback marker is nested inside the node's state update, not at the top level."""
|
||||||
|
update_chunk = {
|
||||||
|
"call_model": {
|
||||||
|
"messages": [
|
||||||
|
AIMessage(
|
||||||
|
content="Unavailable.",
|
||||||
|
additional_kwargs={
|
||||||
|
"deerflow_error_fallback": True,
|
||||||
|
"error_detail": "Connection error.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert _extract_llm_error_fallback_message(update_chunk) == "Connection error."
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_llm_error_fallback_message_updates_mode_no_fallback():
|
||||||
|
"""Normal updates chunk without any fallback should return None safely."""
|
||||||
|
update_chunk = {
|
||||||
|
"__interrupt__": [
|
||||||
|
{
|
||||||
|
"value": "ask_human",
|
||||||
|
"resumable": True,
|
||||||
|
"ns": ["agent"],
|
||||||
|
"when": "during",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
assert _extract_llm_error_fallback_message(update_chunk) is None
|
||||||
|
|||||||
Reference in New Issue
Block a user