(testing) Router add testing coverage (#6253)

* test: add more router code coverage

* test: additional router testing coverage

* fix: fix linting error

* test: fix tests for ci/cd

* test: fix test

* test: handle flaky tests

---------

Co-authored-by: Krrish Dholakia <krrishdholakia@gmail.com>
This commit is contained in:
Ishaan Jaff 2024-10-16 20:02:27 +05:30 committed by GitHub
parent 54ebdbf7ce
commit 8530000b44
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 706 additions and 106 deletions

View file

@ -861,11 +861,23 @@ class Router:
self.fail_calls[model_name] += 1 self.fail_calls[model_name] += 1
raise e raise e
def _update_kwargs_with_default_litellm_params(self, kwargs: dict) -> None:
"""
Adds default litellm params to kwargs, if set.
"""
for k, v in self.default_litellm_params.items():
if (
k not in kwargs and v is not None
): # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
def _update_kwargs_with_deployment(self, deployment: dict, kwargs: dict) -> None: def _update_kwargs_with_deployment(self, deployment: dict, kwargs: dict) -> None:
""" """
Adds selected deployment, model_info and api_base to kwargs["metadata"] 2 jobs:
- Adds selected deployment, model_info and api_base to kwargs["metadata"] (used for logging)
This is used in litellm logging callbacks - Adds default litellm params to kwargs, if set.
""" """
kwargs.setdefault("metadata", {}).update( kwargs.setdefault("metadata", {}).update(
{ {
@ -875,13 +887,7 @@ class Router:
} }
) )
kwargs["model_info"] = deployment.get("model_info", {}) kwargs["model_info"] = deployment.get("model_info", {})
for k, v in self.default_litellm_params.items(): self._update_kwargs_with_default_litellm_params(kwargs=kwargs)
if (
k not in kwargs and v is not None
): # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
def _get_async_openai_model_client(self, deployment: dict, kwargs: dict): def _get_async_openai_model_client(self, deployment: dict, kwargs: dict):
""" """
@ -910,6 +916,7 @@ class Router:
return model_client return model_client
def _get_timeout(self, kwargs: dict, data: dict) -> Optional[Union[float, int]]: def _get_timeout(self, kwargs: dict, data: dict) -> Optional[Union[float, int]]:
"""Helper to get timeout from kwargs or deployment params"""
timeout = ( timeout = (
data.get( data.get(
"timeout", None "timeout", None
@ -3414,11 +3421,10 @@ class Router:
): ):
""" """
Track remaining tpm/rpm quota for model in model_list Track remaining tpm/rpm quota for model in model_list
Currently, only updates TPM usage.
""" """
try: try:
"""
Update TPM usage on success
"""
if kwargs["litellm_params"].get("metadata") is None: if kwargs["litellm_params"].get("metadata") is None:
pass pass
else: else:
@ -3459,6 +3465,8 @@ class Router:
deployment_id=id, deployment_id=id,
) )
return tpm_key
except Exception as e: except Exception as e:
verbose_router_logger.exception( verbose_router_logger.exception(
"litellm.router.Router::deployment_callback_on_success(): Exception occured - {}".format( "litellm.router.Router::deployment_callback_on_success(): Exception occured - {}".format(
@ -3473,7 +3481,14 @@ class Router:
completion_response, # response from completion completion_response, # response from completion
start_time, start_time,
end_time, # start/end time end_time, # start/end time
): ) -> Optional[str]:
"""
Tracks the number of successes for a deployment in the current minute (using in-memory cache)
Returns:
- key: str - The key used to increment the cache
- None: if no key is found
"""
id = None id = None
if kwargs["litellm_params"].get("metadata") is None: if kwargs["litellm_params"].get("metadata") is None:
pass pass
@ -3482,15 +3497,18 @@ class Router:
model_info = kwargs["litellm_params"].get("model_info", {}) or {} model_info = kwargs["litellm_params"].get("model_info", {}) or {}
id = model_info.get("id", None) id = model_info.get("id", None)
if model_group is None or id is None: if model_group is None or id is None:
return return None
elif isinstance(id, int): elif isinstance(id, int):
id = str(id) id = str(id)
if id is not None: if id is not None:
increment_deployment_successes_for_current_minute( key = increment_deployment_successes_for_current_minute(
litellm_router_instance=self, litellm_router_instance=self,
deployment_id=id, deployment_id=id,
) )
return key
return None
def deployment_callback_on_failure( def deployment_callback_on_failure(
self, self,
@ -3498,15 +3516,19 @@ class Router:
completion_response, # response from completion completion_response, # response from completion
start_time, start_time,
end_time, # start/end time end_time, # start/end time
): ) -> bool:
"""
2 jobs:
- Tracks the number of failures for a deployment in the current minute (using in-memory cache)
- Puts the deployment in cooldown if it exceeds the allowed fails / minute
Returns:
- True if the deployment should be put in cooldown
- False if the deployment should not be put in cooldown
"""
try: try:
exception = kwargs.get("exception", None) exception = kwargs.get("exception", None)
exception_status = getattr(exception, "status_code", "") exception_status = getattr(exception, "status_code", "")
model_name = kwargs.get("model", None) # i.e. gpt35turbo
custom_llm_provider = kwargs.get("litellm_params", {}).get(
"custom_llm_provider", None
) # i.e. azure
kwargs.get("litellm_params", {}).get("metadata", None)
_model_info = kwargs.get("litellm_params", {}).get("model_info", {}) _model_info = kwargs.get("litellm_params", {}).get("model_info", {})
exception_headers = litellm.litellm_core_utils.exception_mapping_utils._get_response_headers( exception_headers = litellm.litellm_core_utils.exception_mapping_utils._get_response_headers(
@ -3535,15 +3557,17 @@ class Router:
litellm_router_instance=self, litellm_router_instance=self,
deployment_id=deployment_id, deployment_id=deployment_id,
) )
_set_cooldown_deployments( result = _set_cooldown_deployments(
litellm_router_instance=self, litellm_router_instance=self,
exception_status=exception_status, exception_status=exception_status,
original_exception=exception, original_exception=exception,
deployment=deployment_id, deployment=deployment_id,
time_to_cooldown=_time_to_cooldown, time_to_cooldown=_time_to_cooldown,
) # setting deployment_id in cooldown deployments ) # setting deployment_id in cooldown deployments
if custom_llm_provider:
model_name = f"{custom_llm_provider}/{model_name}" return result
else:
return False
except Exception as e: except Exception as e:
raise e raise e
@ -3582,9 +3606,12 @@ class Router:
except Exception as e: except Exception as e:
raise e raise e
def _update_usage(self, deployment_id: str): def _update_usage(self, deployment_id: str) -> int:
""" """
Update deployment rpm for that minute Update deployment rpm for that minute
Returns:
- int: request count
""" """
rpm_key = deployment_id rpm_key = deployment_id
@ -3600,6 +3627,8 @@ class Router:
key=rpm_key, value=request_count, local_only=True key=rpm_key, value=request_count, local_only=True
) # don't change existing ttl ) # don't change existing ttl
return request_count
def _is_cooldown_required( def _is_cooldown_required(
self, self,
model_id: str, model_id: str,
@ -3778,7 +3807,7 @@ class Router:
for _callback in litellm.callbacks: for _callback in litellm.callbacks:
if isinstance(_callback, CustomLogger): if isinstance(_callback, CustomLogger):
try: try:
_ = await _callback.async_pre_call_check(deployment) await _callback.async_pre_call_check(deployment)
except litellm.RateLimitError as e: except litellm.RateLimitError as e:
## LOG FAILURE EVENT ## LOG FAILURE EVENT
if logging_obj is not None: if logging_obj is not None:
@ -3848,10 +3877,23 @@ class Router:
return hash_object.hexdigest() return hash_object.hexdigest()
def _create_deployment( def _create_deployment(
self, model: dict, _model_name: str, _litellm_params: dict, _model_info: dict self,
): deployment_info: dict,
_model_name: str,
_litellm_params: dict,
_model_info: dict,
) -> Optional[Deployment]:
"""
Create a deployment object and add it to the model list
If the deployment is not active for the current environment, it is ignored
Returns:
- Deployment: The deployment object
- None: If the deployment is not active for the current environment (if 'supported_environments' is set in litellm_params)
"""
deployment = Deployment( deployment = Deployment(
**model, **deployment_info,
model_name=_model_name, model_name=_model_name,
litellm_params=LiteLLM_Params(**_litellm_params), litellm_params=LiteLLM_Params(**_litellm_params),
model_info=_model_info, model_info=_model_info,
@ -3870,18 +3912,18 @@ class Router:
) )
## Check if LLM Deployment is allowed for this deployment ## Check if LLM Deployment is allowed for this deployment
if deployment.model_info and "supported_environments" in deployment.model_info: if self.deployment_is_active_for_environment(deployment=deployment) is not True:
if ( verbose_router_logger.warning(
self.deployment_is_active_for_environment(deployment=deployment) f"Ignoring deployment {deployment.model_name} as it is not active for environment {deployment.model_info['supported_environments']}"
is not True )
): return None
return
deployment = self._add_deployment(deployment=deployment) deployment = self._add_deployment(deployment=deployment)
model = deployment.to_json(exclude_none=True) model = deployment.to_json(exclude_none=True)
self.model_list.append(model) self.model_list.append(model)
return deployment
def deployment_is_active_for_environment(self, deployment: Deployment) -> bool: def deployment_is_active_for_environment(self, deployment: Deployment) -> bool:
""" """
@ -3896,6 +3938,12 @@ class Router:
- ValueError: If LITELLM_ENVIRONMENT is not set in .env or not one of the valid values - ValueError: If LITELLM_ENVIRONMENT is not set in .env or not one of the valid values
- ValueError: If supported_environments is not set in model_info or not one of the valid values - ValueError: If supported_environments is not set in model_info or not one of the valid values
""" """
if (
deployment.model_info is None
or "supported_environments" not in deployment.model_info
or deployment.model_info["supported_environments"] is None
):
return True
litellm_environment = get_secret_str(secret_name="LITELLM_ENVIRONMENT") litellm_environment = get_secret_str(secret_name="LITELLM_ENVIRONMENT")
if litellm_environment is None: if litellm_environment is None:
raise ValueError( raise ValueError(
@ -3913,7 +3961,6 @@ class Router:
f"supported_environments must be one of {VALID_LITELLM_ENVIRONMENTS}. but set as: {_env} for deployment: {deployment}" f"supported_environments must be one of {VALID_LITELLM_ENVIRONMENTS}. but set as: {_env} for deployment: {deployment}"
) )
# validate litellm_environment is one of LiteLLMEnvironment
if litellm_environment in deployment.model_info["supported_environments"]: if litellm_environment in deployment.model_info["supported_environments"]:
return True return True
return False return False
@ -3946,14 +3993,14 @@ class Router:
for org in _litellm_params["organization"]: for org in _litellm_params["organization"]:
_litellm_params["organization"] = org _litellm_params["organization"] = org
self._create_deployment( self._create_deployment(
model=model, deployment_info=model,
_model_name=_model_name, _model_name=_model_name,
_litellm_params=_litellm_params, _litellm_params=_litellm_params,
_model_info=_model_info, _model_info=_model_info,
) )
else: else:
self._create_deployment( self._create_deployment(
model=model, deployment_info=model,
_model_name=_model_name, _model_name=_model_name,
_litellm_params=_litellm_params, _litellm_params=_litellm_params,
_model_info=_model_info, _model_info=_model_info,
@ -4118,9 +4165,9 @@ class Router:
if removal_idx is not None: if removal_idx is not None:
self.model_list.pop(removal_idx) self.model_list.pop(removal_idx)
else:
# if the model_id is not in router # if the model_id is not in router
self.add_deployment(deployment=deployment) self.add_deployment(deployment=deployment)
return deployment return deployment
def delete_deployment(self, id: str) -> Optional[Deployment]: def delete_deployment(self, id: str) -> Optional[Deployment]:
@ -4628,16 +4675,13 @@ class Router:
from collections import defaultdict from collections import defaultdict
access_groups = defaultdict(list) access_groups = defaultdict(list)
if self.access_groups:
return self.access_groups
if self.model_list: if self.model_list:
for m in self.model_list: for m in self.model_list:
for group in m.get("model_info", {}).get("access_groups", []): for group in m.get("model_info", {}).get("access_groups", []):
model_name = m["model_name"] model_name = m["model_name"]
access_groups[group].append(model_name) access_groups[group].append(model_name)
# set access groups
self.access_groups = access_groups
return access_groups return access_groups
def get_settings(self): def get_settings(self):
@ -4672,6 +4716,9 @@ class Router:
return _settings_to_return return _settings_to_return
def update_settings(self, **kwargs): def update_settings(self, **kwargs):
"""
Update the router settings.
"""
# only the following settings are allowed to be configured # only the following settings are allowed to be configured
_allowed_settings = [ _allowed_settings = [
"routing_strategy_args", "routing_strategy_args",
@ -5367,66 +5414,16 @@ class Router:
return healthy_deployments return healthy_deployments
def _track_deployment_metrics(self, deployment, response=None): def _track_deployment_metrics(self, deployment, response=None):
"""
Tracks successful requests rpm usage.
"""
try: try:
litellm_params = deployment["litellm_params"]
api_base = litellm_params.get("api_base", "")
model = litellm_params.get("model", "")
model_id = deployment.get("model_info", {}).get("id", None) model_id = deployment.get("model_info", {}).get("id", None)
if response is None: if response is None:
# update self.deployment_stats # update self.deployment_stats
if model_id is not None: if model_id is not None:
self._update_usage(model_id) # update in-memory cache for tracking self._update_usage(model_id) # update in-memory cache for tracking
if model_id in self.deployment_stats:
# only update num_requests
self.deployment_stats[model_id]["num_requests"] += 1
else:
self.deployment_stats[model_id] = {
"api_base": api_base,
"model": model,
"num_requests": 1,
}
else:
# check response_ms and update num_successes
if isinstance(response, dict):
response_ms = response.get("_response_ms", 0)
else:
response_ms = 0
if model_id is not None:
if model_id in self.deployment_stats:
# check if avg_latency exists
if "avg_latency" in self.deployment_stats[model_id]:
# update avg_latency
self.deployment_stats[model_id]["avg_latency"] = (
self.deployment_stats[model_id]["avg_latency"]
+ response_ms
) / self.deployment_stats[model_id]["num_successes"]
else:
self.deployment_stats[model_id]["avg_latency"] = response_ms
# check if num_successes exists
if "num_successes" in self.deployment_stats[model_id]:
self.deployment_stats[model_id]["num_successes"] += 1
else:
self.deployment_stats[model_id]["num_successes"] = 1
else:
self.deployment_stats[model_id] = {
"api_base": api_base,
"model": model,
"num_successes": 1,
"avg_latency": response_ms,
}
if self.set_verbose is True and self.debug_level == "DEBUG":
from pprint import pformat
# Assuming self.deployment_stats is your dictionary
formatted_stats = pformat(self.deployment_stats)
# Assuming verbose_router_logger is your logger
verbose_router_logger.info(
"self.deployment_stats: \n%s", formatted_stats
)
except Exception as e: except Exception as e:
verbose_router_logger.error(f"Error in _track_deployment_metrics: {str(e)}") verbose_router_logger.error(f"Error in _track_deployment_metrics: {str(e)}")
@ -5442,6 +5439,7 @@ class Router:
""" """
# if we can find the exception then in the retry policy -> return the number of retries # if we can find the exception then in the retry policy -> return the number of retries
retry_policy: Optional[RetryPolicy] = self.retry_policy retry_policy: Optional[RetryPolicy] = self.retry_policy
if ( if (
self.model_group_retry_policy is not None self.model_group_retry_policy is not None
and model_group is not None and model_group is not None
@ -5540,7 +5538,9 @@ class Router:
litellm.success_callback.append( litellm.success_callback.append(
_slack_alerting_logger.response_taking_too_long_callback _slack_alerting_logger.response_taking_too_long_callback
) )
print("\033[94m\nInitialized Alerting for litellm.Router\033[0m\n") # noqa verbose_router_logger.info(
"\033[94m\nInitialized Alerting for litellm.Router\033[0m\n"
)
def set_custom_routing_strategy( def set_custom_routing_strategy(
self, CustomRoutingStrategy: CustomRoutingStrategyBase self, CustomRoutingStrategy: CustomRoutingStrategyBase

View file

@ -148,13 +148,17 @@ def _set_cooldown_deployments(
exception_status: Union[str, int], exception_status: Union[str, int],
deployment: Optional[str] = None, deployment: Optional[str] = None,
time_to_cooldown: Optional[float] = None, time_to_cooldown: Optional[float] = None,
): ) -> bool:
""" """
Add a model to the list of models being cooled down for that minute, if it exceeds the allowed fails / minute Add a model to the list of models being cooled down for that minute, if it exceeds the allowed fails / minute
or or
the exception is not one that should be immediately retried (e.g. 401) the exception is not one that should be immediately retried (e.g. 401)
Returns:
- True if the deployment should be put in cooldown
- False if the deployment should not be put in cooldown
""" """
if ( if (
_should_run_cooldown_logic( _should_run_cooldown_logic(
@ -163,7 +167,7 @@ def _set_cooldown_deployments(
is False is False
or deployment is None or deployment is None
): ):
return return False
exception_status_int = cast_exception_status_to_int(exception_status) exception_status_int = cast_exception_status_to_int(exception_status)
@ -191,6 +195,8 @@ def _set_cooldown_deployments(
cooldown_time=cooldown_time, cooldown_time=cooldown_time,
) )
) )
return True
return False
async def _async_get_cooldown_deployments( async def _async_get_cooldown_deployments(

View file

@ -24,7 +24,7 @@ else:
def increment_deployment_successes_for_current_minute( def increment_deployment_successes_for_current_minute(
litellm_router_instance: LitellmRouter, litellm_router_instance: LitellmRouter,
deployment_id: str, deployment_id: str,
): ) -> str:
""" """
In-Memory: Increments the number of successes for the current minute for a deployment_id In-Memory: Increments the number of successes for the current minute for a deployment_id
""" """
@ -35,6 +35,7 @@ def increment_deployment_successes_for_current_minute(
value=1, value=1,
ttl=60, ttl=60,
) )
return key
def increment_deployment_failures_for_current_minute( def increment_deployment_failures_for_current_minute(

View file

@ -11,9 +11,15 @@ def get_function_names_from_file(file_path):
function_names = [] function_names = []
for node in ast.walk(tree): for node in tree.body:
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
# Top-level functions
function_names.append(node.name) function_names.append(node.name)
elif isinstance(node, ast.ClassDef):
# Functions inside classes
for class_node in node.body:
if isinstance(class_node, (ast.FunctionDef, ast.AsyncFunctionDef)):
function_names.append(class_node.name)
return function_names return function_names
@ -79,6 +85,7 @@ ignored_function_names = [
"a_add_message", "a_add_message",
"aget_messages", "aget_messages",
"arun_thread", "arun_thread",
"try_retrieve_batch",
] ]
@ -103,8 +110,8 @@ def main():
if func not in ignored_function_names: if func not in ignored_function_names:
all_untested_functions.append(func) all_untested_functions.append(func)
untested_perc = (len(all_untested_functions)) / len(router_functions) untested_perc = (len(all_untested_functions)) / len(router_functions)
print("perc_covered: ", untested_perc) print("untested_perc: ", untested_perc)
if untested_perc < 0.3: if untested_perc > 0:
print("The following functions in router.py are not tested:") print("The following functions in router.py are not tested:")
raise Exception( raise Exception(
f"{untested_perc * 100:.2f}% of functions in router.py are not tested: {all_untested_functions}" f"{untested_perc * 100:.2f}% of functions in router.py are not tested: {all_untested_functions}"

View file

@ -20,6 +20,7 @@ import boto3
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.flaky(retries=6, delay=1)
async def test_basic_s3_logging(sync_mode): async def test_basic_s3_logging(sync_mode):
verbose_logger.setLevel(level=logging.DEBUG) verbose_logger.setLevel(level=logging.DEBUG)
litellm.success_callback = ["s3"] litellm.success_callback = ["s3"]

View file

@ -3789,6 +3789,7 @@ def test_completion_anyscale_api():
# @pytest.mark.skip(reason="flaky test, times out frequently") # @pytest.mark.skip(reason="flaky test, times out frequently")
@pytest.mark.flaky(retries=6, delay=1)
def test_completion_cohere(): def test_completion_cohere():
try: try:
# litellm.set_verbose=True # litellm.set_verbose=True

View file

@ -10,6 +10,7 @@ sys.path.insert(
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
from litellm import Router from litellm import Router
import pytest import pytest
import litellm
from unittest.mock import patch, MagicMock, AsyncMock from unittest.mock import patch, MagicMock, AsyncMock
@ -22,6 +23,9 @@ def model_list():
"model": "gpt-3.5-turbo", "model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"), "api_key": os.getenv("OPENAI_API_KEY"),
}, },
"model_info": {
"access_groups": ["group1", "group2"],
},
}, },
{ {
"model_name": "gpt-4o", "model_name": "gpt-4o",
@ -250,3 +254,583 @@ async def test_router_make_call(model_list):
mock_response="https://example.com/image.png", mock_response="https://example.com/image.png",
) )
assert response.data[0].url == "https://example.com/image.png" assert response.data[0].url == "https://example.com/image.png"
def test_update_kwargs_with_deployment(model_list):
"""Test if the '_update_kwargs_with_deployment' function is working correctly"""
router = Router(model_list=model_list)
kwargs: dict = {"metadata": {}}
deployment = router.get_deployment_by_model_group_name(
model_group_name="gpt-3.5-turbo"
)
router._update_kwargs_with_deployment(
deployment=deployment,
kwargs=kwargs,
)
set_fields = ["deployment", "api_base", "model_info"]
assert all(field in kwargs["metadata"] for field in set_fields)
def test_update_kwargs_with_default_litellm_params(model_list):
"""Test if the '_update_kwargs_with_default_litellm_params' function is working correctly"""
router = Router(
model_list=model_list,
default_litellm_params={"api_key": "test", "metadata": {"key": "value"}},
)
kwargs: dict = {"metadata": {"key2": "value2"}}
router._update_kwargs_with_default_litellm_params(kwargs=kwargs)
assert kwargs["api_key"] == "test"
assert kwargs["metadata"]["key"] == "value"
assert kwargs["metadata"]["key2"] == "value2"
def test_get_async_openai_model_client(model_list):
"""Test if the '_get_async_openai_model_client' function is working correctly"""
router = Router(model_list=model_list)
deployment = router.get_deployment_by_model_group_name(
model_group_name="gpt-3.5-turbo"
)
model_client = router._get_async_openai_model_client(
deployment=deployment, kwargs={}
)
assert model_client is not None
def test_get_timeout(model_list):
"""Test if the '_get_timeout' function is working correctly"""
router = Router(model_list=model_list)
timeout = router._get_timeout(kwargs={}, data={"timeout": 100})
assert timeout == 100
@pytest.mark.parametrize(
"fallback_kwarg, expected_error",
[
("mock_testing_fallbacks", litellm.InternalServerError),
("mock_testing_context_fallbacks", litellm.ContextWindowExceededError),
("mock_testing_content_policy_fallbacks", litellm.ContentPolicyViolationError),
],
)
def test_handle_mock_testing_fallbacks(model_list, fallback_kwarg, expected_error):
"""Test if the '_handle_mock_testing_fallbacks' function is working correctly"""
router = Router(model_list=model_list)
with pytest.raises(expected_error):
data = {
fallback_kwarg: True,
}
router._handle_mock_testing_fallbacks(
kwargs=data,
)
def test_handle_mock_testing_rate_limit_error(model_list):
"""Test if the '_handle_mock_testing_rate_limit_error' function is working correctly"""
router = Router(model_list=model_list)
with pytest.raises(litellm.RateLimitError):
data = {
"mock_testing_rate_limit_error": True,
}
router._handle_mock_testing_rate_limit_error(
kwargs=data,
)
def test_get_fallback_model_group_from_fallbacks(model_list):
"""Test if the '_get_fallback_model_group_from_fallbacks' function is working correctly"""
router = Router(model_list=model_list)
fallback_model_group_name = router._get_fallback_model_group_from_fallbacks(
model_group="gpt-4o",
fallbacks=[{"gpt-4o": "gpt-3.5-turbo"}],
)
assert fallback_model_group_name == "gpt-3.5-turbo"
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_deployment_callback_on_success(model_list, sync_mode):
"""Test if the '_deployment_callback_on_success' function is working correctly"""
import time
router = Router(model_list=model_list)
kwargs = {
"litellm_params": {
"metadata": {
"model_group": "gpt-3.5-turbo",
},
"model_info": {"id": 100},
},
}
response = litellm.ModelResponse(
model="gpt-3.5-turbo",
usage={"total_tokens": 100},
)
if sync_mode:
tpm_key = router.sync_deployment_callback_on_success(
kwargs=kwargs,
completion_response=response,
start_time=time.time(),
end_time=time.time(),
)
else:
tpm_key = await router.deployment_callback_on_success(
kwargs=kwargs,
completion_response=response,
start_time=time.time(),
end_time=time.time(),
)
assert tpm_key is not None
def test_deployment_callback_on_failure(model_list):
"""Test if the '_deployment_callback_on_failure' function is working correctly"""
import time
router = Router(model_list=model_list)
kwargs = {
"litellm_params": {
"metadata": {
"model_group": "gpt-3.5-turbo",
},
"model_info": {"id": 100},
},
}
result = router.deployment_callback_on_failure(
kwargs=kwargs,
completion_response=None,
start_time=time.time(),
end_time=time.time(),
)
assert isinstance(result, bool)
assert result is False
def test_log_retry(model_list):
"""Test if the '_log_retry' function is working correctly"""
import time
router = Router(model_list=model_list)
new_kwargs = router.log_retry(
kwargs={"metadata": {}},
e=Exception(),
)
assert "metadata" in new_kwargs
assert "previous_models" in new_kwargs["metadata"]
def test_update_usage(model_list):
"""Test if the '_update_usage' function is working correctly"""
router = Router(model_list=model_list)
deployment = router.get_deployment_by_model_group_name(
model_group_name="gpt-3.5-turbo"
)
deployment_id = deployment["model_info"]["id"]
request_count = router._update_usage(
deployment_id=deployment_id,
)
assert request_count == 1
request_count = router._update_usage(
deployment_id=deployment_id,
)
assert request_count == 2
@pytest.mark.parametrize(
"finish_reason, expected_error", [("content_filter", True), ("stop", False)]
)
def test_should_raise_content_policy_error(model_list, finish_reason, expected_error):
"""Test if the '_should_raise_content_policy_error' function is working correctly"""
router = Router(model_list=model_list)
assert (
router._should_raise_content_policy_error(
model="gpt-3.5-turbo",
response=litellm.ModelResponse(
model="gpt-3.5-turbo",
choices=[
{
"finish_reason": finish_reason,
"message": {"content": "I'm fine, thank you!"},
}
],
usage={"total_tokens": 100},
),
kwargs={
"content_policy_fallbacks": [{"gpt-3.5-turbo": "gpt-4o"}],
},
)
is expected_error
)
def test_get_healthy_deployments(model_list):
"""Test if the '_get_healthy_deployments' function is working correctly"""
router = Router(model_list=model_list)
deployments = router._get_healthy_deployments(model="gpt-3.5-turbo")
assert len(deployments) > 0
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_routing_strategy_pre_call_checks(model_list, sync_mode):
"""Test if the '_routing_strategy_pre_call_checks' function is working correctly"""
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.litellm_logging import Logging
callback = CustomLogger()
litellm.callbacks = [callback]
router = Router(model_list=model_list)
deployment = router.get_deployment_by_model_group_name(
model_group_name="gpt-3.5-turbo"
)
litellm_logging_obj = Logging(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "hi"}],
stream=False,
call_type="acompletion",
litellm_call_id="1234",
start_time=datetime.now(),
function_id="1234",
)
if sync_mode:
router.routing_strategy_pre_call_checks(deployment)
else:
## NO EXCEPTION
await router.async_routing_strategy_pre_call_checks(
deployment, litellm_logging_obj
)
## WITH EXCEPTION - rate limit error
with patch.object(
callback,
"async_pre_call_check",
AsyncMock(
side_effect=litellm.RateLimitError(
message="Rate limit error",
llm_provider="openai",
model="gpt-3.5-turbo",
)
),
):
try:
await router.async_routing_strategy_pre_call_checks(
deployment, litellm_logging_obj
)
pytest.fail("Exception was not raised")
except Exception as e:
assert isinstance(e, litellm.RateLimitError)
## WITH EXCEPTION - generic error
with patch.object(
callback, "async_pre_call_check", AsyncMock(side_effect=Exception("Error"))
):
try:
await router.async_routing_strategy_pre_call_checks(
deployment, litellm_logging_obj
)
pytest.fail("Exception was not raised")
except Exception as e:
assert isinstance(e, Exception)
@pytest.mark.parametrize(
"set_supported_environments, supported_environments, is_supported",
[(True, ["staging"], True), (False, None, True), (True, ["development"], False)],
)
def test_create_deployment(
model_list, set_supported_environments, supported_environments, is_supported
):
"""Test if the '_create_deployment' function is working correctly"""
router = Router(model_list=model_list)
if set_supported_environments:
os.environ["LITELLM_ENVIRONMENT"] = "staging"
deployment = router._create_deployment(
deployment_info={},
_model_name="gpt-3.5-turbo",
_litellm_params={
"model": "gpt-3.5-turbo",
"api_key": "test",
"custom_llm_provider": "openai",
},
_model_info={
"id": 100,
"supported_environments": supported_environments,
},
)
if is_supported:
assert deployment is not None
else:
assert deployment is None
@pytest.mark.parametrize(
"set_supported_environments, supported_environments, is_supported",
[(True, ["staging"], True), (False, None, True), (True, ["development"], False)],
)
def test_deployment_is_active_for_environment(
model_list, set_supported_environments, supported_environments, is_supported
):
"""Test if the '_deployment_is_active_for_environment' function is working correctly"""
router = Router(model_list=model_list)
deployment = router.get_deployment_by_model_group_name(
model_group_name="gpt-3.5-turbo"
)
if set_supported_environments:
os.environ["LITELLM_ENVIRONMENT"] = "staging"
deployment["model_info"]["supported_environments"] = supported_environments
if is_supported:
assert (
router.deployment_is_active_for_environment(deployment=deployment) is True
)
else:
assert (
router.deployment_is_active_for_environment(deployment=deployment) is False
)
def test_set_model_list(model_list):
"""Test if the '_set_model_list' function is working correctly"""
router = Router(model_list=model_list)
router.set_model_list(model_list=model_list)
assert len(router.model_list) == len(model_list)
def test_add_deployment(model_list):
"""Test if the '_add_deployment' function is working correctly"""
router = Router(model_list=model_list)
deployment = router.get_deployment_by_model_group_name(
model_group_name="gpt-3.5-turbo"
)
deployment["model_info"]["id"] = 100
## Test 1: call user facing function
router.add_deployment(deployment=deployment)
## Test 2: call internal function
router._add_deployment(deployment=deployment)
assert len(router.model_list) == len(model_list) + 1
def test_upsert_deployment(model_list):
"""Test if the 'upsert_deployment' function is working correctly"""
router = Router(model_list=model_list)
print("model list", len(router.model_list))
deployment = router.get_deployment_by_model_group_name(
model_group_name="gpt-3.5-turbo"
)
deployment.litellm_params.model = "gpt-4o"
router.upsert_deployment(deployment=deployment)
assert len(router.model_list) == len(model_list)
def test_delete_deployment(model_list):
"""Test if the 'delete_deployment' function is working correctly"""
router = Router(model_list=model_list)
deployment = router.get_deployment_by_model_group_name(
model_group_name="gpt-3.5-turbo"
)
router.delete_deployment(id=deployment["model_info"]["id"])
assert len(router.model_list) == len(model_list) - 1
def test_get_model_info(model_list):
"""Test if the 'get_model_info' function is working correctly"""
router = Router(model_list=model_list)
deployment = router.get_deployment_by_model_group_name(
model_group_name="gpt-3.5-turbo"
)
model_info = router.get_model_info(id=deployment["model_info"]["id"])
assert model_info is not None
def test_get_model_group(model_list):
"""Test if the 'get_model_group' function is working correctly"""
router = Router(model_list=model_list)
deployment = router.get_deployment_by_model_group_name(
model_group_name="gpt-3.5-turbo"
)
model_group = router.get_model_group(id=deployment["model_info"]["id"])
assert model_group is not None
assert model_group[0]["model_name"] == "gpt-3.5-turbo"
@pytest.mark.parametrize("user_facing_model_group_name", ["gpt-3.5-turbo", "gpt-4o"])
def test_set_model_group_info(model_list, user_facing_model_group_name):
"""Test if the 'set_model_group_info' function is working correctly"""
router = Router(model_list=model_list)
resp = router._set_model_group_info(
model_group="gpt-3.5-turbo",
user_facing_model_group_name=user_facing_model_group_name,
)
assert resp is not None
assert resp.model_group == user_facing_model_group_name
@pytest.mark.asyncio
async def test_set_response_headers(model_list):
"""Test if the 'set_response_headers' function is working correctly"""
router = Router(model_list=model_list)
resp = await router.set_response_headers(response=None, model_group=None)
assert resp is None
def test_get_all_deployments(model_list):
"""Test if the 'get_all_deployments' function is working correctly"""
router = Router(model_list=model_list)
deployments = router._get_all_deployments(
model_name="gpt-3.5-turbo", model_alias="gpt-3.5-turbo"
)
assert len(deployments) > 0
def test_get_model_access_groups(model_list):
"""Test if the 'get_model_access_groups' function is working correctly"""
router = Router(model_list=model_list)
access_groups = router.get_model_access_groups()
assert len(access_groups) == 2
def test_update_settings(model_list):
"""Test if the 'update_settings' function is working correctly"""
router = Router(model_list=model_list)
pre_update_allowed_fails = router.allowed_fails
router.update_settings(**{"allowed_fails": 20})
assert router.allowed_fails != pre_update_allowed_fails
assert router.allowed_fails == 20
def test_common_checks_available_deployment(model_list):
"""Test if the 'common_checks_available_deployment' function is working correctly"""
router = Router(model_list=model_list)
_, available_deployments = router._common_checks_available_deployment(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "hi"}],
input="hi",
specific_deployment=False,
)
assert len(available_deployments) > 0
def test_filter_cooldown_deployments(model_list):
"""Test if the 'filter_cooldown_deployments' function is working correctly"""
router = Router(model_list=model_list)
deployments = router._filter_cooldown_deployments(
healthy_deployments=router._get_all_deployments(model_name="gpt-3.5-turbo"), # type: ignore
cooldown_deployments=[],
)
assert len(deployments) == len(
router._get_all_deployments(model_name="gpt-3.5-turbo")
)
def test_track_deployment_metrics(model_list):
"""Test if the 'track_deployment_metrics' function is working correctly"""
from litellm.types.utils import ModelResponse
router = Router(model_list=model_list)
router._track_deployment_metrics(
deployment=router.get_deployment_by_model_group_name(
model_group_name="gpt-3.5-turbo"
),
response=ModelResponse(
model="gpt-3.5-turbo",
usage={"total_tokens": 100},
),
)
@pytest.mark.parametrize(
"exception_type, exception_name, num_retries",
[
(litellm.exceptions.BadRequestError, "BadRequestError", 3),
(litellm.exceptions.AuthenticationError, "AuthenticationError", 4),
(litellm.exceptions.RateLimitError, "RateLimitError", 6),
(
litellm.exceptions.ContentPolicyViolationError,
"ContentPolicyViolationError",
7,
),
],
)
def test_get_num_retries_from_retry_policy(
model_list, exception_type, exception_name, num_retries
):
"""Test if the 'get_num_retries_from_retry_policy' function is working correctly"""
from litellm.router import RetryPolicy
data = {exception_name + "Retries": num_retries}
print("data", data)
router = Router(
model_list=model_list,
retry_policy=RetryPolicy(**data),
)
print("exception_type", exception_type)
calc_num_retries = router.get_num_retries_from_retry_policy(
exception=exception_type(
message="test", llm_provider="openai", model="gpt-3.5-turbo"
)
)
assert calc_num_retries == num_retries
@pytest.mark.parametrize(
"exception_type, exception_name, allowed_fails",
[
(litellm.exceptions.BadRequestError, "BadRequestError", 3),
(litellm.exceptions.AuthenticationError, "AuthenticationError", 4),
(litellm.exceptions.RateLimitError, "RateLimitError", 6),
(
litellm.exceptions.ContentPolicyViolationError,
"ContentPolicyViolationError",
7,
),
],
)
def test_get_allowed_fails_from_policy(
model_list, exception_type, exception_name, allowed_fails
):
"""Test if the 'get_allowed_fails_from_policy' function is working correctly"""
from litellm.types.router import AllowedFailsPolicy
data = {exception_name + "AllowedFails": allowed_fails}
router = Router(
model_list=model_list, allowed_fails_policy=AllowedFailsPolicy(**data)
)
calc_allowed_fails = router.get_allowed_fails_from_policy(
exception=exception_type(
message="test", llm_provider="openai", model="gpt-3.5-turbo"
)
)
assert calc_allowed_fails == allowed_fails
def test_initialize_alerting(model_list):
"""Test if the 'initialize_alerting' function is working correctly"""
from litellm.types.router import AlertingConfig
from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting
router = Router(
model_list=model_list, alerting_config=AlertingConfig(webhook_url="test")
)
router._initialize_alerting()
callback_added = False
for callback in litellm.callbacks:
if isinstance(callback, SlackAlerting):
callback_added = True
assert callback_added is True
def test_flush_cache(model_list):
"""Test if the 'flush_cache' function is working correctly"""
router = Router(model_list=model_list)
router.cache.set_cache("test", "test")
assert router.cache.get_cache("test") == "test"
router.flush_cache()
assert router.cache.get_cache("test") is None