mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 16:35:59 +00:00
feat(llm): introduce lightweight circuit breaker to prevent rate-limit bans and resource exhaustion (#2095)
This commit is contained in:
+102
-2
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from email.utils import parsedate_to_datetime
|
||||
@@ -19,6 +20,8 @@ from langchain.agents.middleware.types import (
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.errors import GraphBubbleUp
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_RETRIABLE_STATUS_CODES = {408, 409, 425, 429, 500, 502, 503, 504}
|
||||
@@ -67,6 +70,80 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
||||
retry_base_delay_ms: int = 1000
|
||||
retry_cap_delay_ms: int = 8000
|
||||
|
||||
circuit_failure_threshold: int = 5
|
||||
circuit_recovery_timeout_sec: int = 60
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Load Circuit Breaker configs from app config if available, fall back to defaults
|
||||
try:
|
||||
app_config = get_app_config()
|
||||
self.circuit_failure_threshold = app_config.circuit_breaker.failure_threshold
|
||||
self.circuit_recovery_timeout_sec = app_config.circuit_breaker.recovery_timeout_sec
|
||||
except (FileNotFoundError, RuntimeError):
|
||||
# Gracefully fall back to class defaults in test environments
|
||||
pass
|
||||
|
||||
# Circuit Breaker state
|
||||
self._circuit_lock = threading.Lock()
|
||||
self._circuit_failure_count = 0
|
||||
self._circuit_open_until = 0.0
|
||||
self._circuit_state = "closed"
|
||||
self._circuit_probe_in_flight = False
|
||||
|
||||
def _check_circuit(self) -> bool:
|
||||
"""Returns True if circuit is OPEN (fast fail), False otherwise."""
|
||||
with self._circuit_lock:
|
||||
now = time.time()
|
||||
|
||||
if self._circuit_state == "open":
|
||||
if now < self._circuit_open_until:
|
||||
return True
|
||||
self._circuit_state = "half_open"
|
||||
self._circuit_probe_in_flight = False
|
||||
|
||||
if self._circuit_state == "half_open":
|
||||
if self._circuit_probe_in_flight:
|
||||
return True
|
||||
self._circuit_probe_in_flight = True
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
def _record_success(self) -> None:
|
||||
with self._circuit_lock:
|
||||
if self._circuit_state != "closed" or self._circuit_failure_count > 0:
|
||||
logger.info("Circuit breaker reset (Closed). LLM service recovered.")
|
||||
self._circuit_failure_count = 0
|
||||
self._circuit_open_until = 0.0
|
||||
self._circuit_state = "closed"
|
||||
self._circuit_probe_in_flight = False
|
||||
|
||||
def _record_failure(self) -> None:
|
||||
with self._circuit_lock:
|
||||
if self._circuit_state == "half_open":
|
||||
self._circuit_open_until = time.time() + self.circuit_recovery_timeout_sec
|
||||
self._circuit_state = "open"
|
||||
self._circuit_probe_in_flight = False
|
||||
logger.error(
|
||||
"Circuit breaker probe failed (Open). Will probe again after %ds.",
|
||||
self.circuit_recovery_timeout_sec,
|
||||
)
|
||||
return
|
||||
|
||||
self._circuit_failure_count += 1
|
||||
if self._circuit_failure_count >= self.circuit_failure_threshold:
|
||||
self._circuit_open_until = time.time() + self.circuit_recovery_timeout_sec
|
||||
if self._circuit_state != "open":
|
||||
self._circuit_state = "open"
|
||||
self._circuit_probe_in_flight = False
|
||||
logger.error(
|
||||
"Circuit breaker tripped (Open). Threshold reached (%d). Will probe after %ds.",
|
||||
self.circuit_failure_threshold,
|
||||
self.circuit_recovery_timeout_sec,
|
||||
)
|
||||
|
||||
def _classify_error(self, exc: BaseException) -> tuple[bool, str]:
|
||||
detail = _extract_error_detail(exc)
|
||||
lowered = detail.lower()
|
||||
@@ -104,6 +181,9 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
||||
reason_text = "provider is busy" if reason == "busy" else "provider request failed temporarily"
|
||||
return f"LLM request retry {attempt}/{self.retry_max_attempts}: {reason_text}. Retrying in {seconds}s."
|
||||
|
||||
def _build_circuit_breaker_message(self) -> str:
|
||||
return "The configured LLM provider is currently unavailable due to continuous failures. Circuit breaker is engaged to protect the system. Please wait a moment before trying again."
|
||||
|
||||
def _build_user_message(self, exc: BaseException, reason: str) -> str:
|
||||
detail = _extract_error_detail(exc)
|
||||
if reason == "quota":
|
||||
@@ -138,12 +218,20 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
if self._check_circuit():
|
||||
return AIMessage(content=self._build_circuit_breaker_message())
|
||||
|
||||
attempt = 1
|
||||
while True:
|
||||
try:
|
||||
return handler(request)
|
||||
response = handler(request)
|
||||
self._record_success()
|
||||
return response
|
||||
except GraphBubbleUp:
|
||||
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
||||
with self._circuit_lock:
|
||||
if self._circuit_state == "half_open":
|
||||
self._circuit_probe_in_flight = False
|
||||
raise
|
||||
except Exception as exc:
|
||||
retriable, reason = self._classify_error(exc)
|
||||
@@ -166,6 +254,8 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
||||
_extract_error_detail(exc),
|
||||
exc_info=exc,
|
||||
)
|
||||
if retriable:
|
||||
self._record_failure()
|
||||
return AIMessage(content=self._build_user_message(exc, reason))
|
||||
|
||||
@override
|
||||
@@ -174,12 +264,20 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
if self._check_circuit():
|
||||
return AIMessage(content=self._build_circuit_breaker_message())
|
||||
|
||||
attempt = 1
|
||||
while True:
|
||||
try:
|
||||
return await handler(request)
|
||||
response = await handler(request)
|
||||
self._record_success()
|
||||
return response
|
||||
except GraphBubbleUp:
|
||||
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
||||
with self._circuit_lock:
|
||||
if self._circuit_state == "half_open":
|
||||
self._circuit_probe_in_flight = False
|
||||
raise
|
||||
except Exception as exc:
|
||||
retriable, reason = self._classify_error(exc)
|
||||
@@ -202,6 +300,8 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
||||
_extract_error_detail(exc),
|
||||
exc_info=exc,
|
||||
)
|
||||
if retriable:
|
||||
self._record_failure()
|
||||
return AIMessage(content=self._build_user_message(exc, reason))
|
||||
|
||||
|
||||
|
||||
@@ -30,6 +30,13 @@ load_dotenv()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CircuitBreakerConfig(BaseModel):
|
||||
"""Configuration for the LLM Circuit Breaker."""
|
||||
|
||||
failure_threshold: int = Field(default=5, description="Number of consecutive failures before tripping the circuit")
|
||||
recovery_timeout_sec: int = Field(default=60, description="Time in seconds before attempting to recover the circuit")
|
||||
|
||||
|
||||
def _default_config_candidates() -> tuple[Path, ...]:
|
||||
"""Return deterministic config.yaml locations without relying on cwd."""
|
||||
backend_dir = Path(__file__).resolve().parents[4]
|
||||
@@ -55,6 +62,7 @@ class AppConfig(BaseModel):
|
||||
memory: MemoryConfig = Field(default_factory=MemoryConfig, description="Memory subsystem configuration")
|
||||
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
|
||||
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
|
||||
circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
|
||||
model_config = ConfigDict(extra="allow", frozen=False)
|
||||
checkpointer: CheckpointerConfig | None = Field(default=None, description="Checkpointer configuration")
|
||||
stream_bridge: StreamBridgeConfig | None = Field(default=None, description="Stream bridge configuration")
|
||||
@@ -129,6 +137,10 @@ class AppConfig(BaseModel):
|
||||
if "guardrails" in config_data:
|
||||
load_guardrails_config_from_dict(config_data["guardrails"])
|
||||
|
||||
# Load circuit_breaker config if present
|
||||
if "circuit_breaker" in config_data:
|
||||
config_data["circuit_breaker"] = config_data["circuit_breaker"]
|
||||
|
||||
# Load checkpointer config if present
|
||||
if "checkpointer" in config_data:
|
||||
load_checkpointer_config_from_dict(config_data["checkpointer"])
|
||||
|
||||
Reference in New Issue
Block a user