forked from phoenix/litellm-mirror
(testing) Router add testing coverage (#6253)
* test: add more router code coverage * test: additional router testing coverage * fix: fix linting error * test: fix tests for ci/cd * test: fix test * test: handle flaky tests --------- Co-authored-by: Krrish Dholakia <krrishdholakia@gmail.com>
This commit is contained in:
parent
54ebdbf7ce
commit
8530000b44
7 changed files with 706 additions and 106 deletions
|
@ -861,11 +861,23 @@ class Router:
|
|||
self.fail_calls[model_name] += 1
|
||||
raise e
|
||||
|
||||
def _update_kwargs_with_default_litellm_params(self, kwargs: dict) -> None:
|
||||
"""
|
||||
Adds default litellm params to kwargs, if set.
|
||||
"""
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if (
|
||||
k not in kwargs and v is not None
|
||||
): # prioritize model-specific params > default router params
|
||||
kwargs[k] = v
|
||||
elif k == "metadata":
|
||||
kwargs[k].update(v)
|
||||
|
||||
def _update_kwargs_with_deployment(self, deployment: dict, kwargs: dict) -> None:
|
||||
"""
|
||||
Adds selected deployment, model_info and api_base to kwargs["metadata"]
|
||||
|
||||
This is used in litellm logging callbacks
|
||||
2 jobs:
|
||||
- Adds selected deployment, model_info and api_base to kwargs["metadata"] (used for logging)
|
||||
- Adds default litellm params to kwargs, if set.
|
||||
"""
|
||||
kwargs.setdefault("metadata", {}).update(
|
||||
{
|
||||
|
@ -875,13 +887,7 @@ class Router:
|
|||
}
|
||||
)
|
||||
kwargs["model_info"] = deployment.get("model_info", {})
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if (
|
||||
k not in kwargs and v is not None
|
||||
): # prioritize model-specific params > default router params
|
||||
kwargs[k] = v
|
||||
elif k == "metadata":
|
||||
kwargs[k].update(v)
|
||||
self._update_kwargs_with_default_litellm_params(kwargs=kwargs)
|
||||
|
||||
def _get_async_openai_model_client(self, deployment: dict, kwargs: dict):
|
||||
"""
|
||||
|
@ -910,6 +916,7 @@ class Router:
|
|||
return model_client
|
||||
|
||||
def _get_timeout(self, kwargs: dict, data: dict) -> Optional[Union[float, int]]:
|
||||
"""Helper to get timeout from kwargs or deployment params"""
|
||||
timeout = (
|
||||
data.get(
|
||||
"timeout", None
|
||||
|
@ -3414,11 +3421,10 @@ class Router:
|
|||
):
|
||||
"""
|
||||
Track remaining tpm/rpm quota for model in model_list
|
||||
|
||||
Currently, only updates TPM usage.
|
||||
"""
|
||||
try:
|
||||
"""
|
||||
Update TPM usage on success
|
||||
"""
|
||||
if kwargs["litellm_params"].get("metadata") is None:
|
||||
pass
|
||||
else:
|
||||
|
@ -3459,6 +3465,8 @@ class Router:
|
|||
deployment_id=id,
|
||||
)
|
||||
|
||||
return tpm_key
|
||||
|
||||
except Exception as e:
|
||||
verbose_router_logger.exception(
|
||||
"litellm.router.Router::deployment_callback_on_success(): Exception occured - {}".format(
|
||||
|
@ -3473,7 +3481,14 @@ class Router:
|
|||
completion_response, # response from completion
|
||||
start_time,
|
||||
end_time, # start/end time
|
||||
):
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Tracks the number of successes for a deployment in the current minute (using in-memory cache)
|
||||
|
||||
Returns:
|
||||
- key: str - The key used to increment the cache
|
||||
- None: if no key is found
|
||||
"""
|
||||
id = None
|
||||
if kwargs["litellm_params"].get("metadata") is None:
|
||||
pass
|
||||
|
@ -3482,15 +3497,18 @@ class Router:
|
|||
model_info = kwargs["litellm_params"].get("model_info", {}) or {}
|
||||
id = model_info.get("id", None)
|
||||
if model_group is None or id is None:
|
||||
return
|
||||
return None
|
||||
elif isinstance(id, int):
|
||||
id = str(id)
|
||||
|
||||
if id is not None:
|
||||
increment_deployment_successes_for_current_minute(
|
||||
key = increment_deployment_successes_for_current_minute(
|
||||
litellm_router_instance=self,
|
||||
deployment_id=id,
|
||||
)
|
||||
return key
|
||||
|
||||
return None
|
||||
|
||||
def deployment_callback_on_failure(
|
||||
self,
|
||||
|
@ -3498,15 +3516,19 @@ class Router:
|
|||
completion_response, # response from completion
|
||||
start_time,
|
||||
end_time, # start/end time
|
||||
):
|
||||
) -> bool:
|
||||
"""
|
||||
2 jobs:
|
||||
- Tracks the number of failures for a deployment in the current minute (using in-memory cache)
|
||||
- Puts the deployment in cooldown if it exceeds the allowed fails / minute
|
||||
|
||||
Returns:
|
||||
- True if the deployment should be put in cooldown
|
||||
- False if the deployment should not be put in cooldown
|
||||
"""
|
||||
try:
|
||||
exception = kwargs.get("exception", None)
|
||||
exception_status = getattr(exception, "status_code", "")
|
||||
model_name = kwargs.get("model", None) # i.e. gpt35turbo
|
||||
custom_llm_provider = kwargs.get("litellm_params", {}).get(
|
||||
"custom_llm_provider", None
|
||||
) # i.e. azure
|
||||
kwargs.get("litellm_params", {}).get("metadata", None)
|
||||
_model_info = kwargs.get("litellm_params", {}).get("model_info", {})
|
||||
|
||||
exception_headers = litellm.litellm_core_utils.exception_mapping_utils._get_response_headers(
|
||||
|
@ -3535,15 +3557,17 @@ class Router:
|
|||
litellm_router_instance=self,
|
||||
deployment_id=deployment_id,
|
||||
)
|
||||
_set_cooldown_deployments(
|
||||
result = _set_cooldown_deployments(
|
||||
litellm_router_instance=self,
|
||||
exception_status=exception_status,
|
||||
original_exception=exception,
|
||||
deployment=deployment_id,
|
||||
time_to_cooldown=_time_to_cooldown,
|
||||
) # setting deployment_id in cooldown deployments
|
||||
if custom_llm_provider:
|
||||
model_name = f"{custom_llm_provider}/{model_name}"
|
||||
|
||||
return result
|
||||
else:
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
@ -3582,9 +3606,12 @@ class Router:
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def _update_usage(self, deployment_id: str):
|
||||
def _update_usage(self, deployment_id: str) -> int:
|
||||
"""
|
||||
Update deployment rpm for that minute
|
||||
|
||||
Returns:
|
||||
- int: request count
|
||||
"""
|
||||
rpm_key = deployment_id
|
||||
|
||||
|
@ -3600,6 +3627,8 @@ class Router:
|
|||
key=rpm_key, value=request_count, local_only=True
|
||||
) # don't change existing ttl
|
||||
|
||||
return request_count
|
||||
|
||||
def _is_cooldown_required(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -3778,7 +3807,7 @@ class Router:
|
|||
for _callback in litellm.callbacks:
|
||||
if isinstance(_callback, CustomLogger):
|
||||
try:
|
||||
_ = await _callback.async_pre_call_check(deployment)
|
||||
await _callback.async_pre_call_check(deployment)
|
||||
except litellm.RateLimitError as e:
|
||||
## LOG FAILURE EVENT
|
||||
if logging_obj is not None:
|
||||
|
@ -3848,10 +3877,23 @@ class Router:
|
|||
return hash_object.hexdigest()
|
||||
|
||||
def _create_deployment(
|
||||
self, model: dict, _model_name: str, _litellm_params: dict, _model_info: dict
|
||||
):
|
||||
self,
|
||||
deployment_info: dict,
|
||||
_model_name: str,
|
||||
_litellm_params: dict,
|
||||
_model_info: dict,
|
||||
) -> Optional[Deployment]:
|
||||
"""
|
||||
Create a deployment object and add it to the model list
|
||||
|
||||
If the deployment is not active for the current environment, it is ignored
|
||||
|
||||
Returns:
|
||||
- Deployment: The deployment object
|
||||
- None: If the deployment is not active for the current environment (if 'supported_environments' is set in litellm_params)
|
||||
"""
|
||||
deployment = Deployment(
|
||||
**model,
|
||||
**deployment_info,
|
||||
model_name=_model_name,
|
||||
litellm_params=LiteLLM_Params(**_litellm_params),
|
||||
model_info=_model_info,
|
||||
|
@ -3870,18 +3912,18 @@ class Router:
|
|||
)
|
||||
|
||||
## Check if LLM Deployment is allowed for this deployment
|
||||
if deployment.model_info and "supported_environments" in deployment.model_info:
|
||||
if (
|
||||
self.deployment_is_active_for_environment(deployment=deployment)
|
||||
is not True
|
||||
):
|
||||
return
|
||||
if self.deployment_is_active_for_environment(deployment=deployment) is not True:
|
||||
verbose_router_logger.warning(
|
||||
f"Ignoring deployment {deployment.model_name} as it is not active for environment {deployment.model_info['supported_environments']}"
|
||||
)
|
||||
return None
|
||||
|
||||
deployment = self._add_deployment(deployment=deployment)
|
||||
|
||||
model = deployment.to_json(exclude_none=True)
|
||||
|
||||
self.model_list.append(model)
|
||||
return deployment
|
||||
|
||||
def deployment_is_active_for_environment(self, deployment: Deployment) -> bool:
|
||||
"""
|
||||
|
@ -3896,6 +3938,12 @@ class Router:
|
|||
- ValueError: If LITELLM_ENVIRONMENT is not set in .env or not one of the valid values
|
||||
- ValueError: If supported_environments is not set in model_info or not one of the valid values
|
||||
"""
|
||||
if (
|
||||
deployment.model_info is None
|
||||
or "supported_environments" not in deployment.model_info
|
||||
or deployment.model_info["supported_environments"] is None
|
||||
):
|
||||
return True
|
||||
litellm_environment = get_secret_str(secret_name="LITELLM_ENVIRONMENT")
|
||||
if litellm_environment is None:
|
||||
raise ValueError(
|
||||
|
@ -3913,7 +3961,6 @@ class Router:
|
|||
f"supported_environments must be one of {VALID_LITELLM_ENVIRONMENTS}. but set as: {_env} for deployment: {deployment}"
|
||||
)
|
||||
|
||||
# validate litellm_environment is one of LiteLLMEnvironment
|
||||
if litellm_environment in deployment.model_info["supported_environments"]:
|
||||
return True
|
||||
return False
|
||||
|
@ -3946,14 +3993,14 @@ class Router:
|
|||
for org in _litellm_params["organization"]:
|
||||
_litellm_params["organization"] = org
|
||||
self._create_deployment(
|
||||
model=model,
|
||||
deployment_info=model,
|
||||
_model_name=_model_name,
|
||||
_litellm_params=_litellm_params,
|
||||
_model_info=_model_info,
|
||||
)
|
||||
else:
|
||||
self._create_deployment(
|
||||
model=model,
|
||||
deployment_info=model,
|
||||
_model_name=_model_name,
|
||||
_litellm_params=_litellm_params,
|
||||
_model_info=_model_info,
|
||||
|
@ -4118,7 +4165,7 @@ class Router:
|
|||
|
||||
if removal_idx is not None:
|
||||
self.model_list.pop(removal_idx)
|
||||
else:
|
||||
|
||||
# if the model_id is not in router
|
||||
self.add_deployment(deployment=deployment)
|
||||
return deployment
|
||||
|
@ -4628,16 +4675,13 @@ class Router:
|
|||
from collections import defaultdict
|
||||
|
||||
access_groups = defaultdict(list)
|
||||
if self.access_groups:
|
||||
return self.access_groups
|
||||
|
||||
if self.model_list:
|
||||
for m in self.model_list:
|
||||
for group in m.get("model_info", {}).get("access_groups", []):
|
||||
model_name = m["model_name"]
|
||||
access_groups[group].append(model_name)
|
||||
# set access groups
|
||||
self.access_groups = access_groups
|
||||
|
||||
return access_groups
|
||||
|
||||
def get_settings(self):
|
||||
|
@ -4672,6 +4716,9 @@ class Router:
|
|||
return _settings_to_return
|
||||
|
||||
def update_settings(self, **kwargs):
|
||||
"""
|
||||
Update the router settings.
|
||||
"""
|
||||
# only the following settings are allowed to be configured
|
||||
_allowed_settings = [
|
||||
"routing_strategy_args",
|
||||
|
@ -5367,66 +5414,16 @@ class Router:
|
|||
return healthy_deployments
|
||||
|
||||
def _track_deployment_metrics(self, deployment, response=None):
|
||||
"""
|
||||
Tracks successful requests rpm usage.
|
||||
"""
|
||||
try:
|
||||
litellm_params = deployment["litellm_params"]
|
||||
api_base = litellm_params.get("api_base", "")
|
||||
model = litellm_params.get("model", "")
|
||||
|
||||
model_id = deployment.get("model_info", {}).get("id", None)
|
||||
if response is None:
|
||||
|
||||
# update self.deployment_stats
|
||||
if model_id is not None:
|
||||
self._update_usage(model_id) # update in-memory cache for tracking
|
||||
if model_id in self.deployment_stats:
|
||||
# only update num_requests
|
||||
self.deployment_stats[model_id]["num_requests"] += 1
|
||||
else:
|
||||
self.deployment_stats[model_id] = {
|
||||
"api_base": api_base,
|
||||
"model": model,
|
||||
"num_requests": 1,
|
||||
}
|
||||
else:
|
||||
# check response_ms and update num_successes
|
||||
if isinstance(response, dict):
|
||||
response_ms = response.get("_response_ms", 0)
|
||||
else:
|
||||
response_ms = 0
|
||||
if model_id is not None:
|
||||
if model_id in self.deployment_stats:
|
||||
# check if avg_latency exists
|
||||
if "avg_latency" in self.deployment_stats[model_id]:
|
||||
# update avg_latency
|
||||
self.deployment_stats[model_id]["avg_latency"] = (
|
||||
self.deployment_stats[model_id]["avg_latency"]
|
||||
+ response_ms
|
||||
) / self.deployment_stats[model_id]["num_successes"]
|
||||
else:
|
||||
self.deployment_stats[model_id]["avg_latency"] = response_ms
|
||||
|
||||
# check if num_successes exists
|
||||
if "num_successes" in self.deployment_stats[model_id]:
|
||||
self.deployment_stats[model_id]["num_successes"] += 1
|
||||
else:
|
||||
self.deployment_stats[model_id]["num_successes"] = 1
|
||||
else:
|
||||
self.deployment_stats[model_id] = {
|
||||
"api_base": api_base,
|
||||
"model": model,
|
||||
"num_successes": 1,
|
||||
"avg_latency": response_ms,
|
||||
}
|
||||
if self.set_verbose is True and self.debug_level == "DEBUG":
|
||||
from pprint import pformat
|
||||
|
||||
# Assuming self.deployment_stats is your dictionary
|
||||
formatted_stats = pformat(self.deployment_stats)
|
||||
|
||||
# Assuming verbose_router_logger is your logger
|
||||
verbose_router_logger.info(
|
||||
"self.deployment_stats: \n%s", formatted_stats
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_router_logger.error(f"Error in _track_deployment_metrics: {str(e)}")
|
||||
|
||||
|
@ -5442,6 +5439,7 @@ class Router:
|
|||
"""
|
||||
# if we can find the exception then in the retry policy -> return the number of retries
|
||||
retry_policy: Optional[RetryPolicy] = self.retry_policy
|
||||
|
||||
if (
|
||||
self.model_group_retry_policy is not None
|
||||
and model_group is not None
|
||||
|
@ -5540,7 +5538,9 @@ class Router:
|
|||
litellm.success_callback.append(
|
||||
_slack_alerting_logger.response_taking_too_long_callback
|
||||
)
|
||||
print("\033[94m\nInitialized Alerting for litellm.Router\033[0m\n") # noqa
|
||||
verbose_router_logger.info(
|
||||
"\033[94m\nInitialized Alerting for litellm.Router\033[0m\n"
|
||||
)
|
||||
|
||||
def set_custom_routing_strategy(
|
||||
self, CustomRoutingStrategy: CustomRoutingStrategyBase
|
||||
|
|
|
@ -148,13 +148,17 @@ def _set_cooldown_deployments(
|
|||
exception_status: Union[str, int],
|
||||
deployment: Optional[str] = None,
|
||||
time_to_cooldown: Optional[float] = None,
|
||||
):
|
||||
) -> bool:
|
||||
"""
|
||||
Add a model to the list of models being cooled down for that minute, if it exceeds the allowed fails / minute
|
||||
|
||||
or
|
||||
|
||||
the exception is not one that should be immediately retried (e.g. 401)
|
||||
|
||||
Returns:
|
||||
- True if the deployment should be put in cooldown
|
||||
- False if the deployment should not be put in cooldown
|
||||
"""
|
||||
if (
|
||||
_should_run_cooldown_logic(
|
||||
|
@ -163,7 +167,7 @@ def _set_cooldown_deployments(
|
|||
is False
|
||||
or deployment is None
|
||||
):
|
||||
return
|
||||
return False
|
||||
|
||||
exception_status_int = cast_exception_status_to_int(exception_status)
|
||||
|
||||
|
@ -191,6 +195,8 @@ def _set_cooldown_deployments(
|
|||
cooldown_time=cooldown_time,
|
||||
)
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def _async_get_cooldown_deployments(
|
||||
|
|
|
@ -24,7 +24,7 @@ else:
|
|||
def increment_deployment_successes_for_current_minute(
|
||||
litellm_router_instance: LitellmRouter,
|
||||
deployment_id: str,
|
||||
):
|
||||
) -> str:
|
||||
"""
|
||||
In-Memory: Increments the number of successes for the current minute for a deployment_id
|
||||
"""
|
||||
|
@ -35,6 +35,7 @@ def increment_deployment_successes_for_current_minute(
|
|||
value=1,
|
||||
ttl=60,
|
||||
)
|
||||
return key
|
||||
|
||||
|
||||
def increment_deployment_failures_for_current_minute(
|
||||
|
|
|
@ -11,9 +11,15 @@ def get_function_names_from_file(file_path):
|
|||
|
||||
function_names = []
|
||||
|
||||
for node in ast.walk(tree):
|
||||
for node in tree.body:
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
# Top-level functions
|
||||
function_names.append(node.name)
|
||||
elif isinstance(node, ast.ClassDef):
|
||||
# Functions inside classes
|
||||
for class_node in node.body:
|
||||
if isinstance(class_node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
function_names.append(class_node.name)
|
||||
|
||||
return function_names
|
||||
|
||||
|
@ -79,6 +85,7 @@ ignored_function_names = [
|
|||
"a_add_message",
|
||||
"aget_messages",
|
||||
"arun_thread",
|
||||
"try_retrieve_batch",
|
||||
]
|
||||
|
||||
|
||||
|
@ -103,8 +110,8 @@ def main():
|
|||
if func not in ignored_function_names:
|
||||
all_untested_functions.append(func)
|
||||
untested_perc = (len(all_untested_functions)) / len(router_functions)
|
||||
print("perc_covered: ", untested_perc)
|
||||
if untested_perc < 0.3:
|
||||
print("untested_perc: ", untested_perc)
|
||||
if untested_perc > 0:
|
||||
print("The following functions in router.py are not tested:")
|
||||
raise Exception(
|
||||
f"{untested_perc * 100:.2f}% of functions in router.py are not tested: {all_untested_functions}"
|
||||
|
|
|
@ -20,6 +20,7 @@ import boto3
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.flaky(retries=6, delay=1)
|
||||
async def test_basic_s3_logging(sync_mode):
|
||||
verbose_logger.setLevel(level=logging.DEBUG)
|
||||
litellm.success_callback = ["s3"]
|
||||
|
|
|
@ -3789,6 +3789,7 @@ def test_completion_anyscale_api():
|
|||
|
||||
|
||||
# @pytest.mark.skip(reason="flaky test, times out frequently")
|
||||
@pytest.mark.flaky(retries=6, delay=1)
|
||||
def test_completion_cohere():
|
||||
try:
|
||||
# litellm.set_verbose=True
|
||||
|
|
|
@ -10,6 +10,7 @@ sys.path.insert(
|
|||
) # Adds the parent directory to the system path
|
||||
from litellm import Router
|
||||
import pytest
|
||||
import litellm
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
|
||||
|
||||
|
@ -22,6 +23,9 @@ def model_list():
|
|||
"model": "gpt-3.5-turbo",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
"model_info": {
|
||||
"access_groups": ["group1", "group2"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-4o",
|
||||
|
@ -250,3 +254,583 @@ async def test_router_make_call(model_list):
|
|||
mock_response="https://example.com/image.png",
|
||||
)
|
||||
assert response.data[0].url == "https://example.com/image.png"
|
||||
|
||||
|
||||
def test_update_kwargs_with_deployment(model_list):
|
||||
"""Test if the '_update_kwargs_with_deployment' function is working correctly"""
|
||||
router = Router(model_list=model_list)
|
||||
kwargs: dict = {"metadata": {}}
|
||||
deployment = router.get_deployment_by_model_group_name(
|
||||
model_group_name="gpt-3.5-turbo"
|
||||
)
|
||||
router._update_kwargs_with_deployment(
|
||||
deployment=deployment,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
set_fields = ["deployment", "api_base", "model_info"]
|
||||
assert all(field in kwargs["metadata"] for field in set_fields)
|
||||
|
||||
|
||||
def test_update_kwargs_with_default_litellm_params(model_list):
|
||||
"""Test if the '_update_kwargs_with_default_litellm_params' function is working correctly"""
|
||||
router = Router(
|
||||
model_list=model_list,
|
||||
default_litellm_params={"api_key": "test", "metadata": {"key": "value"}},
|
||||
)
|
||||
kwargs: dict = {"metadata": {"key2": "value2"}}
|
||||
router._update_kwargs_with_default_litellm_params(kwargs=kwargs)
|
||||
assert kwargs["api_key"] == "test"
|
||||
assert kwargs["metadata"]["key"] == "value"
|
||||
assert kwargs["metadata"]["key2"] == "value2"
|
||||
|
||||
|
||||
def test_get_async_openai_model_client(model_list):
|
||||
"""Test if the '_get_async_openai_model_client' function is working correctly"""
|
||||
router = Router(model_list=model_list)
|
||||
deployment = router.get_deployment_by_model_group_name(
|
||||
model_group_name="gpt-3.5-turbo"
|
||||
)
|
||||
model_client = router._get_async_openai_model_client(
|
||||
deployment=deployment, kwargs={}
|
||||
)
|
||||
assert model_client is not None
|
||||
|
||||
|
||||
def test_get_timeout(model_list):
|
||||
"""Test if the '_get_timeout' function is working correctly"""
|
||||
router = Router(model_list=model_list)
|
||||
timeout = router._get_timeout(kwargs={}, data={"timeout": 100})
|
||||
assert timeout == 100
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"fallback_kwarg, expected_error",
|
||||
[
|
||||
("mock_testing_fallbacks", litellm.InternalServerError),
|
||||
("mock_testing_context_fallbacks", litellm.ContextWindowExceededError),
|
||||
("mock_testing_content_policy_fallbacks", litellm.ContentPolicyViolationError),
|
||||
],
|
||||
)
|
||||
def test_handle_mock_testing_fallbacks(model_list, fallback_kwarg, expected_error):
|
||||
"""Test if the '_handle_mock_testing_fallbacks' function is working correctly"""
|
||||
router = Router(model_list=model_list)
|
||||
with pytest.raises(expected_error):
|
||||
data = {
|
||||
fallback_kwarg: True,
|
||||
}
|
||||
router._handle_mock_testing_fallbacks(
|
||||
kwargs=data,
|
||||
)
|
||||
|
||||
|
||||
def test_handle_mock_testing_rate_limit_error(model_list):
|
||||
"""Test if the '_handle_mock_testing_rate_limit_error' function is working correctly"""
|
||||
router = Router(model_list=model_list)
|
||||
with pytest.raises(litellm.RateLimitError):
|
||||
data = {
|
||||
"mock_testing_rate_limit_error": True,
|
||||
}
|
||||
router._handle_mock_testing_rate_limit_error(
|
||||
kwargs=data,
|
||||
)
|
||||
|
||||
|
||||
def test_get_fallback_model_group_from_fallbacks(model_list):
|
||||
"""Test if the '_get_fallback_model_group_from_fallbacks' function is working correctly"""
|
||||
router = Router(model_list=model_list)
|
||||
fallback_model_group_name = router._get_fallback_model_group_from_fallbacks(
|
||||
model_group="gpt-4o",
|
||||
fallbacks=[{"gpt-4o": "gpt-3.5-turbo"}],
|
||||
)
|
||||
assert fallback_model_group_name == "gpt-3.5-turbo"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_deployment_callback_on_success(model_list, sync_mode):
|
||||
"""Test if the '_deployment_callback_on_success' function is working correctly"""
|
||||
import time
|
||||
|
||||
router = Router(model_list=model_list)
|
||||
|
||||
kwargs = {
|
||||
"litellm_params": {
|
||||
"metadata": {
|
||||
"model_group": "gpt-3.5-turbo",
|
||||
},
|
||||
"model_info": {"id": 100},
|
||||
},
|
||||
}
|
||||
response = litellm.ModelResponse(
|
||||
model="gpt-3.5-turbo",
|
||||
usage={"total_tokens": 100},
|
||||
)
|
||||
if sync_mode:
|
||||
tpm_key = router.sync_deployment_callback_on_success(
|
||||
kwargs=kwargs,
|
||||
completion_response=response,
|
||||
start_time=time.time(),
|
||||
end_time=time.time(),
|
||||
)
|
||||
else:
|
||||
tpm_key = await router.deployment_callback_on_success(
|
||||
kwargs=kwargs,
|
||||
completion_response=response,
|
||||
start_time=time.time(),
|
||||
end_time=time.time(),
|
||||
)
|
||||
assert tpm_key is not None
|
||||
|
||||
|
||||
def test_deployment_callback_on_failure(model_list):
|
||||
"""Test if the '_deployment_callback_on_failure' function is working correctly"""
|
||||
import time
|
||||
|
||||
router = Router(model_list=model_list)
|
||||
kwargs = {
|
||||
"litellm_params": {
|
||||
"metadata": {
|
||||
"model_group": "gpt-3.5-turbo",
|
||||
},
|
||||
"model_info": {"id": 100},
|
||||
},
|
||||
}
|
||||
result = router.deployment_callback_on_failure(
|
||||
kwargs=kwargs,
|
||||
completion_response=None,
|
||||
start_time=time.time(),
|
||||
end_time=time.time(),
|
||||
)
|
||||
assert isinstance(result, bool)
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_log_retry(model_list):
|
||||
"""Test if the '_log_retry' function is working correctly"""
|
||||
import time
|
||||
|
||||
router = Router(model_list=model_list)
|
||||
new_kwargs = router.log_retry(
|
||||
kwargs={"metadata": {}},
|
||||
e=Exception(),
|
||||
)
|
||||
assert "metadata" in new_kwargs
|
||||
assert "previous_models" in new_kwargs["metadata"]
|
||||
|
||||
|
||||
def test_update_usage(model_list):
|
||||
"""Test if the '_update_usage' function is working correctly"""
|
||||
router = Router(model_list=model_list)
|
||||
deployment = router.get_deployment_by_model_group_name(
|
||||
model_group_name="gpt-3.5-turbo"
|
||||
)
|
||||
deployment_id = deployment["model_info"]["id"]
|
||||
request_count = router._update_usage(
|
||||
deployment_id=deployment_id,
|
||||
)
|
||||
assert request_count == 1
|
||||
|
||||
request_count = router._update_usage(
|
||||
deployment_id=deployment_id,
|
||||
)
|
||||
|
||||
assert request_count == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"finish_reason, expected_error", [("content_filter", True), ("stop", False)]
|
||||
)
|
||||
def test_should_raise_content_policy_error(model_list, finish_reason, expected_error):
|
||||
"""Test if the '_should_raise_content_policy_error' function is working correctly"""
|
||||
router = Router(model_list=model_list)
|
||||
|
||||
assert (
|
||||
router._should_raise_content_policy_error(
|
||||
model="gpt-3.5-turbo",
|
||||
response=litellm.ModelResponse(
|
||||
model="gpt-3.5-turbo",
|
||||
choices=[
|
||||
{
|
||||
"finish_reason": finish_reason,
|
||||
"message": {"content": "I'm fine, thank you!"},
|
||||
}
|
||||
],
|
||||
usage={"total_tokens": 100},
|
||||
),
|
||||
kwargs={
|
||||
"content_policy_fallbacks": [{"gpt-3.5-turbo": "gpt-4o"}],
|
||||
},
|
||||
)
|
||||
is expected_error
|
||||
)
|
||||
|
||||
|
||||
def test_get_healthy_deployments(model_list):
|
||||
"""Test if the '_get_healthy_deployments' function is working correctly"""
|
||||
router = Router(model_list=model_list)
|
||||
deployments = router._get_healthy_deployments(model="gpt-3.5-turbo")
|
||||
assert len(deployments) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_routing_strategy_pre_call_checks(model_list, sync_mode):
|
||||
"""Test if the '_routing_strategy_pre_call_checks' function is working correctly"""
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
|
||||
callback = CustomLogger()
|
||||
litellm.callbacks = [callback]
|
||||
|
||||
router = Router(model_list=model_list)
|
||||
|
||||
deployment = router.get_deployment_by_model_group_name(
|
||||
model_group_name="gpt-3.5-turbo"
|
||||
)
|
||||
|
||||
litellm_logging_obj = Logging(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
stream=False,
|
||||
call_type="acompletion",
|
||||
litellm_call_id="1234",
|
||||
start_time=datetime.now(),
|
||||
function_id="1234",
|
||||
)
|
||||
if sync_mode:
|
||||
router.routing_strategy_pre_call_checks(deployment)
|
||||
else:
|
||||
## NO EXCEPTION
|
||||
await router.async_routing_strategy_pre_call_checks(
|
||||
deployment, litellm_logging_obj
|
||||
)
|
||||
|
||||
## WITH EXCEPTION - rate limit error
|
||||
with patch.object(
|
||||
callback,
|
||||
"async_pre_call_check",
|
||||
AsyncMock(
|
||||
side_effect=litellm.RateLimitError(
|
||||
message="Rate limit error",
|
||||
llm_provider="openai",
|
||||
model="gpt-3.5-turbo",
|
||||
)
|
||||
),
|
||||
):
|
||||
try:
|
||||
await router.async_routing_strategy_pre_call_checks(
|
||||
deployment, litellm_logging_obj
|
||||
)
|
||||
pytest.fail("Exception was not raised")
|
||||
except Exception as e:
|
||||
assert isinstance(e, litellm.RateLimitError)
|
||||
|
||||
## WITH EXCEPTION - generic error
|
||||
with patch.object(
|
||||
callback, "async_pre_call_check", AsyncMock(side_effect=Exception("Error"))
|
||||
):
|
||||
try:
|
||||
await router.async_routing_strategy_pre_call_checks(
|
||||
deployment, litellm_logging_obj
|
||||
)
|
||||
pytest.fail("Exception was not raised")
|
||||
except Exception as e:
|
||||
assert isinstance(e, Exception)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"set_supported_environments, supported_environments, is_supported",
|
||||
[(True, ["staging"], True), (False, None, True), (True, ["development"], False)],
|
||||
)
|
||||
def test_create_deployment(
|
||||
model_list, set_supported_environments, supported_environments, is_supported
|
||||
):
|
||||
"""Test if the '_create_deployment' function is working correctly"""
|
||||
router = Router(model_list=model_list)
|
||||
|
||||
if set_supported_environments:
|
||||
os.environ["LITELLM_ENVIRONMENT"] = "staging"
|
||||
deployment = router._create_deployment(
|
||||
deployment_info={},
|
||||
_model_name="gpt-3.5-turbo",
|
||||
_litellm_params={
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": "test",
|
||||
"custom_llm_provider": "openai",
|
||||
},
|
||||
_model_info={
|
||||
"id": 100,
|
||||
"supported_environments": supported_environments,
|
||||
},
|
||||
)
|
||||
if is_supported:
|
||||
assert deployment is not None
|
||||
else:
|
||||
assert deployment is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"set_supported_environments, supported_environments, is_supported",
|
||||
[(True, ["staging"], True), (False, None, True), (True, ["development"], False)],
|
||||
)
|
||||
def test_deployment_is_active_for_environment(
|
||||
model_list, set_supported_environments, supported_environments, is_supported
|
||||
):
|
||||
"""Test if the '_deployment_is_active_for_environment' function is working correctly"""
|
||||
router = Router(model_list=model_list)
|
||||
deployment = router.get_deployment_by_model_group_name(
|
||||
model_group_name="gpt-3.5-turbo"
|
||||
)
|
||||
if set_supported_environments:
|
||||
os.environ["LITELLM_ENVIRONMENT"] = "staging"
|
||||
deployment["model_info"]["supported_environments"] = supported_environments
|
||||
if is_supported:
|
||||
assert (
|
||||
router.deployment_is_active_for_environment(deployment=deployment) is True
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
router.deployment_is_active_for_environment(deployment=deployment) is False
|
||||
)
|
||||
|
||||
|
||||
def test_set_model_list(model_list):
|
||||
"""Test if the '_set_model_list' function is working correctly"""
|
||||
router = Router(model_list=model_list)
|
||||
router.set_model_list(model_list=model_list)
|
||||
assert len(router.model_list) == len(model_list)
|
||||
|
||||
|
||||
def test_add_deployment(model_list):
|
||||
"""Test if the '_add_deployment' function is working correctly"""
|
||||
router = Router(model_list=model_list)
|
||||
deployment = router.get_deployment_by_model_group_name(
|
||||
model_group_name="gpt-3.5-turbo"
|
||||
)
|
||||
deployment["model_info"]["id"] = 100
|
||||
## Test 1: call user facing function
|
||||
router.add_deployment(deployment=deployment)
|
||||
|
||||
## Test 2: call internal function
|
||||
router._add_deployment(deployment=deployment)
|
||||
assert len(router.model_list) == len(model_list) + 1
|
||||
|
||||
|
||||
def test_upsert_deployment(model_list):
|
||||
"""Test if the 'upsert_deployment' function is working correctly"""
|
||||
router = Router(model_list=model_list)
|
||||
print("model list", len(router.model_list))
|
||||
deployment = router.get_deployment_by_model_group_name(
|
||||
model_group_name="gpt-3.5-turbo"
|
||||
)
|
||||
deployment.litellm_params.model = "gpt-4o"
|
||||
router.upsert_deployment(deployment=deployment)
|
||||
assert len(router.model_list) == len(model_list)
|
||||
|
||||
|
||||
def test_delete_deployment(model_list):
|
||||
"""Test if the 'delete_deployment' function is working correctly"""
|
||||
router = Router(model_list=model_list)
|
||||
deployment = router.get_deployment_by_model_group_name(
|
||||
model_group_name="gpt-3.5-turbo"
|
||||
)
|
||||
router.delete_deployment(id=deployment["model_info"]["id"])
|
||||
assert len(router.model_list) == len(model_list) - 1
|
||||
|
||||
|
||||
def test_get_model_info(model_list):
|
||||
"""Test if the 'get_model_info' function is working correctly"""
|
||||
router = Router(model_list=model_list)
|
||||
deployment = router.get_deployment_by_model_group_name(
|
||||
model_group_name="gpt-3.5-turbo"
|
||||
)
|
||||
model_info = router.get_model_info(id=deployment["model_info"]["id"])
|
||||
assert model_info is not None
|
||||
|
||||
|
||||
def test_get_model_group(model_list):
|
||||
"""Test if the 'get_model_group' function is working correctly"""
|
||||
router = Router(model_list=model_list)
|
||||
deployment = router.get_deployment_by_model_group_name(
|
||||
model_group_name="gpt-3.5-turbo"
|
||||
)
|
||||
model_group = router.get_model_group(id=deployment["model_info"]["id"])
|
||||
assert model_group is not None
|
||||
assert model_group[0]["model_name"] == "gpt-3.5-turbo"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("user_facing_model_group_name", ["gpt-3.5-turbo", "gpt-4o"])
|
||||
def test_set_model_group_info(model_list, user_facing_model_group_name):
|
||||
"""Test if the 'set_model_group_info' function is working correctly"""
|
||||
router = Router(model_list=model_list)
|
||||
resp = router._set_model_group_info(
|
||||
model_group="gpt-3.5-turbo",
|
||||
user_facing_model_group_name=user_facing_model_group_name,
|
||||
)
|
||||
assert resp is not None
|
||||
assert resp.model_group == user_facing_model_group_name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_response_headers(model_list):
|
||||
"""Test if the 'set_response_headers' function is working correctly"""
|
||||
router = Router(model_list=model_list)
|
||||
resp = await router.set_response_headers(response=None, model_group=None)
|
||||
assert resp is None
|
||||
|
||||
|
||||
def test_get_all_deployments(model_list):
|
||||
"""Test if the 'get_all_deployments' function is working correctly"""
|
||||
router = Router(model_list=model_list)
|
||||
deployments = router._get_all_deployments(
|
||||
model_name="gpt-3.5-turbo", model_alias="gpt-3.5-turbo"
|
||||
)
|
||||
assert len(deployments) > 0
|
||||
|
||||
|
||||
def test_get_model_access_groups(model_list):
|
||||
"""Test if the 'get_model_access_groups' function is working correctly"""
|
||||
router = Router(model_list=model_list)
|
||||
access_groups = router.get_model_access_groups()
|
||||
assert len(access_groups) == 2
|
||||
|
||||
|
||||
def test_update_settings(model_list):
|
||||
"""Test if the 'update_settings' function is working correctly"""
|
||||
router = Router(model_list=model_list)
|
||||
pre_update_allowed_fails = router.allowed_fails
|
||||
router.update_settings(**{"allowed_fails": 20})
|
||||
assert router.allowed_fails != pre_update_allowed_fails
|
||||
assert router.allowed_fails == 20
|
||||
|
||||
|
||||
def test_common_checks_available_deployment(model_list):
|
||||
"""Test if the 'common_checks_available_deployment' function is working correctly"""
|
||||
router = Router(model_list=model_list)
|
||||
_, available_deployments = router._common_checks_available_deployment(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
input="hi",
|
||||
specific_deployment=False,
|
||||
)
|
||||
|
||||
assert len(available_deployments) > 0
|
||||
|
||||
|
||||
def test_filter_cooldown_deployments(model_list):
|
||||
"""Test if the 'filter_cooldown_deployments' function is working correctly"""
|
||||
router = Router(model_list=model_list)
|
||||
deployments = router._filter_cooldown_deployments(
|
||||
healthy_deployments=router._get_all_deployments(model_name="gpt-3.5-turbo"), # type: ignore
|
||||
cooldown_deployments=[],
|
||||
)
|
||||
assert len(deployments) == len(
|
||||
router._get_all_deployments(model_name="gpt-3.5-turbo")
|
||||
)
|
||||
|
||||
|
||||
def test_track_deployment_metrics(model_list):
|
||||
"""Test if the 'track_deployment_metrics' function is working correctly"""
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
router = Router(model_list=model_list)
|
||||
router._track_deployment_metrics(
|
||||
deployment=router.get_deployment_by_model_group_name(
|
||||
model_group_name="gpt-3.5-turbo"
|
||||
),
|
||||
response=ModelResponse(
|
||||
model="gpt-3.5-turbo",
|
||||
usage={"total_tokens": 100},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"exception_type, exception_name, num_retries",
|
||||
[
|
||||
(litellm.exceptions.BadRequestError, "BadRequestError", 3),
|
||||
(litellm.exceptions.AuthenticationError, "AuthenticationError", 4),
|
||||
(litellm.exceptions.RateLimitError, "RateLimitError", 6),
|
||||
(
|
||||
litellm.exceptions.ContentPolicyViolationError,
|
||||
"ContentPolicyViolationError",
|
||||
7,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_num_retries_from_retry_policy(
|
||||
model_list, exception_type, exception_name, num_retries
|
||||
):
|
||||
"""Test if the 'get_num_retries_from_retry_policy' function is working correctly"""
|
||||
from litellm.router import RetryPolicy
|
||||
|
||||
data = {exception_name + "Retries": num_retries}
|
||||
print("data", data)
|
||||
router = Router(
|
||||
model_list=model_list,
|
||||
retry_policy=RetryPolicy(**data),
|
||||
)
|
||||
print("exception_type", exception_type)
|
||||
calc_num_retries = router.get_num_retries_from_retry_policy(
|
||||
exception=exception_type(
|
||||
message="test", llm_provider="openai", model="gpt-3.5-turbo"
|
||||
)
|
||||
)
|
||||
assert calc_num_retries == num_retries
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"exception_type, exception_name, allowed_fails",
|
||||
[
|
||||
(litellm.exceptions.BadRequestError, "BadRequestError", 3),
|
||||
(litellm.exceptions.AuthenticationError, "AuthenticationError", 4),
|
||||
(litellm.exceptions.RateLimitError, "RateLimitError", 6),
|
||||
(
|
||||
litellm.exceptions.ContentPolicyViolationError,
|
||||
"ContentPolicyViolationError",
|
||||
7,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_allowed_fails_from_policy(
|
||||
model_list, exception_type, exception_name, allowed_fails
|
||||
):
|
||||
"""Test if the 'get_allowed_fails_from_policy' function is working correctly"""
|
||||
from litellm.types.router import AllowedFailsPolicy
|
||||
|
||||
data = {exception_name + "AllowedFails": allowed_fails}
|
||||
router = Router(
|
||||
model_list=model_list, allowed_fails_policy=AllowedFailsPolicy(**data)
|
||||
)
|
||||
calc_allowed_fails = router.get_allowed_fails_from_policy(
|
||||
exception=exception_type(
|
||||
message="test", llm_provider="openai", model="gpt-3.5-turbo"
|
||||
)
|
||||
)
|
||||
assert calc_allowed_fails == allowed_fails
|
||||
|
||||
|
||||
def test_initialize_alerting(model_list):
|
||||
"""Test if the 'initialize_alerting' function is working correctly"""
|
||||
from litellm.types.router import AlertingConfig
|
||||
from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting
|
||||
|
||||
router = Router(
|
||||
model_list=model_list, alerting_config=AlertingConfig(webhook_url="test")
|
||||
)
|
||||
router._initialize_alerting()
|
||||
|
||||
callback_added = False
|
||||
for callback in litellm.callbacks:
|
||||
if isinstance(callback, SlackAlerting):
|
||||
callback_added = True
|
||||
assert callback_added is True
|
||||
|
||||
|
||||
def test_flush_cache(model_list):
|
||||
"""Test if the 'flush_cache' function is working correctly"""
|
||||
router = Router(model_list=model_list)
|
||||
router.cache.set_cache("test", "test")
|
||||
assert router.cache.get_cache("test") == "test"
|
||||
router.flush_cache()
|
||||
assert router.cache.get_cache("test") is None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue