forked from phoenix/litellm-mirror
Merge pull request #3625 from BerriAI/litellm_router_default_fallbacks
Default routing fallbacks
This commit is contained in:
commit
ffcd6b6a63
10 changed files with 151 additions and 16 deletions
|
@ -106,11 +106,12 @@ To see how it's implemented - [check out the code](https://github.com/BerriAI/li
|
|||
|
||||
## Custom mapping list
|
||||
|
||||
Base case - we return the original exception.
|
||||
Base case - we return `litellm.APIConnectionError` exception (inherits from openai's APIConnectionError exception).
|
||||
|
||||
| custom_llm_provider | Timeout | ContextWindowExceededError | BadRequestError | NotFoundError | ContentPolicyViolationError | AuthenticationError | APIError | RateLimitError | ServiceUnavailableError | PermissionDeniedError | UnprocessableEntityError |
|
||||
|----------------------------|---------|----------------------------|------------------|---------------|-----------------------------|---------------------|----------|----------------|-------------------------|-----------------------|-------------------------|
|
||||
| openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
|
||||
| watsonx | | | | | | | |✓| | | |
|
||||
| text-completion-openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
|
||||
| custom_openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
|
||||
| openai_compatible_providers| ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
|
||||
|
|
|
@ -15,7 +15,6 @@ import dotenv, traceback, random, asyncio, time, contextvars
|
|||
from copy import deepcopy
|
||||
import httpx
|
||||
import litellm
|
||||
|
||||
from ._logging import verbose_logger
|
||||
from litellm import ( # type: ignore
|
||||
client,
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
import copy, httpx
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple
|
||||
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple, TypedDict
|
||||
from typing_extensions import overload
|
||||
import random, threading, time, traceback, uuid
|
||||
import litellm, openai, hashlib, json
|
||||
|
@ -47,6 +47,7 @@ from litellm.types.router import (
|
|||
updateLiteLLMParams,
|
||||
RetryPolicy,
|
||||
AlertingConfig,
|
||||
DeploymentTypedDict,
|
||||
)
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||
|
@ -62,7 +63,7 @@ class Router:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
model_list: Optional[list] = None,
|
||||
model_list: Optional[List[Union[DeploymentTypedDict, Dict]]] = None,
|
||||
## CACHING ##
|
||||
redis_url: Optional[str] = None,
|
||||
redis_host: Optional[str] = None,
|
||||
|
@ -83,6 +84,9 @@ class Router:
|
|||
default_max_parallel_requests: Optional[int] = None,
|
||||
set_verbose: bool = False,
|
||||
debug_level: Literal["DEBUG", "INFO"] = "INFO",
|
||||
default_fallbacks: Optional[
|
||||
List[str]
|
||||
] = None, # generic fallbacks, works across all deployments
|
||||
fallbacks: List = [],
|
||||
context_window_fallbacks: List = [],
|
||||
model_group_alias: Optional[dict] = {},
|
||||
|
@ -259,6 +263,11 @@ class Router:
|
|||
self.retry_after = retry_after
|
||||
self.routing_strategy = routing_strategy
|
||||
self.fallbacks = fallbacks or litellm.fallbacks
|
||||
if default_fallbacks is not None:
|
||||
if self.fallbacks is not None:
|
||||
self.fallbacks.append({"*": default_fallbacks})
|
||||
else:
|
||||
self.fallbacks = [{"*": default_fallbacks}]
|
||||
self.context_window_fallbacks = (
|
||||
context_window_fallbacks or litellm.context_window_fallbacks
|
||||
)
|
||||
|
@ -1471,13 +1480,21 @@ class Router:
|
|||
pass
|
||||
elif fallbacks is not None:
|
||||
verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}")
|
||||
for item in fallbacks:
|
||||
key_list = list(item.keys())
|
||||
if len(key_list) == 0:
|
||||
continue
|
||||
if key_list[0] == model_group:
|
||||
generic_fallback_idx: Optional[int] = None
|
||||
## check for specific model group-specific fallbacks
|
||||
for idx, item in enumerate(fallbacks):
|
||||
if list(item.keys())[0] == model_group:
|
||||
fallback_model_group = item[model_group]
|
||||
break
|
||||
elif list(item.keys())[0] == "*":
|
||||
generic_fallback_idx = idx
|
||||
## if none, check for generic fallback
|
||||
if (
|
||||
fallback_model_group is None
|
||||
and generic_fallback_idx is not None
|
||||
):
|
||||
fallback_model_group = fallbacks[generic_fallback_idx]["*"]
|
||||
|
||||
if fallback_model_group is None:
|
||||
verbose_router_logger.info(
|
||||
f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}"
|
||||
|
@ -1537,7 +1554,7 @@ class Router:
|
|||
|
||||
"""
|
||||
_healthy_deployments = await self._async_get_healthy_deployments(
|
||||
model=kwargs.get("model"),
|
||||
model=kwargs.get("model") or "",
|
||||
)
|
||||
|
||||
# raises an exception if this error should not be retries
|
||||
|
@ -1644,12 +1661,18 @@ class Router:
|
|||
Try calling the function_with_retries
|
||||
If it fails after num_retries, fall back to another model group
|
||||
"""
|
||||
mock_testing_fallbacks = kwargs.pop("mock_testing_fallbacks", None)
|
||||
model_group = kwargs.get("model")
|
||||
fallbacks = kwargs.get("fallbacks", self.fallbacks)
|
||||
context_window_fallbacks = kwargs.get(
|
||||
"context_window_fallbacks", self.context_window_fallbacks
|
||||
)
|
||||
try:
|
||||
if mock_testing_fallbacks is not None and mock_testing_fallbacks == True:
|
||||
raise Exception(
|
||||
f"This is a mock exception for model={model_group}, to trigger a fallback. Fallbacks={fallbacks}"
|
||||
)
|
||||
|
||||
response = self.function_with_retries(*args, **kwargs)
|
||||
return response
|
||||
except Exception as e:
|
||||
|
@ -1658,7 +1681,7 @@ class Router:
|
|||
try:
|
||||
if (
|
||||
hasattr(e, "status_code")
|
||||
and e.status_code == 400
|
||||
and e.status_code == 400 # type: ignore
|
||||
and not isinstance(e, litellm.ContextWindowExceededError)
|
||||
): # don't retry a malformed request
|
||||
raise e
|
||||
|
@ -1700,10 +1723,20 @@ class Router:
|
|||
elif fallbacks is not None:
|
||||
verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}")
|
||||
fallback_model_group = None
|
||||
for item in fallbacks:
|
||||
generic_fallback_idx: Optional[int] = None
|
||||
## check for specific model group-specific fallbacks
|
||||
for idx, item in enumerate(fallbacks):
|
||||
if list(item.keys())[0] == model_group:
|
||||
fallback_model_group = item[model_group]
|
||||
break
|
||||
elif list(item.keys())[0] == "*":
|
||||
generic_fallback_idx = idx
|
||||
## if none, check for generic fallback
|
||||
if (
|
||||
fallback_model_group is None
|
||||
and generic_fallback_idx is not None
|
||||
):
|
||||
fallback_model_group = fallbacks[generic_fallback_idx]["*"]
|
||||
|
||||
if fallback_model_group is None:
|
||||
raise original_exception
|
||||
|
|
|
@ -26,7 +26,7 @@ model_list = [
|
|||
}
|
||||
]
|
||||
|
||||
router = litellm.Router(model_list=model_list)
|
||||
router = litellm.Router(model_list=model_list) # type: ignore
|
||||
|
||||
|
||||
async def _openai_completion():
|
||||
|
|
|
@ -3342,6 +3342,8 @@ def test_completion_watsonx():
|
|||
print(response)
|
||||
except litellm.APIError as e:
|
||||
pass
|
||||
except litellm.RateLimitError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
@ -3361,6 +3363,8 @@ def test_completion_stream_watsonx():
|
|||
print(chunk)
|
||||
except litellm.APIError as e:
|
||||
pass
|
||||
except litellm.RateLimitError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
@ -3425,6 +3429,8 @@ async def test_acompletion_watsonx():
|
|||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
except litellm.RateLimitError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
@ -3445,6 +3451,8 @@ async def test_acompletion_stream_watsonx():
|
|||
# Add any assertions here to check the response
|
||||
async for chunk in response:
|
||||
print(chunk)
|
||||
except litellm.RateLimitError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
|
|
@ -494,6 +494,8 @@ def test_watsonx_embeddings():
|
|||
)
|
||||
print(f"response: {response}")
|
||||
assert isinstance(response.usage, litellm.Usage)
|
||||
except litellm.RateLimitError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
|
|
@ -1007,3 +1007,50 @@ async def test_service_unavailable_fallbacks(sync_mode):
|
|||
)
|
||||
|
||||
assert response.model == "gpt-35-turbo"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_model_fallbacks(sync_mode):
|
||||
"""
|
||||
Related issue - https://github.com/BerriAI/litellm/issues/3623
|
||||
|
||||
If model misconfigured, setup a default model for generic fallback
|
||||
"""
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "bad-model",
|
||||
"litellm_params": {
|
||||
"model": "openai/my-bad-model",
|
||||
"api_key": "my-bad-api-key",
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "my-good-model",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
},
|
||||
],
|
||||
default_fallbacks=["my-good-model"],
|
||||
)
|
||||
|
||||
if sync_mode:
|
||||
response = router.completion(
|
||||
model="bad-model",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
mock_testing_fallbacks=True,
|
||||
mock_response="Hey! nice day",
|
||||
)
|
||||
else:
|
||||
response = await router.acompletion(
|
||||
model="bad-model",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
mock_testing_fallbacks=True,
|
||||
mock_response="Hey! nice day",
|
||||
)
|
||||
|
||||
assert isinstance(response, litellm.ModelResponse)
|
||||
assert response.model is not None and response.model == "gpt-4o"
|
||||
|
|
|
@ -456,7 +456,8 @@ def test_completion_claude_stream():
|
|||
print(f"completion_response: {complete_response}")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
|
||||
# test_completion_claude_stream()
|
||||
def test_completion_claude_2_stream():
|
||||
litellm.set_verbose = True
|
||||
|
@ -1416,6 +1417,8 @@ def test_completion_watsonx_stream():
|
|||
raise Exception("finish reason not set for last chunk")
|
||||
if complete_response.strip() == "":
|
||||
raise Exception("Empty response received")
|
||||
except litellm.RateLimitError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List, Optional, Union, Dict, Tuple, Literal
|
||||
from typing import List, Optional, Union, Dict, Tuple, Literal, TypedDict
|
||||
import httpx
|
||||
from pydantic import BaseModel, validator, Field
|
||||
from .completion import CompletionRequest
|
||||
|
@ -277,6 +277,47 @@ class updateDeployment(BaseModel):
|
|||
protected_namespaces = ()
|
||||
|
||||
|
||||
class LiteLLMParamsTypedDict(TypedDict, total=False):
|
||||
"""
|
||||
[TODO]
|
||||
- allow additional params (not in list)
|
||||
- set value to none if not set -> don't raise error if value not set
|
||||
"""
|
||||
|
||||
model: str
|
||||
custom_llm_provider: Optional[str]
|
||||
tpm: Optional[int]
|
||||
rpm: Optional[int]
|
||||
api_key: Optional[str]
|
||||
api_base: Optional[str]
|
||||
api_version: Optional[str]
|
||||
timeout: Optional[Union[float, str, httpx.Timeout]]
|
||||
stream_timeout: Optional[Union[float, str]]
|
||||
max_retries: Optional[int]
|
||||
organization: Optional[str] # for openai orgs
|
||||
## UNIFIED PROJECT/REGION ##
|
||||
region_name: Optional[str]
|
||||
## VERTEX AI ##
|
||||
vertex_project: Optional[str]
|
||||
vertex_location: Optional[str]
|
||||
## AWS BEDROCK / SAGEMAKER ##
|
||||
aws_access_key_id: Optional[str]
|
||||
aws_secret_access_key: Optional[str]
|
||||
aws_region_name: Optional[str]
|
||||
## IBM WATSONX ##
|
||||
watsonx_region_name: Optional[str]
|
||||
## CUSTOM PRICING ##
|
||||
input_cost_per_token: Optional[float]
|
||||
output_cost_per_token: Optional[float]
|
||||
input_cost_per_second: Optional[float]
|
||||
output_cost_per_second: Optional[float]
|
||||
|
||||
|
||||
class DeploymentTypedDict(TypedDict):
|
||||
model_name: str
|
||||
litellm_params: LiteLLMParamsTypedDict
|
||||
|
||||
|
||||
class Deployment(BaseModel):
|
||||
model_name: str
|
||||
litellm_params: LiteLLM_Params
|
||||
|
|
|
@ -13,6 +13,7 @@ import dotenv, json, traceback, threading, base64, ast
|
|||
import subprocess, os
|
||||
from os.path import abspath, join, dirname
|
||||
import litellm, openai
|
||||
|
||||
import itertools
|
||||
import random, uuid, requests # type: ignore
|
||||
from functools import wraps
|
||||
|
@ -8514,7 +8515,7 @@ def exception_type(
|
|||
request=original_exception.request,
|
||||
)
|
||||
elif custom_llm_provider == "watsonx":
|
||||
if "token_quota_reached" in error_response:
|
||||
if "token_quota_reached" in error_str:
|
||||
exception_mapping_worked = True
|
||||
raise RateLimitError(
|
||||
message=f"WatsonxException: Rate Limit Errror - {error_str}",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue