mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
feat(router.py): enable default fallbacks
allow user to define a generic list of fallbacks, in case a new deployment is bad Closes https://github.com/BerriAI/litellm/issues/3623
This commit is contained in:
parent
bf8d3be791
commit
6f20389bd5
3 changed files with 132 additions and 11 deletions
|
@ -9,7 +9,7 @@
|
|||
|
||||
import copy, httpx
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple
|
||||
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple, TypedDict
|
||||
from typing_extensions import overload
|
||||
import random, threading, time, traceback, uuid
|
||||
import litellm, openai, hashlib, json
|
||||
|
@ -47,6 +47,7 @@ from litellm.types.router import (
|
|||
updateLiteLLMParams,
|
||||
RetryPolicy,
|
||||
AlertingConfig,
|
||||
DeploymentTypedDict,
|
||||
)
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||
|
@ -62,7 +63,7 @@ class Router:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
model_list: Optional[list] = None,
|
||||
model_list: Optional[List[DeploymentTypedDict]] = None,
|
||||
## CACHING ##
|
||||
redis_url: Optional[str] = None,
|
||||
redis_host: Optional[str] = None,
|
||||
|
@ -83,6 +84,9 @@ class Router:
|
|||
default_max_parallel_requests: Optional[int] = None,
|
||||
set_verbose: bool = False,
|
||||
debug_level: Literal["DEBUG", "INFO"] = "INFO",
|
||||
default_fallbacks: Optional[
|
||||
List[str]
|
||||
] = None, # generic fallbacks, works across all deployments
|
||||
fallbacks: List = [],
|
||||
context_window_fallbacks: List = [],
|
||||
model_group_alias: Optional[dict] = {},
|
||||
|
@ -259,6 +263,11 @@ class Router:
|
|||
self.retry_after = retry_after
|
||||
self.routing_strategy = routing_strategy
|
||||
self.fallbacks = fallbacks or litellm.fallbacks
|
||||
if default_fallbacks is not None:
|
||||
if self.fallbacks is not None:
|
||||
self.fallbacks.append({"*": default_fallbacks})
|
||||
else:
|
||||
self.fallbacks = [{"*": default_fallbacks}]
|
||||
self.context_window_fallbacks = (
|
||||
context_window_fallbacks or litellm.context_window_fallbacks
|
||||
)
|
||||
|
@ -1471,13 +1480,21 @@ class Router:
|
|||
pass
|
||||
elif fallbacks is not None:
|
||||
verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}")
|
||||
for item in fallbacks:
|
||||
key_list = list(item.keys())
|
||||
if len(key_list) == 0:
|
||||
continue
|
||||
if key_list[0] == model_group:
|
||||
generic_fallback_idx: Optional[int] = None
|
||||
## check for specific model group-specific fallbacks
|
||||
for idx, item in enumerate(fallbacks):
|
||||
if list(item.keys())[0] == model_group:
|
||||
fallback_model_group = item[model_group]
|
||||
break
|
||||
elif list(item.keys())[0] == "*":
|
||||
generic_fallback_idx = idx
|
||||
## if none, check for generic fallback
|
||||
if (
|
||||
fallback_model_group is None
|
||||
and generic_fallback_idx is not None
|
||||
):
|
||||
fallback_model_group = fallbacks[generic_fallback_idx]["*"]
|
||||
|
||||
if fallback_model_group is None:
|
||||
verbose_router_logger.info(
|
||||
f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}"
|
||||
|
@ -1537,7 +1554,7 @@ class Router:
|
|||
|
||||
"""
|
||||
_healthy_deployments = await self._async_get_healthy_deployments(
|
||||
model=kwargs.get("model"),
|
||||
model=kwargs.get("model") or "",
|
||||
)
|
||||
|
||||
# raises an exception if this error should not be retries
|
||||
|
@ -1644,12 +1661,18 @@ class Router:
|
|||
Try calling the function_with_retries
|
||||
If it fails after num_retries, fall back to another model group
|
||||
"""
|
||||
mock_testing_fallbacks = kwargs.pop("mock_testing_fallbacks", None)
|
||||
model_group = kwargs.get("model")
|
||||
fallbacks = kwargs.get("fallbacks", self.fallbacks)
|
||||
context_window_fallbacks = kwargs.get(
|
||||
"context_window_fallbacks", self.context_window_fallbacks
|
||||
)
|
||||
try:
|
||||
if mock_testing_fallbacks is not None and mock_testing_fallbacks == True:
|
||||
raise Exception(
|
||||
f"This is a mock exception for model={model_group}, to trigger a fallback. Fallbacks={fallbacks}"
|
||||
)
|
||||
|
||||
response = self.function_with_retries(*args, **kwargs)
|
||||
return response
|
||||
except Exception as e:
|
||||
|
@ -1658,7 +1681,7 @@ class Router:
|
|||
try:
|
||||
if (
|
||||
hasattr(e, "status_code")
|
||||
and e.status_code == 400
|
||||
and e.status_code == 400 # type: ignore
|
||||
and not isinstance(e, litellm.ContextWindowExceededError)
|
||||
): # don't retry a malformed request
|
||||
raise e
|
||||
|
@ -1700,10 +1723,20 @@ class Router:
|
|||
elif fallbacks is not None:
|
||||
verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}")
|
||||
fallback_model_group = None
|
||||
for item in fallbacks:
|
||||
generic_fallback_idx: Optional[int] = None
|
||||
## check for specific model group-specific fallbacks
|
||||
for idx, item in enumerate(fallbacks):
|
||||
if list(item.keys())[0] == model_group:
|
||||
fallback_model_group = item[model_group]
|
||||
break
|
||||
elif list(item.keys())[0] == "*":
|
||||
generic_fallback_idx = idx
|
||||
## if none, check for generic fallback
|
||||
if (
|
||||
fallback_model_group is None
|
||||
and generic_fallback_idx is not None
|
||||
):
|
||||
fallback_model_group = fallbacks[generic_fallback_idx]["*"]
|
||||
|
||||
if fallback_model_group is None:
|
||||
raise original_exception
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue