mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(router.py): cooldown deployments, for 401 errors
This commit is contained in:
parent
8ee51a96f4
commit
1baad80c7d
6 changed files with 165 additions and 14 deletions
|
@ -387,6 +387,19 @@ def mock_completion(
|
|||
- If 'stream' is True, it returns a response that mimics the behavior of a streaming completion.
|
||||
"""
|
||||
try:
|
||||
## LOGGING
|
||||
logging.pre_call(
|
||||
input=messages,
|
||||
api_key="mock-key",
|
||||
)
|
||||
if isinstance(mock_response, Exception):
|
||||
raise litellm.APIError(
|
||||
status_code=500, # type: ignore
|
||||
message=str(mock_response),
|
||||
llm_provider="openai", # type: ignore
|
||||
model=model, # type: ignore
|
||||
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
|
||||
)
|
||||
model_response = ModelResponse(stream=stream)
|
||||
if stream is True:
|
||||
# don't try to access stream object,
|
||||
|
|
|
@ -1418,13 +1418,6 @@ class Router:
|
|||
traceback.print_exc()
|
||||
raise original_exception
|
||||
|
||||
async def _async_router_should_retry(
|
||||
self, e: Exception, remaining_retries: int, num_retries: int
|
||||
):
|
||||
"""
|
||||
Calculate back-off, then retry
|
||||
"""
|
||||
|
||||
async def async_function_with_retries(self, *args, **kwargs):
|
||||
verbose_router_logger.debug(
|
||||
f"Inside async function with retries: args - {args}; kwargs - {kwargs}"
|
||||
|
@ -1674,6 +1667,7 @@ class Router:
|
|||
context_window_fallbacks = kwargs.pop(
|
||||
"context_window_fallbacks", self.context_window_fallbacks
|
||||
)
|
||||
|
||||
try:
|
||||
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
||||
response = original_function(*args, **kwargs)
|
||||
|
@ -1751,10 +1745,11 @@ class Router:
|
|||
) # i.e. azure
|
||||
metadata = kwargs.get("litellm_params", {}).get("metadata", None)
|
||||
_model_info = kwargs.get("litellm_params", {}).get("model_info", {})
|
||||
|
||||
if isinstance(_model_info, dict):
|
||||
deployment_id = _model_info.get("id", None)
|
||||
self._set_cooldown_deployments(
|
||||
deployment_id
|
||||
exception_status=exception_status, deployment=deployment_id
|
||||
) # setting deployment_id in cooldown deployments
|
||||
if custom_llm_provider:
|
||||
model_name = f"{custom_llm_provider}/{model_name}"
|
||||
|
@ -1814,9 +1809,15 @@ class Router:
|
|||
key=rpm_key, value=request_count, local_only=True
|
||||
) # don't change existing ttl
|
||||
|
||||
def _set_cooldown_deployments(self, deployment: Optional[str] = None):
|
||||
def _set_cooldown_deployments(
|
||||
self, exception_status: Union[str, int], deployment: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Add a model to the list of models being cooled down for that minute, if it exceeds the allowed fails / minute
|
||||
|
||||
or
|
||||
|
||||
the exception is not one that should be immediately retried (e.g. 401)
|
||||
"""
|
||||
if deployment is None:
|
||||
return
|
||||
|
@ -1833,7 +1834,20 @@ class Router:
|
|||
f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {self.allowed_fails}"
|
||||
)
|
||||
cooldown_time = self.cooldown_time or 1
|
||||
if updated_fails > self.allowed_fails:
|
||||
|
||||
if isinstance(exception_status, str):
|
||||
try:
|
||||
exception_status = int(exception_status)
|
||||
except Exception as e:
|
||||
verbose_router_logger.debug(
|
||||
"Unable to cast exception status to int {}. Defaulting to status=500.".format(
|
||||
exception_status
|
||||
)
|
||||
)
|
||||
exception_status = 500
|
||||
_should_retry = litellm._should_retry(status_code=exception_status)
|
||||
|
||||
if updated_fails > self.allowed_fails or _should_retry == False:
|
||||
# get the current cooldown list for that minute
|
||||
cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls
|
||||
cached_value = self.cache.get_cache(key=cooldown_key)
|
||||
|
|
|
@ -19,6 +19,7 @@ def setup_and_teardown():
|
|||
0, os.path.abspath("../..")
|
||||
) # Adds the project directory to the system path
|
||||
import litellm
|
||||
from litellm import Router
|
||||
|
||||
importlib.reload(litellm)
|
||||
import asyncio
|
||||
|
|
|
@ -1154,6 +1154,7 @@ def test_consistent_model_id():
|
|||
assert id1 == id2
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="local test")
|
||||
def test_reading_keys_os_environ():
|
||||
import openai
|
||||
|
||||
|
@ -1253,6 +1254,7 @@ def test_reading_keys_os_environ():
|
|||
# test_reading_keys_os_environ()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="local test")
|
||||
def test_reading_openai_keys_os_environ():
|
||||
import openai
|
||||
|
||||
|
|
|
@ -22,10 +22,10 @@ class MyCustomHandler(CustomLogger):
|
|||
def log_pre_api_call(self, model, messages, kwargs):
|
||||
print(f"Pre-API Call")
|
||||
print(
|
||||
f"previous_models: {kwargs['litellm_params']['metadata']['previous_models']}"
|
||||
f"previous_models: {kwargs['litellm_params']['metadata'].get('previous_models', None)}"
|
||||
)
|
||||
self.previous_models += len(
|
||||
kwargs["litellm_params"]["metadata"]["previous_models"]
|
||||
self.previous_models = len(
|
||||
kwargs["litellm_params"]["metadata"].get("previous_models", [])
|
||||
) # {"previous_models": [{"model": litellm_model_name, "exception_type": AuthenticationError, "exception_string": <complete_traceback>}]}
|
||||
print(f"self.previous_models: {self.previous_models}")
|
||||
|
||||
|
@ -140,7 +140,7 @@ def test_sync_fallbacks():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_fallbacks():
|
||||
litellm.set_verbose = False
|
||||
litellm.set_verbose = True
|
||||
model_list = [
|
||||
{ # list of model deployments
|
||||
"model_name": "azure/gpt-3.5-turbo", # openai model name
|
||||
|
|
121
litellm/tests/test_router_retries.py
Normal file
121
litellm/tests/test_router_retries.py
Normal file
|
@ -0,0 +1,121 @@
|
|||
#### What this tests ####
|
||||
# This tests calling router with fallback models
|
||||
|
||||
import sys, os, time
|
||||
import traceback, asyncio
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
import litellm
|
||||
from litellm import Router
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
|
||||
class MyCustomHandler(CustomLogger):
|
||||
success: bool = False
|
||||
failure: bool = False
|
||||
previous_models: int = 0
|
||||
|
||||
def log_pre_api_call(self, model, messages, kwargs):
|
||||
print(f"Pre-API Call")
|
||||
print(
|
||||
f"previous_models: {kwargs['litellm_params']['metadata'].get('previous_models', None)}"
|
||||
)
|
||||
self.previous_models = len(
|
||||
kwargs["litellm_params"]["metadata"].get("previous_models", [])
|
||||
) # {"previous_models": [{"model": litellm_model_name, "exception_type": AuthenticationError, "exception_string": <complete_traceback>}]}
|
||||
print(f"self.previous_models: {self.previous_models}")
|
||||
|
||||
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
||||
print(
|
||||
f"Post-API Call - response object: {response_obj}; model: {kwargs['model']}"
|
||||
)
|
||||
|
||||
def log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Stream")
|
||||
|
||||
def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Stream")
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Success")
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Success")
|
||||
|
||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Failure")
|
||||
|
||||
|
||||
"""
|
||||
Test sync + async
|
||||
|
||||
- Authorization Errors
|
||||
- Random API Error
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.parametrize("error_type", ["Authorization Error", "API Error"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_retries_errors(sync_mode, error_type):
|
||||
"""
|
||||
- Auth Error -> 0 retries
|
||||
- API Error -> 2 retries
|
||||
"""
|
||||
|
||||
_api_key = (
|
||||
"bad-key" if error_type == "Authorization Error" else os.getenv("AZURE_API_KEY")
|
||||
)
|
||||
print(f"_api_key: {_api_key}")
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "azure/gpt-3.5-turbo", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-functioncalling",
|
||||
"api_key": _api_key,
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800,
|
||||
},
|
||||
]
|
||||
|
||||
router = Router(model_list=model_list, allowed_fails=3)
|
||||
|
||||
customHandler = MyCustomHandler()
|
||||
litellm.callbacks = [customHandler]
|
||||
user_message = "Hello, how are you?"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
|
||||
kwargs = {
|
||||
"model": "azure/gpt-3.5-turbo",
|
||||
"messages": messages,
|
||||
"mock_response": (
|
||||
None
|
||||
if error_type == "Authorization Error"
|
||||
else Exception("Invalid Request")
|
||||
),
|
||||
}
|
||||
|
||||
try:
|
||||
if sync_mode:
|
||||
response = router.completion(**kwargs)
|
||||
else:
|
||||
response = await router.acompletion(**kwargs)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
await asyncio.sleep(
|
||||
0.05
|
||||
) # allow a delay as success_callbacks are on a separate thread
|
||||
print(f"customHandler.previous_models: {customHandler.previous_models}")
|
||||
|
||||
if error_type == "Authorization Error":
|
||||
assert customHandler.previous_models == 0 # 0 retries
|
||||
else:
|
||||
assert customHandler.previous_models == 2 # 2 retries
|
Loading…
Add table
Add a link
Reference in a new issue