Merge pull request #3625 from BerriAI/litellm_router_default_fallbacks

Default routing fallbacks
This commit is contained in:
Ishaan Jaff 2024-05-13 20:47:54 -07:00 committed by GitHub
commit ffcd6b6a63
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 151 additions and 16 deletions

View file

@ -106,11 +106,12 @@ To see how it's implemented - [check out the code](https://github.com/BerriAI/li
## Custom mapping list ## Custom mapping list
Base case - we return the original exception. Base case - we return `litellm.APIConnectionError` exception (inherits from openai's APIConnectionError exception).
| custom_llm_provider | Timeout | ContextWindowExceededError | BadRequestError | NotFoundError | ContentPolicyViolationError | AuthenticationError | APIError | RateLimitError | ServiceUnavailableError | PermissionDeniedError | UnprocessableEntityError | | custom_llm_provider | Timeout | ContextWindowExceededError | BadRequestError | NotFoundError | ContentPolicyViolationError | AuthenticationError | APIError | RateLimitError | ServiceUnavailableError | PermissionDeniedError | UnprocessableEntityError |
|----------------------------|---------|----------------------------|------------------|---------------|-----------------------------|---------------------|----------|----------------|-------------------------|-----------------------|-------------------------| |----------------------------|---------|----------------------------|------------------|---------------|-----------------------------|---------------------|----------|----------------|-------------------------|-----------------------|-------------------------|
| openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | | | openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
| watsonx | | | | | | | |✓| | | |
| text-completion-openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | | | text-completion-openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
| custom_openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | | | custom_openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
| openai_compatible_providers| ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | | | openai_compatible_providers| ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |

View file

@ -15,7 +15,6 @@ import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy from copy import deepcopy
import httpx import httpx
import litellm import litellm
from ._logging import verbose_logger from ._logging import verbose_logger
from litellm import ( # type: ignore from litellm import ( # type: ignore
client, client,

View file

@ -9,7 +9,7 @@
import copy, httpx import copy, httpx
from datetime import datetime from datetime import datetime
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple, TypedDict
from typing_extensions import overload from typing_extensions import overload
import random, threading, time, traceback, uuid import random, threading, time, traceback, uuid
import litellm, openai, hashlib, json import litellm, openai, hashlib, json
@ -47,6 +47,7 @@ from litellm.types.router import (
updateLiteLLMParams, updateLiteLLMParams,
RetryPolicy, RetryPolicy,
AlertingConfig, AlertingConfig,
DeploymentTypedDict,
) )
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.azure import get_azure_ad_token_from_oidc from litellm.llms.azure import get_azure_ad_token_from_oidc
@ -62,7 +63,7 @@ class Router:
def __init__( def __init__(
self, self,
model_list: Optional[list] = None, model_list: Optional[List[Union[DeploymentTypedDict, Dict]]] = None,
## CACHING ## ## CACHING ##
redis_url: Optional[str] = None, redis_url: Optional[str] = None,
redis_host: Optional[str] = None, redis_host: Optional[str] = None,
@ -83,6 +84,9 @@ class Router:
default_max_parallel_requests: Optional[int] = None, default_max_parallel_requests: Optional[int] = None,
set_verbose: bool = False, set_verbose: bool = False,
debug_level: Literal["DEBUG", "INFO"] = "INFO", debug_level: Literal["DEBUG", "INFO"] = "INFO",
default_fallbacks: Optional[
List[str]
] = None, # generic fallbacks, works across all deployments
fallbacks: List = [], fallbacks: List = [],
context_window_fallbacks: List = [], context_window_fallbacks: List = [],
model_group_alias: Optional[dict] = {}, model_group_alias: Optional[dict] = {},
@ -259,6 +263,11 @@ class Router:
self.retry_after = retry_after self.retry_after = retry_after
self.routing_strategy = routing_strategy self.routing_strategy = routing_strategy
self.fallbacks = fallbacks or litellm.fallbacks self.fallbacks = fallbacks or litellm.fallbacks
if default_fallbacks is not None:
if self.fallbacks is not None:
self.fallbacks.append({"*": default_fallbacks})
else:
self.fallbacks = [{"*": default_fallbacks}]
self.context_window_fallbacks = ( self.context_window_fallbacks = (
context_window_fallbacks or litellm.context_window_fallbacks context_window_fallbacks or litellm.context_window_fallbacks
) )
@ -1471,13 +1480,21 @@ class Router:
pass pass
elif fallbacks is not None: elif fallbacks is not None:
verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}") verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}")
for item in fallbacks: generic_fallback_idx: Optional[int] = None
key_list = list(item.keys()) ## check for specific model group-specific fallbacks
if len(key_list) == 0: for idx, item in enumerate(fallbacks):
continue if list(item.keys())[0] == model_group:
if key_list[0] == model_group:
fallback_model_group = item[model_group] fallback_model_group = item[model_group]
break break
elif list(item.keys())[0] == "*":
generic_fallback_idx = idx
## if none, check for generic fallback
if (
fallback_model_group is None
and generic_fallback_idx is not None
):
fallback_model_group = fallbacks[generic_fallback_idx]["*"]
if fallback_model_group is None: if fallback_model_group is None:
verbose_router_logger.info( verbose_router_logger.info(
f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}" f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}"
@ -1537,7 +1554,7 @@ class Router:
""" """
_healthy_deployments = await self._async_get_healthy_deployments( _healthy_deployments = await self._async_get_healthy_deployments(
model=kwargs.get("model"), model=kwargs.get("model") or "",
) )
# raises an exception if this error should not be retries # raises an exception if this error should not be retries
@ -1644,12 +1661,18 @@ class Router:
Try calling the function_with_retries Try calling the function_with_retries
If it fails after num_retries, fall back to another model group If it fails after num_retries, fall back to another model group
""" """
mock_testing_fallbacks = kwargs.pop("mock_testing_fallbacks", None)
model_group = kwargs.get("model") model_group = kwargs.get("model")
fallbacks = kwargs.get("fallbacks", self.fallbacks) fallbacks = kwargs.get("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.get( context_window_fallbacks = kwargs.get(
"context_window_fallbacks", self.context_window_fallbacks "context_window_fallbacks", self.context_window_fallbacks
) )
try: try:
if mock_testing_fallbacks is not None and mock_testing_fallbacks == True:
raise Exception(
f"This is a mock exception for model={model_group}, to trigger a fallback. Fallbacks={fallbacks}"
)
response = self.function_with_retries(*args, **kwargs) response = self.function_with_retries(*args, **kwargs)
return response return response
except Exception as e: except Exception as e:
@ -1658,7 +1681,7 @@ class Router:
try: try:
if ( if (
hasattr(e, "status_code") hasattr(e, "status_code")
and e.status_code == 400 and e.status_code == 400 # type: ignore
and not isinstance(e, litellm.ContextWindowExceededError) and not isinstance(e, litellm.ContextWindowExceededError)
): # don't retry a malformed request ): # don't retry a malformed request
raise e raise e
@ -1700,10 +1723,20 @@ class Router:
elif fallbacks is not None: elif fallbacks is not None:
verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}") verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}")
fallback_model_group = None fallback_model_group = None
for item in fallbacks: generic_fallback_idx: Optional[int] = None
## check for specific model group-specific fallbacks
for idx, item in enumerate(fallbacks):
if list(item.keys())[0] == model_group: if list(item.keys())[0] == model_group:
fallback_model_group = item[model_group] fallback_model_group = item[model_group]
break break
elif list(item.keys())[0] == "*":
generic_fallback_idx = idx
## if none, check for generic fallback
if (
fallback_model_group is None
and generic_fallback_idx is not None
):
fallback_model_group = fallbacks[generic_fallback_idx]["*"]
if fallback_model_group is None: if fallback_model_group is None:
raise original_exception raise original_exception

View file

@ -26,7 +26,7 @@ model_list = [
} }
] ]
router = litellm.Router(model_list=model_list) router = litellm.Router(model_list=model_list) # type: ignore
async def _openai_completion(): async def _openai_completion():

View file

@ -3342,6 +3342,8 @@ def test_completion_watsonx():
print(response) print(response)
except litellm.APIError as e: except litellm.APIError as e:
pass pass
except litellm.RateLimitError as e:
pass
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@ -3361,6 +3363,8 @@ def test_completion_stream_watsonx():
print(chunk) print(chunk)
except litellm.APIError as e: except litellm.APIError as e:
pass pass
except litellm.RateLimitError as e:
pass
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@ -3425,6 +3429,8 @@ async def test_acompletion_watsonx():
) )
# Add any assertions here to check the response # Add any assertions here to check the response
print(response) print(response)
except litellm.RateLimitError as e:
pass
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@ -3445,6 +3451,8 @@ async def test_acompletion_stream_watsonx():
# Add any assertions here to check the response # Add any assertions here to check the response
async for chunk in response: async for chunk in response:
print(chunk) print(chunk)
except litellm.RateLimitError as e:
pass
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")

View file

@ -494,6 +494,8 @@ def test_watsonx_embeddings():
) )
print(f"response: {response}") print(f"response: {response}")
assert isinstance(response.usage, litellm.Usage) assert isinstance(response.usage, litellm.Usage)
except litellm.RateLimitError as e:
pass
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")

View file

@ -1007,3 +1007,50 @@ async def test_service_unavailable_fallbacks(sync_mode):
) )
assert response.model == "gpt-35-turbo" assert response.model == "gpt-35-turbo"
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_default_model_fallbacks(sync_mode):
"""
Related issue - https://github.com/BerriAI/litellm/issues/3623
If model misconfigured, setup a default model for generic fallback
"""
router = Router(
model_list=[
{
"model_name": "bad-model",
"litellm_params": {
"model": "openai/my-bad-model",
"api_key": "my-bad-api-key",
},
},
{
"model_name": "my-good-model",
"litellm_params": {
"model": "gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
],
default_fallbacks=["my-good-model"],
)
if sync_mode:
response = router.completion(
model="bad-model",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
mock_testing_fallbacks=True,
mock_response="Hey! nice day",
)
else:
response = await router.acompletion(
model="bad-model",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
mock_testing_fallbacks=True,
mock_response="Hey! nice day",
)
assert isinstance(response, litellm.ModelResponse)
assert response.model is not None and response.model == "gpt-4o"

View file

@ -457,6 +457,7 @@ def test_completion_claude_stream():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_completion_claude_stream() # test_completion_claude_stream()
def test_completion_claude_2_stream(): def test_completion_claude_2_stream():
litellm.set_verbose = True litellm.set_verbose = True
@ -1416,6 +1417,8 @@ def test_completion_watsonx_stream():
raise Exception("finish reason not set for last chunk") raise Exception("finish reason not set for last chunk")
if complete_response.strip() == "": if complete_response.strip() == "":
raise Exception("Empty response received") raise Exception("Empty response received")
except litellm.RateLimitError as e:
pass
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")

View file

@ -1,4 +1,4 @@
from typing import List, Optional, Union, Dict, Tuple, Literal from typing import List, Optional, Union, Dict, Tuple, Literal, TypedDict
import httpx import httpx
from pydantic import BaseModel, validator, Field from pydantic import BaseModel, validator, Field
from .completion import CompletionRequest from .completion import CompletionRequest
@ -277,6 +277,47 @@ class updateDeployment(BaseModel):
protected_namespaces = () protected_namespaces = ()
class LiteLLMParamsTypedDict(TypedDict, total=False):
"""
[TODO]
- allow additional params (not in list)
- set value to none if not set -> don't raise error if value not set
"""
model: str
custom_llm_provider: Optional[str]
tpm: Optional[int]
rpm: Optional[int]
api_key: Optional[str]
api_base: Optional[str]
api_version: Optional[str]
timeout: Optional[Union[float, str, httpx.Timeout]]
stream_timeout: Optional[Union[float, str]]
max_retries: Optional[int]
organization: Optional[str] # for openai orgs
## UNIFIED PROJECT/REGION ##
region_name: Optional[str]
## VERTEX AI ##
vertex_project: Optional[str]
vertex_location: Optional[str]
## AWS BEDROCK / SAGEMAKER ##
aws_access_key_id: Optional[str]
aws_secret_access_key: Optional[str]
aws_region_name: Optional[str]
## IBM WATSONX ##
watsonx_region_name: Optional[str]
## CUSTOM PRICING ##
input_cost_per_token: Optional[float]
output_cost_per_token: Optional[float]
input_cost_per_second: Optional[float]
output_cost_per_second: Optional[float]
class DeploymentTypedDict(TypedDict):
model_name: str
litellm_params: LiteLLMParamsTypedDict
class Deployment(BaseModel): class Deployment(BaseModel):
model_name: str model_name: str
litellm_params: LiteLLM_Params litellm_params: LiteLLM_Params

View file

@ -13,6 +13,7 @@ import dotenv, json, traceback, threading, base64, ast
import subprocess, os import subprocess, os
from os.path import abspath, join, dirname from os.path import abspath, join, dirname
import litellm, openai import litellm, openai
import itertools import itertools
import random, uuid, requests # type: ignore import random, uuid, requests # type: ignore
from functools import wraps from functools import wraps
@ -8514,7 +8515,7 @@ def exception_type(
request=original_exception.request, request=original_exception.request,
) )
elif custom_llm_provider == "watsonx": elif custom_llm_provider == "watsonx":
if "token_quota_reached" in error_response: if "token_quota_reached" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
raise RateLimitError( raise RateLimitError(
message=f"WatsonxException: Rate Limit Errror - {error_str}", message=f"WatsonxException: Rate Limit Errror - {error_str}",