mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-10 17:35:57 +00:00
fix(middleware): fix LLM fallback run status (#3321)
* Fix LLM fallback run status * optimize LLM fallback maker extraction in streaming path
This commit is contained in:
@@ -3,12 +3,22 @@ from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, call
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.checkpoint.base import empty_checkpoint
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from deerflow.runtime.runs.manager import RunManager
|
||||
from deerflow.runtime.runs.schemas import RunStatus
|
||||
from deerflow.runtime.runs.worker import RunContext, _agent_factory_supports_app_config, _build_runtime_context, _install_runtime_context, _rollback_to_pre_run_checkpoint, run_agent
|
||||
from deerflow.runtime.runs.worker import (
|
||||
RunContext,
|
||||
_agent_factory_supports_app_config,
|
||||
_build_runtime_context,
|
||||
_extract_llm_error_fallback_message,
|
||||
_install_runtime_context,
|
||||
_rollback_to_pre_run_checkpoint,
|
||||
_try_extract_from_message,
|
||||
run_agent,
|
||||
)
|
||||
|
||||
|
||||
class FakeCheckpointer:
|
||||
@@ -95,6 +105,52 @@ async def test_run_agent_threads_explicit_app_config_into_config_only_factory():
|
||||
bridge.cleanup.assert_awaited_once_with(record.run_id, delay=60)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_run_agent_marks_llm_error_fallback_as_error_status():
|
||||
run_manager = RunManager()
|
||||
record = await run_manager.create("thread-1")
|
||||
bridge = SimpleNamespace(
|
||||
publish=AsyncMock(),
|
||||
publish_end=AsyncMock(),
|
||||
cleanup=AsyncMock(),
|
||||
)
|
||||
|
||||
class DummyAgent:
|
||||
async def astream(self, graph_input, config=None, stream_mode=None, subgraphs=False):
|
||||
yield {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="The configured LLM provider is temporarily unavailable after multiple retries.",
|
||||
additional_kwargs={
|
||||
"deerflow_error_fallback": True,
|
||||
"error_type": "APIConnectionError",
|
||||
"error_reason": "transient",
|
||||
"error_detail": "Connection error.",
|
||||
},
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
def factory(*, config):
|
||||
return DummyAgent()
|
||||
|
||||
await run_agent(
|
||||
bridge,
|
||||
run_manager,
|
||||
record,
|
||||
ctx=RunContext(checkpointer=None),
|
||||
agent_factory=factory,
|
||||
graph_input={},
|
||||
config={},
|
||||
)
|
||||
|
||||
fetched = await run_manager.get(record.run_id)
|
||||
assert fetched is not None
|
||||
assert fetched.status == RunStatus.error
|
||||
assert fetched.error == "Connection error."
|
||||
bridge.publish_end.assert_awaited_once_with(record.run_id)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_run_agent_defaults_root_run_name_from_assistant_id():
|
||||
run_manager = RunManager()
|
||||
@@ -486,3 +542,133 @@ def test_agent_factory_supports_app_config_returns_false_when_signature_lookup_f
|
||||
monkeypatch.setattr("deerflow.runtime.runs.worker.inspect.signature", lambda _obj: (_ for _ in ()).throw(ValueError("boom")))
|
||||
|
||||
assert _agent_factory_supports_app_config(BrokenCallable()) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _extract_llm_error_fallback_message coverage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_try_extract_from_message_finds_fallback_on_message_object():
|
||||
msg = AIMessage(
|
||||
content="fallback",
|
||||
additional_kwargs={
|
||||
"deerflow_error_fallback": True,
|
||||
"error_detail": "Connection error.",
|
||||
"error_reason": "transient",
|
||||
},
|
||||
)
|
||||
assert _try_extract_from_message(msg) == "Connection error."
|
||||
|
||||
|
||||
def test_try_extract_from_message_finds_fallback_on_dict():
|
||||
msg = {
|
||||
"content": "fallback",
|
||||
"additional_kwargs": {
|
||||
"deerflow_error_fallback": True,
|
||||
"error_detail": "Quota exceeded.",
|
||||
},
|
||||
}
|
||||
assert _try_extract_from_message(msg) == "Quota exceeded."
|
||||
|
||||
|
||||
def test_try_extract_from_message_returns_none_for_normal_message():
|
||||
msg = AIMessage(content="hello")
|
||||
assert _try_extract_from_message(msg) is None
|
||||
|
||||
|
||||
def test_extract_llm_error_fallback_message_large_state_chunk_no_fallback():
|
||||
"""Normal-size state dict without fallback markers must not raise and should return None."""
|
||||
large_state = {
|
||||
"messages": [
|
||||
AIMessage(content="Hello!"),
|
||||
{"role": "user", "content": "Hi there"},
|
||||
],
|
||||
"foo": "x" * 10_000,
|
||||
"bar": {"nested": {"deep": {"data": list(range(1000))}}},
|
||||
"baz": [{"id": i, "payload": "y" * 1000} for i in range(500)],
|
||||
}
|
||||
assert _extract_llm_error_fallback_message(large_state) is None
|
||||
|
||||
|
||||
def test_extract_llm_error_fallback_message_finds_fallback_in_messages_list():
|
||||
state = {
|
||||
"messages": [
|
||||
AIMessage(content="Hello!"),
|
||||
AIMessage(
|
||||
content="Unavailable.",
|
||||
additional_kwargs={
|
||||
"deerflow_error_fallback": True,
|
||||
"error_detail": "Connection error.",
|
||||
},
|
||||
),
|
||||
],
|
||||
"other_state": "large_value" * 1000,
|
||||
}
|
||||
assert _extract_llm_error_fallback_message(state) == "Connection error."
|
||||
|
||||
|
||||
def test_extract_llm_error_fallback_message_finds_fallback_in_raw_message():
|
||||
msg = AIMessage(
|
||||
content="Unavailable.",
|
||||
additional_kwargs={
|
||||
"deerflow_error_fallback": True,
|
||||
"error_reason": "quota",
|
||||
},
|
||||
)
|
||||
assert _extract_llm_error_fallback_message(msg) == "quota"
|
||||
|
||||
|
||||
def test_extract_llm_error_fallback_message_finds_fallback_in_tuple():
|
||||
item = (
|
||||
"messages",
|
||||
AIMessage(
|
||||
content="Unavailable.",
|
||||
additional_kwargs={
|
||||
"deerflow_error_fallback": True,
|
||||
"error_detail": "Circuit open.",
|
||||
},
|
||||
),
|
||||
)
|
||||
assert _extract_llm_error_fallback_message(item) == "Circuit open."
|
||||
|
||||
|
||||
def test_extract_llm_error_fallback_message_returns_none_for_empty_values():
|
||||
assert _extract_llm_error_fallback_message({}) is None
|
||||
assert _extract_llm_error_fallback_message([]) is None
|
||||
assert _extract_llm_error_fallback_message(None) is None
|
||||
assert _extract_llm_error_fallback_message("string") is None
|
||||
|
||||
|
||||
def test_extract_llm_error_fallback_message_finds_fallback_in_updates_mode():
|
||||
"""stream_mode='updates' yields dicts keyed by node name (e.g. {'call_model': {...}}).
|
||||
Fallback marker is nested inside the node's state update, not at the top level."""
|
||||
update_chunk = {
|
||||
"call_model": {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="Unavailable.",
|
||||
additional_kwargs={
|
||||
"deerflow_error_fallback": True,
|
||||
"error_detail": "Connection error.",
|
||||
},
|
||||
)
|
||||
]
|
||||
}
|
||||
}
|
||||
assert _extract_llm_error_fallback_message(update_chunk) == "Connection error."
|
||||
|
||||
|
||||
def test_extract_llm_error_fallback_message_updates_mode_no_fallback():
|
||||
"""Normal updates chunk without any fallback should return None safely."""
|
||||
update_chunk = {
|
||||
"__interrupt__": [
|
||||
{
|
||||
"value": "ask_human",
|
||||
"resumable": True,
|
||||
"ns": ["agent"],
|
||||
"when": "during",
|
||||
}
|
||||
]
|
||||
}
|
||||
assert _extract_llm_error_fallback_message(update_chunk) is None
|
||||
|
||||
Reference in New Issue
Block a user