Merge pull request #3625 from BerriAI/litellm_router_default_fallbacks

Default routing fallbacks
This commit is contained in:
Ishaan Jaff 2024-05-13 20:47:54 -07:00 committed by GitHub
commit ffcd6b6a63
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 151 additions and 16 deletions

View file

@ -106,11 +106,12 @@ To see how it's implemented - [check out the code](https://github.com/BerriAI/li
## Custom mapping list
Base case - we return the original exception.
Base case - we return `litellm.APIConnectionError` exception (inherits from openai's APIConnectionError exception).
| custom_llm_provider | Timeout | ContextWindowExceededError | BadRequestError | NotFoundError | ContentPolicyViolationError | AuthenticationError | APIError | RateLimitError | ServiceUnavailableError | PermissionDeniedError | UnprocessableEntityError |
|----------------------------|---------|----------------------------|------------------|---------------|-----------------------------|---------------------|----------|----------------|-------------------------|-----------------------|-------------------------|
| openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
| watsonx | | | | | | | |✓| | | |
| text-completion-openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
| custom_openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
| openai_compatible_providers| ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |

View file

@ -15,7 +15,6 @@ import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy
import httpx
import litellm
from ._logging import verbose_logger
from litellm import ( # type: ignore
client,

View file

@ -9,7 +9,7 @@
import copy, httpx
from datetime import datetime
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple, TypedDict
from typing_extensions import overload
import random, threading, time, traceback, uuid
import litellm, openai, hashlib, json
@ -47,6 +47,7 @@ from litellm.types.router import (
updateLiteLLMParams,
RetryPolicy,
AlertingConfig,
DeploymentTypedDict,
)
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.azure import get_azure_ad_token_from_oidc
@ -62,7 +63,7 @@ class Router:
def __init__(
self,
model_list: Optional[list] = None,
model_list: Optional[List[Union[DeploymentTypedDict, Dict]]] = None,
## CACHING ##
redis_url: Optional[str] = None,
redis_host: Optional[str] = None,
@ -83,6 +84,9 @@ class Router:
default_max_parallel_requests: Optional[int] = None,
set_verbose: bool = False,
debug_level: Literal["DEBUG", "INFO"] = "INFO",
default_fallbacks: Optional[
List[str]
] = None, # generic fallbacks, works across all deployments
fallbacks: List = [],
context_window_fallbacks: List = [],
model_group_alias: Optional[dict] = {},
@ -259,6 +263,11 @@ class Router:
self.retry_after = retry_after
self.routing_strategy = routing_strategy
self.fallbacks = fallbacks or litellm.fallbacks
if default_fallbacks is not None:
if self.fallbacks is not None:
self.fallbacks.append({"*": default_fallbacks})
else:
self.fallbacks = [{"*": default_fallbacks}]
self.context_window_fallbacks = (
context_window_fallbacks or litellm.context_window_fallbacks
)
@ -1471,13 +1480,21 @@ class Router:
pass
elif fallbacks is not None:
verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}")
for item in fallbacks:
key_list = list(item.keys())
if len(key_list) == 0:
continue
if key_list[0] == model_group:
generic_fallback_idx: Optional[int] = None
## check for specific model group-specific fallbacks
for idx, item in enumerate(fallbacks):
if list(item.keys())[0] == model_group:
fallback_model_group = item[model_group]
break
elif list(item.keys())[0] == "*":
generic_fallback_idx = idx
## if none, check for generic fallback
if (
fallback_model_group is None
and generic_fallback_idx is not None
):
fallback_model_group = fallbacks[generic_fallback_idx]["*"]
if fallback_model_group is None:
verbose_router_logger.info(
f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}"
@ -1537,7 +1554,7 @@ class Router:
"""
_healthy_deployments = await self._async_get_healthy_deployments(
model=kwargs.get("model"),
model=kwargs.get("model") or "",
)
# raises an exception if this error should not be retries
@ -1644,12 +1661,18 @@ class Router:
Try calling the function_with_retries
If it fails after num_retries, fall back to another model group
"""
mock_testing_fallbacks = kwargs.pop("mock_testing_fallbacks", None)
model_group = kwargs.get("model")
fallbacks = kwargs.get("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.get(
"context_window_fallbacks", self.context_window_fallbacks
)
try:
if mock_testing_fallbacks is not None and mock_testing_fallbacks == True:
raise Exception(
f"This is a mock exception for model={model_group}, to trigger a fallback. Fallbacks={fallbacks}"
)
response = self.function_with_retries(*args, **kwargs)
return response
except Exception as e:
@ -1658,7 +1681,7 @@ class Router:
try:
if (
hasattr(e, "status_code")
and e.status_code == 400
and e.status_code == 400 # type: ignore
and not isinstance(e, litellm.ContextWindowExceededError)
): # don't retry a malformed request
raise e
@ -1700,10 +1723,20 @@ class Router:
elif fallbacks is not None:
verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}")
fallback_model_group = None
for item in fallbacks:
generic_fallback_idx: Optional[int] = None
## check for specific model group-specific fallbacks
for idx, item in enumerate(fallbacks):
if list(item.keys())[0] == model_group:
fallback_model_group = item[model_group]
break
elif list(item.keys())[0] == "*":
generic_fallback_idx = idx
## if none, check for generic fallback
if (
fallback_model_group is None
and generic_fallback_idx is not None
):
fallback_model_group = fallbacks[generic_fallback_idx]["*"]
if fallback_model_group is None:
raise original_exception

View file

@ -26,7 +26,7 @@ model_list = [
}
]
router = litellm.Router(model_list=model_list)
router = litellm.Router(model_list=model_list) # type: ignore
async def _openai_completion():

View file

@ -3342,6 +3342,8 @@ def test_completion_watsonx():
print(response)
except litellm.APIError as e:
pass
except litellm.RateLimitError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@ -3361,6 +3363,8 @@ def test_completion_stream_watsonx():
print(chunk)
except litellm.APIError as e:
pass
except litellm.RateLimitError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@ -3425,6 +3429,8 @@ async def test_acompletion_watsonx():
)
# Add any assertions here to check the response
print(response)
except litellm.RateLimitError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@ -3445,6 +3451,8 @@ async def test_acompletion_stream_watsonx():
# Add any assertions here to check the response
async for chunk in response:
print(chunk)
except litellm.RateLimitError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -494,6 +494,8 @@ def test_watsonx_embeddings():
)
print(f"response: {response}")
assert isinstance(response.usage, litellm.Usage)
except litellm.RateLimitError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -1007,3 +1007,50 @@ async def test_service_unavailable_fallbacks(sync_mode):
)
assert response.model == "gpt-35-turbo"
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_default_model_fallbacks(sync_mode):
"""
Related issue - https://github.com/BerriAI/litellm/issues/3623
If model misconfigured, setup a default model for generic fallback
"""
router = Router(
model_list=[
{
"model_name": "bad-model",
"litellm_params": {
"model": "openai/my-bad-model",
"api_key": "my-bad-api-key",
},
},
{
"model_name": "my-good-model",
"litellm_params": {
"model": "gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
],
default_fallbacks=["my-good-model"],
)
if sync_mode:
response = router.completion(
model="bad-model",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
mock_testing_fallbacks=True,
mock_response="Hey! nice day",
)
else:
response = await router.acompletion(
model="bad-model",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
mock_testing_fallbacks=True,
mock_response="Hey! nice day",
)
assert isinstance(response, litellm.ModelResponse)
assert response.model is not None and response.model == "gpt-4o"

View file

@ -456,7 +456,8 @@ def test_completion_claude_stream():
print(f"completion_response: {complete_response}")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_claude_stream()
def test_completion_claude_2_stream():
litellm.set_verbose = True
@ -1416,6 +1417,8 @@ def test_completion_watsonx_stream():
raise Exception("finish reason not set for last chunk")
if complete_response.strip() == "":
raise Exception("Empty response received")
except litellm.RateLimitError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -1,4 +1,4 @@
from typing import List, Optional, Union, Dict, Tuple, Literal
from typing import List, Optional, Union, Dict, Tuple, Literal, TypedDict
import httpx
from pydantic import BaseModel, validator, Field
from .completion import CompletionRequest
@ -277,6 +277,47 @@ class updateDeployment(BaseModel):
protected_namespaces = ()
class LiteLLMParamsTypedDict(TypedDict, total=False):
"""
[TODO]
- allow additional params (not in list)
- set value to none if not set -> don't raise error if value not set
"""
model: str
custom_llm_provider: Optional[str]
tpm: Optional[int]
rpm: Optional[int]
api_key: Optional[str]
api_base: Optional[str]
api_version: Optional[str]
timeout: Optional[Union[float, str, httpx.Timeout]]
stream_timeout: Optional[Union[float, str]]
max_retries: Optional[int]
organization: Optional[str] # for openai orgs
## UNIFIED PROJECT/REGION ##
region_name: Optional[str]
## VERTEX AI ##
vertex_project: Optional[str]
vertex_location: Optional[str]
## AWS BEDROCK / SAGEMAKER ##
aws_access_key_id: Optional[str]
aws_secret_access_key: Optional[str]
aws_region_name: Optional[str]
## IBM WATSONX ##
watsonx_region_name: Optional[str]
## CUSTOM PRICING ##
input_cost_per_token: Optional[float]
output_cost_per_token: Optional[float]
input_cost_per_second: Optional[float]
output_cost_per_second: Optional[float]
class DeploymentTypedDict(TypedDict):
model_name: str
litellm_params: LiteLLMParamsTypedDict
class Deployment(BaseModel):
model_name: str
litellm_params: LiteLLM_Params

View file

@ -13,6 +13,7 @@ import dotenv, json, traceback, threading, base64, ast
import subprocess, os
from os.path import abspath, join, dirname
import litellm, openai
import itertools
import random, uuid, requests # type: ignore
from functools import wraps
@ -8514,7 +8515,7 @@ def exception_type(
request=original_exception.request,
)
elif custom_llm_provider == "watsonx":
if "token_quota_reached" in error_response:
if "token_quota_reached" in error_str:
exception_mapping_worked = True
raise RateLimitError(
message=f"WatsonxException: Rate Limit Errror - {error_str}",