feat(router.py): enable default fallbacks

allow user to define a generic list of fallbacks, in case a new deployment is bad

Closes https://github.com/BerriAI/litellm/issues/3623
This commit is contained in:
Krrish Dholakia 2024-05-13 17:49:56 -07:00
parent d7c28509d7
commit 5488bf4921
3 changed files with 132 additions and 11 deletions

View file

@ -9,7 +9,7 @@
import copy, httpx import copy, httpx
from datetime import datetime from datetime import datetime
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple, TypedDict
from typing_extensions import overload from typing_extensions import overload
import random, threading, time, traceback, uuid import random, threading, time, traceback, uuid
import litellm, openai, hashlib, json import litellm, openai, hashlib, json
@ -47,6 +47,7 @@ from litellm.types.router import (
updateLiteLLMParams, updateLiteLLMParams,
RetryPolicy, RetryPolicy,
AlertingConfig, AlertingConfig,
DeploymentTypedDict,
) )
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.azure import get_azure_ad_token_from_oidc from litellm.llms.azure import get_azure_ad_token_from_oidc
@ -62,7 +63,7 @@ class Router:
def __init__( def __init__(
self, self,
model_list: Optional[list] = None, model_list: Optional[List[DeploymentTypedDict]] = None,
## CACHING ## ## CACHING ##
redis_url: Optional[str] = None, redis_url: Optional[str] = None,
redis_host: Optional[str] = None, redis_host: Optional[str] = None,
@ -83,6 +84,9 @@ class Router:
default_max_parallel_requests: Optional[int] = None, default_max_parallel_requests: Optional[int] = None,
set_verbose: bool = False, set_verbose: bool = False,
debug_level: Literal["DEBUG", "INFO"] = "INFO", debug_level: Literal["DEBUG", "INFO"] = "INFO",
default_fallbacks: Optional[
List[str]
] = None, # generic fallbacks, works across all deployments
fallbacks: List = [], fallbacks: List = [],
context_window_fallbacks: List = [], context_window_fallbacks: List = [],
model_group_alias: Optional[dict] = {}, model_group_alias: Optional[dict] = {},
@ -259,6 +263,11 @@ class Router:
self.retry_after = retry_after self.retry_after = retry_after
self.routing_strategy = routing_strategy self.routing_strategy = routing_strategy
self.fallbacks = fallbacks or litellm.fallbacks self.fallbacks = fallbacks or litellm.fallbacks
if default_fallbacks is not None:
if self.fallbacks is not None:
self.fallbacks.append({"*": default_fallbacks})
else:
self.fallbacks = [{"*": default_fallbacks}]
self.context_window_fallbacks = ( self.context_window_fallbacks = (
context_window_fallbacks or litellm.context_window_fallbacks context_window_fallbacks or litellm.context_window_fallbacks
) )
@ -1471,13 +1480,21 @@ class Router:
pass pass
elif fallbacks is not None: elif fallbacks is not None:
verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}") verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}")
for item in fallbacks: generic_fallback_idx: Optional[int] = None
key_list = list(item.keys()) ## check for specific model group-specific fallbacks
if len(key_list) == 0: for idx, item in enumerate(fallbacks):
continue if list(item.keys())[0] == model_group:
if key_list[0] == model_group:
fallback_model_group = item[model_group] fallback_model_group = item[model_group]
break break
elif list(item.keys())[0] == "*":
generic_fallback_idx = idx
## if none, check for generic fallback
if (
fallback_model_group is None
and generic_fallback_idx is not None
):
fallback_model_group = fallbacks[generic_fallback_idx]["*"]
if fallback_model_group is None: if fallback_model_group is None:
verbose_router_logger.info( verbose_router_logger.info(
f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}" f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}"
@ -1537,7 +1554,7 @@ class Router:
""" """
_healthy_deployments = await self._async_get_healthy_deployments( _healthy_deployments = await self._async_get_healthy_deployments(
model=kwargs.get("model"), model=kwargs.get("model") or "",
) )
# raises an exception if this error should not be retries # raises an exception if this error should not be retries
@ -1644,12 +1661,18 @@ class Router:
Try calling the function_with_retries Try calling the function_with_retries
If it fails after num_retries, fall back to another model group If it fails after num_retries, fall back to another model group
""" """
mock_testing_fallbacks = kwargs.pop("mock_testing_fallbacks", None)
model_group = kwargs.get("model") model_group = kwargs.get("model")
fallbacks = kwargs.get("fallbacks", self.fallbacks) fallbacks = kwargs.get("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.get( context_window_fallbacks = kwargs.get(
"context_window_fallbacks", self.context_window_fallbacks "context_window_fallbacks", self.context_window_fallbacks
) )
try: try:
if mock_testing_fallbacks is not None and mock_testing_fallbacks == True:
raise Exception(
f"This is a mock exception for model={model_group}, to trigger a fallback. Fallbacks={fallbacks}"
)
response = self.function_with_retries(*args, **kwargs) response = self.function_with_retries(*args, **kwargs)
return response return response
except Exception as e: except Exception as e:
@ -1658,7 +1681,7 @@ class Router:
try: try:
if ( if (
hasattr(e, "status_code") hasattr(e, "status_code")
and e.status_code == 400 and e.status_code == 400 # type: ignore
and not isinstance(e, litellm.ContextWindowExceededError) and not isinstance(e, litellm.ContextWindowExceededError)
): # don't retry a malformed request ): # don't retry a malformed request
raise e raise e
@ -1700,10 +1723,20 @@ class Router:
elif fallbacks is not None: elif fallbacks is not None:
verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}") verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}")
fallback_model_group = None fallback_model_group = None
for item in fallbacks: generic_fallback_idx: Optional[int] = None
## check for specific model group-specific fallbacks
for idx, item in enumerate(fallbacks):
if list(item.keys())[0] == model_group: if list(item.keys())[0] == model_group:
fallback_model_group = item[model_group] fallback_model_group = item[model_group]
break break
elif list(item.keys())[0] == "*":
generic_fallback_idx = idx
## if none, check for generic fallback
if (
fallback_model_group is None
and generic_fallback_idx is not None
):
fallback_model_group = fallbacks[generic_fallback_idx]["*"]
if fallback_model_group is None: if fallback_model_group is None:
raise original_exception raise original_exception

View file

@ -1007,3 +1007,50 @@ async def test_service_unavailable_fallbacks(sync_mode):
) )
assert response.model == "gpt-35-turbo" assert response.model == "gpt-35-turbo"
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_default_model_fallbacks(sync_mode):
"""
Related issue - https://github.com/BerriAI/litellm/issues/3623
If model misconfigured, setup a default model for generic fallback
"""
router = Router(
model_list=[
{
"model_name": "bad-model",
"litellm_params": {
"model": "openai/my-bad-model",
"api_key": "my-bad-api-key",
},
},
{
"model_name": "my-good-model",
"litellm_params": {
"model": "gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
],
default_fallbacks=["my-good-model"],
)
if sync_mode:
response = router.completion(
model="bad-model",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
mock_testing_fallbacks=True,
mock_response="Hey! nice day",
)
else:
response = await router.acompletion(
model="bad-model",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
mock_testing_fallbacks=True,
mock_response="Hey! nice day",
)
assert isinstance(response, litellm.ModelResponse)
assert response.model is not None and response.model == "gpt-4o"

View file

@ -1,4 +1,4 @@
from typing import List, Optional, Union, Dict, Tuple, Literal from typing import List, Optional, Union, Dict, Tuple, Literal, TypedDict
import httpx import httpx
from pydantic import BaseModel, validator, Field from pydantic import BaseModel, validator, Field
from .completion import CompletionRequest from .completion import CompletionRequest
@ -277,6 +277,47 @@ class updateDeployment(BaseModel):
protected_namespaces = () protected_namespaces = ()
class LiteLLMParamsTypedDict(TypedDict, total=False):
"""
[TODO]
- allow additional params (not in list)
- set value to none if not set -> don't raise error if value not set
"""
model: str
custom_llm_provider: Optional[str]
tpm: Optional[int]
rpm: Optional[int]
api_key: Optional[str]
api_base: Optional[str]
api_version: Optional[str]
timeout: Optional[Union[float, str, httpx.Timeout]]
stream_timeout: Optional[Union[float, str]]
max_retries: Optional[int]
organization: Optional[str] # for openai orgs
## UNIFIED PROJECT/REGION ##
region_name: Optional[str]
## VERTEX AI ##
vertex_project: Optional[str]
vertex_location: Optional[str]
## AWS BEDROCK / SAGEMAKER ##
aws_access_key_id: Optional[str]
aws_secret_access_key: Optional[str]
aws_region_name: Optional[str]
## IBM WATSONX ##
watsonx_region_name: Optional[str]
## CUSTOM PRICING ##
input_cost_per_token: Optional[float]
output_cost_per_token: Optional[float]
input_cost_per_second: Optional[float]
output_cost_per_second: Optional[float]
class DeploymentTypedDict(TypedDict):
model_name: str
litellm_params: LiteLLMParamsTypedDict
class Deployment(BaseModel): class Deployment(BaseModel):
model_name: str model_name: str
litellm_params: LiteLLM_Params litellm_params: LiteLLM_Params