forked from phoenix/litellm-mirror
Merge pull request #3625 from BerriAI/litellm_router_default_fallbacks
Default routing fallbacks
This commit is contained in:
commit
ffcd6b6a63
10 changed files with 151 additions and 16 deletions
|
@ -106,11 +106,12 @@ To see how it's implemented - [check out the code](https://github.com/BerriAI/li
|
||||||
|
|
||||||
## Custom mapping list
|
## Custom mapping list
|
||||||
|
|
||||||
Base case - we return the original exception.
|
Base case - we return `litellm.APIConnectionError` exception (inherits from openai's APIConnectionError exception).
|
||||||
|
|
||||||
| custom_llm_provider | Timeout | ContextWindowExceededError | BadRequestError | NotFoundError | ContentPolicyViolationError | AuthenticationError | APIError | RateLimitError | ServiceUnavailableError | PermissionDeniedError | UnprocessableEntityError |
|
| custom_llm_provider | Timeout | ContextWindowExceededError | BadRequestError | NotFoundError | ContentPolicyViolationError | AuthenticationError | APIError | RateLimitError | ServiceUnavailableError | PermissionDeniedError | UnprocessableEntityError |
|
||||||
|----------------------------|---------|----------------------------|------------------|---------------|-----------------------------|---------------------|----------|----------------|-------------------------|-----------------------|-------------------------|
|
|----------------------------|---------|----------------------------|------------------|---------------|-----------------------------|---------------------|----------|----------------|-------------------------|-----------------------|-------------------------|
|
||||||
| openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
|
| openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
|
||||||
|
| watsonx | | | | | | | |✓| | | |
|
||||||
| text-completion-openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
|
| text-completion-openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
|
||||||
| custom_openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
|
| custom_openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
|
||||||
| openai_compatible_providers| ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
|
| openai_compatible_providers| ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
|
||||||
|
|
|
@ -15,7 +15,6 @@ import dotenv, traceback, random, asyncio, time, contextvars
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import httpx
|
import httpx
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
from ._logging import verbose_logger
|
from ._logging import verbose_logger
|
||||||
from litellm import ( # type: ignore
|
from litellm import ( # type: ignore
|
||||||
client,
|
client,
|
||||||
|
|
|
@ -9,7 +9,7 @@
|
||||||
|
|
||||||
import copy, httpx
|
import copy, httpx
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple
|
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple, TypedDict
|
||||||
from typing_extensions import overload
|
from typing_extensions import overload
|
||||||
import random, threading, time, traceback, uuid
|
import random, threading, time, traceback, uuid
|
||||||
import litellm, openai, hashlib, json
|
import litellm, openai, hashlib, json
|
||||||
|
@ -47,6 +47,7 @@ from litellm.types.router import (
|
||||||
updateLiteLLMParams,
|
updateLiteLLMParams,
|
||||||
RetryPolicy,
|
RetryPolicy,
|
||||||
AlertingConfig,
|
AlertingConfig,
|
||||||
|
DeploymentTypedDict,
|
||||||
)
|
)
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||||
|
@ -62,7 +63,7 @@ class Router:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_list: Optional[list] = None,
|
model_list: Optional[List[Union[DeploymentTypedDict, Dict]]] = None,
|
||||||
## CACHING ##
|
## CACHING ##
|
||||||
redis_url: Optional[str] = None,
|
redis_url: Optional[str] = None,
|
||||||
redis_host: Optional[str] = None,
|
redis_host: Optional[str] = None,
|
||||||
|
@ -83,6 +84,9 @@ class Router:
|
||||||
default_max_parallel_requests: Optional[int] = None,
|
default_max_parallel_requests: Optional[int] = None,
|
||||||
set_verbose: bool = False,
|
set_verbose: bool = False,
|
||||||
debug_level: Literal["DEBUG", "INFO"] = "INFO",
|
debug_level: Literal["DEBUG", "INFO"] = "INFO",
|
||||||
|
default_fallbacks: Optional[
|
||||||
|
List[str]
|
||||||
|
] = None, # generic fallbacks, works across all deployments
|
||||||
fallbacks: List = [],
|
fallbacks: List = [],
|
||||||
context_window_fallbacks: List = [],
|
context_window_fallbacks: List = [],
|
||||||
model_group_alias: Optional[dict] = {},
|
model_group_alias: Optional[dict] = {},
|
||||||
|
@ -259,6 +263,11 @@ class Router:
|
||||||
self.retry_after = retry_after
|
self.retry_after = retry_after
|
||||||
self.routing_strategy = routing_strategy
|
self.routing_strategy = routing_strategy
|
||||||
self.fallbacks = fallbacks or litellm.fallbacks
|
self.fallbacks = fallbacks or litellm.fallbacks
|
||||||
|
if default_fallbacks is not None:
|
||||||
|
if self.fallbacks is not None:
|
||||||
|
self.fallbacks.append({"*": default_fallbacks})
|
||||||
|
else:
|
||||||
|
self.fallbacks = [{"*": default_fallbacks}]
|
||||||
self.context_window_fallbacks = (
|
self.context_window_fallbacks = (
|
||||||
context_window_fallbacks or litellm.context_window_fallbacks
|
context_window_fallbacks or litellm.context_window_fallbacks
|
||||||
)
|
)
|
||||||
|
@ -1471,13 +1480,21 @@ class Router:
|
||||||
pass
|
pass
|
||||||
elif fallbacks is not None:
|
elif fallbacks is not None:
|
||||||
verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}")
|
verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}")
|
||||||
for item in fallbacks:
|
generic_fallback_idx: Optional[int] = None
|
||||||
key_list = list(item.keys())
|
## check for specific model group-specific fallbacks
|
||||||
if len(key_list) == 0:
|
for idx, item in enumerate(fallbacks):
|
||||||
continue
|
if list(item.keys())[0] == model_group:
|
||||||
if key_list[0] == model_group:
|
|
||||||
fallback_model_group = item[model_group]
|
fallback_model_group = item[model_group]
|
||||||
break
|
break
|
||||||
|
elif list(item.keys())[0] == "*":
|
||||||
|
generic_fallback_idx = idx
|
||||||
|
## if none, check for generic fallback
|
||||||
|
if (
|
||||||
|
fallback_model_group is None
|
||||||
|
and generic_fallback_idx is not None
|
||||||
|
):
|
||||||
|
fallback_model_group = fallbacks[generic_fallback_idx]["*"]
|
||||||
|
|
||||||
if fallback_model_group is None:
|
if fallback_model_group is None:
|
||||||
verbose_router_logger.info(
|
verbose_router_logger.info(
|
||||||
f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}"
|
f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}"
|
||||||
|
@ -1537,7 +1554,7 @@ class Router:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
_healthy_deployments = await self._async_get_healthy_deployments(
|
_healthy_deployments = await self._async_get_healthy_deployments(
|
||||||
model=kwargs.get("model"),
|
model=kwargs.get("model") or "",
|
||||||
)
|
)
|
||||||
|
|
||||||
# raises an exception if this error should not be retries
|
# raises an exception if this error should not be retries
|
||||||
|
@ -1644,12 +1661,18 @@ class Router:
|
||||||
Try calling the function_with_retries
|
Try calling the function_with_retries
|
||||||
If it fails after num_retries, fall back to another model group
|
If it fails after num_retries, fall back to another model group
|
||||||
"""
|
"""
|
||||||
|
mock_testing_fallbacks = kwargs.pop("mock_testing_fallbacks", None)
|
||||||
model_group = kwargs.get("model")
|
model_group = kwargs.get("model")
|
||||||
fallbacks = kwargs.get("fallbacks", self.fallbacks)
|
fallbacks = kwargs.get("fallbacks", self.fallbacks)
|
||||||
context_window_fallbacks = kwargs.get(
|
context_window_fallbacks = kwargs.get(
|
||||||
"context_window_fallbacks", self.context_window_fallbacks
|
"context_window_fallbacks", self.context_window_fallbacks
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
if mock_testing_fallbacks is not None and mock_testing_fallbacks == True:
|
||||||
|
raise Exception(
|
||||||
|
f"This is a mock exception for model={model_group}, to trigger a fallback. Fallbacks={fallbacks}"
|
||||||
|
)
|
||||||
|
|
||||||
response = self.function_with_retries(*args, **kwargs)
|
response = self.function_with_retries(*args, **kwargs)
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -1658,7 +1681,7 @@ class Router:
|
||||||
try:
|
try:
|
||||||
if (
|
if (
|
||||||
hasattr(e, "status_code")
|
hasattr(e, "status_code")
|
||||||
and e.status_code == 400
|
and e.status_code == 400 # type: ignore
|
||||||
and not isinstance(e, litellm.ContextWindowExceededError)
|
and not isinstance(e, litellm.ContextWindowExceededError)
|
||||||
): # don't retry a malformed request
|
): # don't retry a malformed request
|
||||||
raise e
|
raise e
|
||||||
|
@ -1700,10 +1723,20 @@ class Router:
|
||||||
elif fallbacks is not None:
|
elif fallbacks is not None:
|
||||||
verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}")
|
verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}")
|
||||||
fallback_model_group = None
|
fallback_model_group = None
|
||||||
for item in fallbacks:
|
generic_fallback_idx: Optional[int] = None
|
||||||
|
## check for specific model group-specific fallbacks
|
||||||
|
for idx, item in enumerate(fallbacks):
|
||||||
if list(item.keys())[0] == model_group:
|
if list(item.keys())[0] == model_group:
|
||||||
fallback_model_group = item[model_group]
|
fallback_model_group = item[model_group]
|
||||||
break
|
break
|
||||||
|
elif list(item.keys())[0] == "*":
|
||||||
|
generic_fallback_idx = idx
|
||||||
|
## if none, check for generic fallback
|
||||||
|
if (
|
||||||
|
fallback_model_group is None
|
||||||
|
and generic_fallback_idx is not None
|
||||||
|
):
|
||||||
|
fallback_model_group = fallbacks[generic_fallback_idx]["*"]
|
||||||
|
|
||||||
if fallback_model_group is None:
|
if fallback_model_group is None:
|
||||||
raise original_exception
|
raise original_exception
|
||||||
|
|
|
@ -26,7 +26,7 @@ model_list = [
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
router = litellm.Router(model_list=model_list)
|
router = litellm.Router(model_list=model_list) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
async def _openai_completion():
|
async def _openai_completion():
|
||||||
|
|
|
@ -3342,6 +3342,8 @@ def test_completion_watsonx():
|
||||||
print(response)
|
print(response)
|
||||||
except litellm.APIError as e:
|
except litellm.APIError as e:
|
||||||
pass
|
pass
|
||||||
|
except litellm.RateLimitError as e:
|
||||||
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
@ -3361,6 +3363,8 @@ def test_completion_stream_watsonx():
|
||||||
print(chunk)
|
print(chunk)
|
||||||
except litellm.APIError as e:
|
except litellm.APIError as e:
|
||||||
pass
|
pass
|
||||||
|
except litellm.RateLimitError as e:
|
||||||
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
@ -3425,6 +3429,8 @@ async def test_acompletion_watsonx():
|
||||||
)
|
)
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
print(response)
|
print(response)
|
||||||
|
except litellm.RateLimitError as e:
|
||||||
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
@ -3445,6 +3451,8 @@ async def test_acompletion_stream_watsonx():
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
print(chunk)
|
print(chunk)
|
||||||
|
except litellm.RateLimitError as e:
|
||||||
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
|
@ -494,6 +494,8 @@ def test_watsonx_embeddings():
|
||||||
)
|
)
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
assert isinstance(response.usage, litellm.Usage)
|
assert isinstance(response.usage, litellm.Usage)
|
||||||
|
except litellm.RateLimitError as e:
|
||||||
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
|
@ -1007,3 +1007,50 @@ async def test_service_unavailable_fallbacks(sync_mode):
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.model == "gpt-35-turbo"
|
assert response.model == "gpt-35-turbo"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_default_model_fallbacks(sync_mode):
|
||||||
|
"""
|
||||||
|
Related issue - https://github.com/BerriAI/litellm/issues/3623
|
||||||
|
|
||||||
|
If model misconfigured, setup a default model for generic fallback
|
||||||
|
"""
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "bad-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "openai/my-bad-model",
|
||||||
|
"api_key": "my-bad-api-key",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "my-good-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
default_fallbacks=["my-good-model"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if sync_mode:
|
||||||
|
response = router.completion(
|
||||||
|
model="bad-model",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
mock_testing_fallbacks=True,
|
||||||
|
mock_response="Hey! nice day",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = await router.acompletion(
|
||||||
|
model="bad-model",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
mock_testing_fallbacks=True,
|
||||||
|
mock_response="Hey! nice day",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, litellm.ModelResponse)
|
||||||
|
assert response.model is not None and response.model == "gpt-4o"
|
||||||
|
|
|
@ -457,6 +457,7 @@ def test_completion_claude_stream():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
# test_completion_claude_stream()
|
# test_completion_claude_stream()
|
||||||
def test_completion_claude_2_stream():
|
def test_completion_claude_2_stream():
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
@ -1416,6 +1417,8 @@ def test_completion_watsonx_stream():
|
||||||
raise Exception("finish reason not set for last chunk")
|
raise Exception("finish reason not set for last chunk")
|
||||||
if complete_response.strip() == "":
|
if complete_response.strip() == "":
|
||||||
raise Exception("Empty response received")
|
raise Exception("Empty response received")
|
||||||
|
except litellm.RateLimitError as e:
|
||||||
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import List, Optional, Union, Dict, Tuple, Literal
|
from typing import List, Optional, Union, Dict, Tuple, Literal, TypedDict
|
||||||
import httpx
|
import httpx
|
||||||
from pydantic import BaseModel, validator, Field
|
from pydantic import BaseModel, validator, Field
|
||||||
from .completion import CompletionRequest
|
from .completion import CompletionRequest
|
||||||
|
@ -277,6 +277,47 @@ class updateDeployment(BaseModel):
|
||||||
protected_namespaces = ()
|
protected_namespaces = ()
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLMParamsTypedDict(TypedDict, total=False):
|
||||||
|
"""
|
||||||
|
[TODO]
|
||||||
|
- allow additional params (not in list)
|
||||||
|
- set value to none if not set -> don't raise error if value not set
|
||||||
|
"""
|
||||||
|
|
||||||
|
model: str
|
||||||
|
custom_llm_provider: Optional[str]
|
||||||
|
tpm: Optional[int]
|
||||||
|
rpm: Optional[int]
|
||||||
|
api_key: Optional[str]
|
||||||
|
api_base: Optional[str]
|
||||||
|
api_version: Optional[str]
|
||||||
|
timeout: Optional[Union[float, str, httpx.Timeout]]
|
||||||
|
stream_timeout: Optional[Union[float, str]]
|
||||||
|
max_retries: Optional[int]
|
||||||
|
organization: Optional[str] # for openai orgs
|
||||||
|
## UNIFIED PROJECT/REGION ##
|
||||||
|
region_name: Optional[str]
|
||||||
|
## VERTEX AI ##
|
||||||
|
vertex_project: Optional[str]
|
||||||
|
vertex_location: Optional[str]
|
||||||
|
## AWS BEDROCK / SAGEMAKER ##
|
||||||
|
aws_access_key_id: Optional[str]
|
||||||
|
aws_secret_access_key: Optional[str]
|
||||||
|
aws_region_name: Optional[str]
|
||||||
|
## IBM WATSONX ##
|
||||||
|
watsonx_region_name: Optional[str]
|
||||||
|
## CUSTOM PRICING ##
|
||||||
|
input_cost_per_token: Optional[float]
|
||||||
|
output_cost_per_token: Optional[float]
|
||||||
|
input_cost_per_second: Optional[float]
|
||||||
|
output_cost_per_second: Optional[float]
|
||||||
|
|
||||||
|
|
||||||
|
class DeploymentTypedDict(TypedDict):
|
||||||
|
model_name: str
|
||||||
|
litellm_params: LiteLLMParamsTypedDict
|
||||||
|
|
||||||
|
|
||||||
class Deployment(BaseModel):
|
class Deployment(BaseModel):
|
||||||
model_name: str
|
model_name: str
|
||||||
litellm_params: LiteLLM_Params
|
litellm_params: LiteLLM_Params
|
||||||
|
|
|
@ -13,6 +13,7 @@ import dotenv, json, traceback, threading, base64, ast
|
||||||
import subprocess, os
|
import subprocess, os
|
||||||
from os.path import abspath, join, dirname
|
from os.path import abspath, join, dirname
|
||||||
import litellm, openai
|
import litellm, openai
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import random, uuid, requests # type: ignore
|
import random, uuid, requests # type: ignore
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
@ -8514,7 +8515,7 @@ def exception_type(
|
||||||
request=original_exception.request,
|
request=original_exception.request,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "watsonx":
|
elif custom_llm_provider == "watsonx":
|
||||||
if "token_quota_reached" in error_response:
|
if "token_quota_reached" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise RateLimitError(
|
raise RateLimitError(
|
||||||
message=f"WatsonxException: Rate Limit Errror - {error_str}",
|
message=f"WatsonxException: Rate Limit Errror - {error_str}",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue