mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
feat(router.py): enable default fallbacks
allow user to define a generic list of fallbacks, in case a new deployment is bad Closes https://github.com/BerriAI/litellm/issues/3623
This commit is contained in:
parent
d7c28509d7
commit
5488bf4921
3 changed files with 132 additions and 11 deletions
|
@ -9,7 +9,7 @@
|
|||
|
||||
import copy, httpx
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple
|
||||
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple, TypedDict
|
||||
from typing_extensions import overload
|
||||
import random, threading, time, traceback, uuid
|
||||
import litellm, openai, hashlib, json
|
||||
|
@ -47,6 +47,7 @@ from litellm.types.router import (
|
|||
updateLiteLLMParams,
|
||||
RetryPolicy,
|
||||
AlertingConfig,
|
||||
DeploymentTypedDict,
|
||||
)
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||
|
@ -62,7 +63,7 @@ class Router:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
model_list: Optional[list] = None,
|
||||
model_list: Optional[List[DeploymentTypedDict]] = None,
|
||||
## CACHING ##
|
||||
redis_url: Optional[str] = None,
|
||||
redis_host: Optional[str] = None,
|
||||
|
@ -83,6 +84,9 @@ class Router:
|
|||
default_max_parallel_requests: Optional[int] = None,
|
||||
set_verbose: bool = False,
|
||||
debug_level: Literal["DEBUG", "INFO"] = "INFO",
|
||||
default_fallbacks: Optional[
|
||||
List[str]
|
||||
] = None, # generic fallbacks, works across all deployments
|
||||
fallbacks: List = [],
|
||||
context_window_fallbacks: List = [],
|
||||
model_group_alias: Optional[dict] = {},
|
||||
|
@ -259,6 +263,11 @@ class Router:
|
|||
self.retry_after = retry_after
|
||||
self.routing_strategy = routing_strategy
|
||||
self.fallbacks = fallbacks or litellm.fallbacks
|
||||
if default_fallbacks is not None:
|
||||
if self.fallbacks is not None:
|
||||
self.fallbacks.append({"*": default_fallbacks})
|
||||
else:
|
||||
self.fallbacks = [{"*": default_fallbacks}]
|
||||
self.context_window_fallbacks = (
|
||||
context_window_fallbacks or litellm.context_window_fallbacks
|
||||
)
|
||||
|
@ -1471,13 +1480,21 @@ class Router:
|
|||
pass
|
||||
elif fallbacks is not None:
|
||||
verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}")
|
||||
for item in fallbacks:
|
||||
key_list = list(item.keys())
|
||||
if len(key_list) == 0:
|
||||
continue
|
||||
if key_list[0] == model_group:
|
||||
generic_fallback_idx: Optional[int] = None
|
||||
## check for specific model group-specific fallbacks
|
||||
for idx, item in enumerate(fallbacks):
|
||||
if list(item.keys())[0] == model_group:
|
||||
fallback_model_group = item[model_group]
|
||||
break
|
||||
elif list(item.keys())[0] == "*":
|
||||
generic_fallback_idx = idx
|
||||
## if none, check for generic fallback
|
||||
if (
|
||||
fallback_model_group is None
|
||||
and generic_fallback_idx is not None
|
||||
):
|
||||
fallback_model_group = fallbacks[generic_fallback_idx]["*"]
|
||||
|
||||
if fallback_model_group is None:
|
||||
verbose_router_logger.info(
|
||||
f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}"
|
||||
|
@ -1537,7 +1554,7 @@ class Router:
|
|||
|
||||
"""
|
||||
_healthy_deployments = await self._async_get_healthy_deployments(
|
||||
model=kwargs.get("model"),
|
||||
model=kwargs.get("model") or "",
|
||||
)
|
||||
|
||||
# raises an exception if this error should not be retries
|
||||
|
@ -1644,12 +1661,18 @@ class Router:
|
|||
Try calling the function_with_retries
|
||||
If it fails after num_retries, fall back to another model group
|
||||
"""
|
||||
mock_testing_fallbacks = kwargs.pop("mock_testing_fallbacks", None)
|
||||
model_group = kwargs.get("model")
|
||||
fallbacks = kwargs.get("fallbacks", self.fallbacks)
|
||||
context_window_fallbacks = kwargs.get(
|
||||
"context_window_fallbacks", self.context_window_fallbacks
|
||||
)
|
||||
try:
|
||||
if mock_testing_fallbacks is not None and mock_testing_fallbacks == True:
|
||||
raise Exception(
|
||||
f"This is a mock exception for model={model_group}, to trigger a fallback. Fallbacks={fallbacks}"
|
||||
)
|
||||
|
||||
response = self.function_with_retries(*args, **kwargs)
|
||||
return response
|
||||
except Exception as e:
|
||||
|
@ -1658,7 +1681,7 @@ class Router:
|
|||
try:
|
||||
if (
|
||||
hasattr(e, "status_code")
|
||||
and e.status_code == 400
|
||||
and e.status_code == 400 # type: ignore
|
||||
and not isinstance(e, litellm.ContextWindowExceededError)
|
||||
): # don't retry a malformed request
|
||||
raise e
|
||||
|
@ -1700,10 +1723,20 @@ class Router:
|
|||
elif fallbacks is not None:
|
||||
verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}")
|
||||
fallback_model_group = None
|
||||
for item in fallbacks:
|
||||
generic_fallback_idx: Optional[int] = None
|
||||
## check for specific model group-specific fallbacks
|
||||
for idx, item in enumerate(fallbacks):
|
||||
if list(item.keys())[0] == model_group:
|
||||
fallback_model_group = item[model_group]
|
||||
break
|
||||
elif list(item.keys())[0] == "*":
|
||||
generic_fallback_idx = idx
|
||||
## if none, check for generic fallback
|
||||
if (
|
||||
fallback_model_group is None
|
||||
and generic_fallback_idx is not None
|
||||
):
|
||||
fallback_model_group = fallbacks[generic_fallback_idx]["*"]
|
||||
|
||||
if fallback_model_group is None:
|
||||
raise original_exception
|
||||
|
|
|
@ -1007,3 +1007,50 @@ async def test_service_unavailable_fallbacks(sync_mode):
|
|||
)
|
||||
|
||||
assert response.model == "gpt-35-turbo"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_model_fallbacks(sync_mode):
|
||||
"""
|
||||
Related issue - https://github.com/BerriAI/litellm/issues/3623
|
||||
|
||||
If model misconfigured, setup a default model for generic fallback
|
||||
"""
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "bad-model",
|
||||
"litellm_params": {
|
||||
"model": "openai/my-bad-model",
|
||||
"api_key": "my-bad-api-key",
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "my-good-model",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
},
|
||||
],
|
||||
default_fallbacks=["my-good-model"],
|
||||
)
|
||||
|
||||
if sync_mode:
|
||||
response = router.completion(
|
||||
model="bad-model",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
mock_testing_fallbacks=True,
|
||||
mock_response="Hey! nice day",
|
||||
)
|
||||
else:
|
||||
response = await router.acompletion(
|
||||
model="bad-model",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
mock_testing_fallbacks=True,
|
||||
mock_response="Hey! nice day",
|
||||
)
|
||||
|
||||
assert isinstance(response, litellm.ModelResponse)
|
||||
assert response.model is not None and response.model == "gpt-4o"
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List, Optional, Union, Dict, Tuple, Literal
|
||||
from typing import List, Optional, Union, Dict, Tuple, Literal, TypedDict
|
||||
import httpx
|
||||
from pydantic import BaseModel, validator, Field
|
||||
from .completion import CompletionRequest
|
||||
|
@ -277,6 +277,47 @@ class updateDeployment(BaseModel):
|
|||
protected_namespaces = ()
|
||||
|
||||
|
||||
class LiteLLMParamsTypedDict(TypedDict, total=False):
|
||||
"""
|
||||
[TODO]
|
||||
- allow additional params (not in list)
|
||||
- set value to none if not set -> don't raise error if value not set
|
||||
"""
|
||||
|
||||
model: str
|
||||
custom_llm_provider: Optional[str]
|
||||
tpm: Optional[int]
|
||||
rpm: Optional[int]
|
||||
api_key: Optional[str]
|
||||
api_base: Optional[str]
|
||||
api_version: Optional[str]
|
||||
timeout: Optional[Union[float, str, httpx.Timeout]]
|
||||
stream_timeout: Optional[Union[float, str]]
|
||||
max_retries: Optional[int]
|
||||
organization: Optional[str] # for openai orgs
|
||||
## UNIFIED PROJECT/REGION ##
|
||||
region_name: Optional[str]
|
||||
## VERTEX AI ##
|
||||
vertex_project: Optional[str]
|
||||
vertex_location: Optional[str]
|
||||
## AWS BEDROCK / SAGEMAKER ##
|
||||
aws_access_key_id: Optional[str]
|
||||
aws_secret_access_key: Optional[str]
|
||||
aws_region_name: Optional[str]
|
||||
## IBM WATSONX ##
|
||||
watsonx_region_name: Optional[str]
|
||||
## CUSTOM PRICING ##
|
||||
input_cost_per_token: Optional[float]
|
||||
output_cost_per_token: Optional[float]
|
||||
input_cost_per_second: Optional[float]
|
||||
output_cost_per_second: Optional[float]
|
||||
|
||||
|
||||
class DeploymentTypedDict(TypedDict):
|
||||
model_name: str
|
||||
litellm_params: LiteLLMParamsTypedDict
|
||||
|
||||
|
||||
class Deployment(BaseModel):
|
||||
model_name: str
|
||||
litellm_params: LiteLLM_Params
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue