feat(router.py): enable default fallbacks

allow user to define a generic list of fallbacks, in case a new deployment is bad

Closes https://github.com/BerriAI/litellm/issues/3623
This commit is contained in:
Krrish Dholakia 2024-05-13 17:49:56 -07:00
parent bf8d3be791
commit 6f20389bd5
3 changed files with 132 additions and 11 deletions

View file

@ -9,7 +9,7 @@
import copy, httpx
from datetime import datetime
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple, TypedDict
from typing_extensions import overload
import random, threading, time, traceback, uuid
import litellm, openai, hashlib, json
@ -47,6 +47,7 @@ from litellm.types.router import (
updateLiteLLMParams,
RetryPolicy,
AlertingConfig,
DeploymentTypedDict,
)
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.azure import get_azure_ad_token_from_oidc
@ -62,7 +63,7 @@ class Router:
def __init__(
self,
model_list: Optional[list] = None,
model_list: Optional[List[DeploymentTypedDict]] = None,
## CACHING ##
redis_url: Optional[str] = None,
redis_host: Optional[str] = None,
@ -83,6 +84,9 @@ class Router:
default_max_parallel_requests: Optional[int] = None,
set_verbose: bool = False,
debug_level: Literal["DEBUG", "INFO"] = "INFO",
default_fallbacks: Optional[
List[str]
] = None, # generic fallbacks, works across all deployments
fallbacks: List = [],
context_window_fallbacks: List = [],
model_group_alias: Optional[dict] = {},
@ -259,6 +263,11 @@ class Router:
self.retry_after = retry_after
self.routing_strategy = routing_strategy
self.fallbacks = fallbacks or litellm.fallbacks
if default_fallbacks is not None:
if self.fallbacks is not None:
self.fallbacks.append({"*": default_fallbacks})
else:
self.fallbacks = [{"*": default_fallbacks}]
self.context_window_fallbacks = (
context_window_fallbacks or litellm.context_window_fallbacks
)
@ -1471,13 +1480,21 @@ class Router:
pass
elif fallbacks is not None:
verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}")
for item in fallbacks:
key_list = list(item.keys())
if len(key_list) == 0:
continue
if key_list[0] == model_group:
generic_fallback_idx: Optional[int] = None
## check for specific model group-specific fallbacks
for idx, item in enumerate(fallbacks):
if list(item.keys())[0] == model_group:
fallback_model_group = item[model_group]
break
elif list(item.keys())[0] == "*":
generic_fallback_idx = idx
## if none, check for generic fallback
if (
fallback_model_group is None
and generic_fallback_idx is not None
):
fallback_model_group = fallbacks[generic_fallback_idx]["*"]
if fallback_model_group is None:
verbose_router_logger.info(
f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}"
@ -1537,7 +1554,7 @@ class Router:
"""
_healthy_deployments = await self._async_get_healthy_deployments(
model=kwargs.get("model"),
model=kwargs.get("model") or "",
)
# raises an exception if this error should not be retries
@ -1644,12 +1661,18 @@ class Router:
Try calling the function_with_retries
If it fails after num_retries, fall back to another model group
"""
mock_testing_fallbacks = kwargs.pop("mock_testing_fallbacks", None)
model_group = kwargs.get("model")
fallbacks = kwargs.get("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.get(
"context_window_fallbacks", self.context_window_fallbacks
)
try:
if mock_testing_fallbacks is not None and mock_testing_fallbacks == True:
raise Exception(
f"This is a mock exception for model={model_group}, to trigger a fallback. Fallbacks={fallbacks}"
)
response = self.function_with_retries(*args, **kwargs)
return response
except Exception as e:
@ -1658,7 +1681,7 @@ class Router:
try:
if (
hasattr(e, "status_code")
and e.status_code == 400
and e.status_code == 400 # type: ignore
and not isinstance(e, litellm.ContextWindowExceededError)
): # don't retry a malformed request
raise e
@ -1700,10 +1723,20 @@ class Router:
elif fallbacks is not None:
verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}")
fallback_model_group = None
for item in fallbacks:
generic_fallback_idx: Optional[int] = None
## check for specific model group-specific fallbacks
for idx, item in enumerate(fallbacks):
if list(item.keys())[0] == model_group:
fallback_model_group = item[model_group]
break
elif list(item.keys())[0] == "*":
generic_fallback_idx = idx
## if none, check for generic fallback
if (
fallback_model_group is None
and generic_fallback_idx is not None
):
fallback_model_group = fallbacks[generic_fallback_idx]["*"]
if fallback_model_group is None:
raise original_exception