From 8530000b44241f949580e1f38a49088627ea2c88 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 16 Oct 2024 20:02:27 +0530 Subject: [PATCH] (testing) Router add testing coverage (#6253) * test: add more router code coverage * test: additional router testing coverage * fix: fix linting error * test: fix tests for ci/cd * test: fix test * test: handle flaky tests --------- Co-authored-by: Krrish Dholakia --- litellm/router.py | 200 +++--- litellm/router_utils/cooldown_handlers.py | 10 +- .../track_deployment_metrics.py | 3 +- .../router_code_coverage.py | 13 +- tests/local_testing/test_amazing_s3_logs.py | 1 + tests/local_testing/test_completion.py | 1 + .../test_router_helper_utils.py | 584 ++++++++++++++++++ 7 files changed, 706 insertions(+), 106 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index bbc3db21d..cec462846 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -861,11 +861,23 @@ class Router: self.fail_calls[model_name] += 1 raise e + def _update_kwargs_with_default_litellm_params(self, kwargs: dict) -> None: + """ + Adds default litellm params to kwargs, if set. + """ + for k, v in self.default_litellm_params.items(): + if ( + k not in kwargs and v is not None + ): # prioritize model-specific params > default router params + kwargs[k] = v + elif k == "metadata": + kwargs[k].update(v) + def _update_kwargs_with_deployment(self, deployment: dict, kwargs: dict) -> None: """ - Adds selected deployment, model_info and api_base to kwargs["metadata"] - - This is used in litellm logging callbacks + 2 jobs: + - Adds selected deployment, model_info and api_base to kwargs["metadata"] (used for logging) + - Adds default litellm params to kwargs, if set. """ kwargs.setdefault("metadata", {}).update( { @@ -875,13 +887,7 @@ class Router: } ) kwargs["model_info"] = deployment.get("model_info", {}) - for k, v in self.default_litellm_params.items(): - if ( - k not in kwargs and v is not None - ): # prioritize model-specific params > default router params - kwargs[k] = v - elif k == "metadata": - kwargs[k].update(v) + self._update_kwargs_with_default_litellm_params(kwargs=kwargs) def _get_async_openai_model_client(self, deployment: dict, kwargs: dict): """ @@ -910,6 +916,7 @@ class Router: return model_client def _get_timeout(self, kwargs: dict, data: dict) -> Optional[Union[float, int]]: + """Helper to get timeout from kwargs or deployment params""" timeout = ( data.get( "timeout", None @@ -3414,11 +3421,10 @@ class Router: ): """ Track remaining tpm/rpm quota for model in model_list + + Currently, only updates TPM usage. """ try: - """ - Update TPM usage on success - """ if kwargs["litellm_params"].get("metadata") is None: pass else: @@ -3459,6 +3465,8 @@ class Router: deployment_id=id, ) + return tpm_key + except Exception as e: verbose_router_logger.exception( "litellm.router.Router::deployment_callback_on_success(): Exception occured - {}".format( @@ -3473,7 +3481,14 @@ class Router: completion_response, # response from completion start_time, end_time, # start/end time - ): + ) -> Optional[str]: + """ + Tracks the number of successes for a deployment in the current minute (using in-memory cache) + + Returns: + - key: str - The key used to increment the cache + - None: if no key is found + """ id = None if kwargs["litellm_params"].get("metadata") is None: pass @@ -3482,15 +3497,18 @@ class Router: model_info = kwargs["litellm_params"].get("model_info", {}) or {} id = model_info.get("id", None) if model_group is None or id is None: - return + return None elif isinstance(id, int): id = str(id) if id is not None: - increment_deployment_successes_for_current_minute( + key = increment_deployment_successes_for_current_minute( litellm_router_instance=self, deployment_id=id, ) + return key + + return None def deployment_callback_on_failure( self, @@ -3498,15 +3516,19 @@ class Router: completion_response, # response from completion start_time, end_time, # start/end time - ): + ) -> bool: + """ + 2 jobs: + - Tracks the number of failures for a deployment in the current minute (using in-memory cache) + - Puts the deployment in cooldown if it exceeds the allowed fails / minute + + Returns: + - True if the deployment should be put in cooldown + - False if the deployment should not be put in cooldown + """ try: exception = kwargs.get("exception", None) exception_status = getattr(exception, "status_code", "") - model_name = kwargs.get("model", None) # i.e. gpt35turbo - custom_llm_provider = kwargs.get("litellm_params", {}).get( - "custom_llm_provider", None - ) # i.e. azure - kwargs.get("litellm_params", {}).get("metadata", None) _model_info = kwargs.get("litellm_params", {}).get("model_info", {}) exception_headers = litellm.litellm_core_utils.exception_mapping_utils._get_response_headers( @@ -3535,15 +3557,17 @@ class Router: litellm_router_instance=self, deployment_id=deployment_id, ) - _set_cooldown_deployments( + result = _set_cooldown_deployments( litellm_router_instance=self, exception_status=exception_status, original_exception=exception, deployment=deployment_id, time_to_cooldown=_time_to_cooldown, ) # setting deployment_id in cooldown deployments - if custom_llm_provider: - model_name = f"{custom_llm_provider}/{model_name}" + + return result + else: + return False except Exception as e: raise e @@ -3582,9 +3606,12 @@ class Router: except Exception as e: raise e - def _update_usage(self, deployment_id: str): + def _update_usage(self, deployment_id: str) -> int: """ Update deployment rpm for that minute + + Returns: + - int: request count """ rpm_key = deployment_id @@ -3600,6 +3627,8 @@ class Router: key=rpm_key, value=request_count, local_only=True ) # don't change existing ttl + return request_count + def _is_cooldown_required( self, model_id: str, @@ -3778,7 +3807,7 @@ class Router: for _callback in litellm.callbacks: if isinstance(_callback, CustomLogger): try: - _ = await _callback.async_pre_call_check(deployment) + await _callback.async_pre_call_check(deployment) except litellm.RateLimitError as e: ## LOG FAILURE EVENT if logging_obj is not None: @@ -3848,10 +3877,23 @@ class Router: return hash_object.hexdigest() def _create_deployment( - self, model: dict, _model_name: str, _litellm_params: dict, _model_info: dict - ): + self, + deployment_info: dict, + _model_name: str, + _litellm_params: dict, + _model_info: dict, + ) -> Optional[Deployment]: + """ + Create a deployment object and add it to the model list + + If the deployment is not active for the current environment, it is ignored + + Returns: + - Deployment: The deployment object + - None: If the deployment is not active for the current environment (if 'supported_environments' is set in litellm_params) + """ deployment = Deployment( - **model, + **deployment_info, model_name=_model_name, litellm_params=LiteLLM_Params(**_litellm_params), model_info=_model_info, @@ -3870,18 +3912,18 @@ class Router: ) ## Check if LLM Deployment is allowed for this deployment - if deployment.model_info and "supported_environments" in deployment.model_info: - if ( - self.deployment_is_active_for_environment(deployment=deployment) - is not True - ): - return + if self.deployment_is_active_for_environment(deployment=deployment) is not True: + verbose_router_logger.warning( + f"Ignoring deployment {deployment.model_name} as it is not active for environment {deployment.model_info['supported_environments']}" + ) + return None deployment = self._add_deployment(deployment=deployment) model = deployment.to_json(exclude_none=True) self.model_list.append(model) + return deployment def deployment_is_active_for_environment(self, deployment: Deployment) -> bool: """ @@ -3896,6 +3938,12 @@ class Router: - ValueError: If LITELLM_ENVIRONMENT is not set in .env or not one of the valid values - ValueError: If supported_environments is not set in model_info or not one of the valid values """ + if ( + deployment.model_info is None + or "supported_environments" not in deployment.model_info + or deployment.model_info["supported_environments"] is None + ): + return True litellm_environment = get_secret_str(secret_name="LITELLM_ENVIRONMENT") if litellm_environment is None: raise ValueError( @@ -3913,7 +3961,6 @@ class Router: f"supported_environments must be one of {VALID_LITELLM_ENVIRONMENTS}. but set as: {_env} for deployment: {deployment}" ) - # validate litellm_environment is one of LiteLLMEnvironment if litellm_environment in deployment.model_info["supported_environments"]: return True return False @@ -3946,14 +3993,14 @@ class Router: for org in _litellm_params["organization"]: _litellm_params["organization"] = org self._create_deployment( - model=model, + deployment_info=model, _model_name=_model_name, _litellm_params=_litellm_params, _model_info=_model_info, ) else: self._create_deployment( - model=model, + deployment_info=model, _model_name=_model_name, _litellm_params=_litellm_params, _model_info=_model_info, @@ -4118,9 +4165,9 @@ class Router: if removal_idx is not None: self.model_list.pop(removal_idx) - else: - # if the model_id is not in router - self.add_deployment(deployment=deployment) + + # if the model_id is not in router + self.add_deployment(deployment=deployment) return deployment def delete_deployment(self, id: str) -> Optional[Deployment]: @@ -4628,16 +4675,13 @@ class Router: from collections import defaultdict access_groups = defaultdict(list) - if self.access_groups: - return self.access_groups if self.model_list: for m in self.model_list: for group in m.get("model_info", {}).get("access_groups", []): model_name = m["model_name"] access_groups[group].append(model_name) - # set access groups - self.access_groups = access_groups + return access_groups def get_settings(self): @@ -4672,6 +4716,9 @@ class Router: return _settings_to_return def update_settings(self, **kwargs): + """ + Update the router settings. + """ # only the following settings are allowed to be configured _allowed_settings = [ "routing_strategy_args", @@ -5367,66 +5414,16 @@ class Router: return healthy_deployments def _track_deployment_metrics(self, deployment, response=None): + """ + Tracks successful requests rpm usage. + """ try: - litellm_params = deployment["litellm_params"] - api_base = litellm_params.get("api_base", "") - model = litellm_params.get("model", "") - model_id = deployment.get("model_info", {}).get("id", None) if response is None: # update self.deployment_stats if model_id is not None: self._update_usage(model_id) # update in-memory cache for tracking - if model_id in self.deployment_stats: - # only update num_requests - self.deployment_stats[model_id]["num_requests"] += 1 - else: - self.deployment_stats[model_id] = { - "api_base": api_base, - "model": model, - "num_requests": 1, - } - else: - # check response_ms and update num_successes - if isinstance(response, dict): - response_ms = response.get("_response_ms", 0) - else: - response_ms = 0 - if model_id is not None: - if model_id in self.deployment_stats: - # check if avg_latency exists - if "avg_latency" in self.deployment_stats[model_id]: - # update avg_latency - self.deployment_stats[model_id]["avg_latency"] = ( - self.deployment_stats[model_id]["avg_latency"] - + response_ms - ) / self.deployment_stats[model_id]["num_successes"] - else: - self.deployment_stats[model_id]["avg_latency"] = response_ms - - # check if num_successes exists - if "num_successes" in self.deployment_stats[model_id]: - self.deployment_stats[model_id]["num_successes"] += 1 - else: - self.deployment_stats[model_id]["num_successes"] = 1 - else: - self.deployment_stats[model_id] = { - "api_base": api_base, - "model": model, - "num_successes": 1, - "avg_latency": response_ms, - } - if self.set_verbose is True and self.debug_level == "DEBUG": - from pprint import pformat - - # Assuming self.deployment_stats is your dictionary - formatted_stats = pformat(self.deployment_stats) - - # Assuming verbose_router_logger is your logger - verbose_router_logger.info( - "self.deployment_stats: \n%s", formatted_stats - ) except Exception as e: verbose_router_logger.error(f"Error in _track_deployment_metrics: {str(e)}") @@ -5442,6 +5439,7 @@ class Router: """ # if we can find the exception then in the retry policy -> return the number of retries retry_policy: Optional[RetryPolicy] = self.retry_policy + if ( self.model_group_retry_policy is not None and model_group is not None @@ -5540,7 +5538,9 @@ class Router: litellm.success_callback.append( _slack_alerting_logger.response_taking_too_long_callback ) - print("\033[94m\nInitialized Alerting for litellm.Router\033[0m\n") # noqa + verbose_router_logger.info( + "\033[94m\nInitialized Alerting for litellm.Router\033[0m\n" + ) def set_custom_routing_strategy( self, CustomRoutingStrategy: CustomRoutingStrategyBase diff --git a/litellm/router_utils/cooldown_handlers.py b/litellm/router_utils/cooldown_handlers.py index e9212e0e4..5d16950bd 100644 --- a/litellm/router_utils/cooldown_handlers.py +++ b/litellm/router_utils/cooldown_handlers.py @@ -148,13 +148,17 @@ def _set_cooldown_deployments( exception_status: Union[str, int], deployment: Optional[str] = None, time_to_cooldown: Optional[float] = None, -): +) -> bool: """ Add a model to the list of models being cooled down for that minute, if it exceeds the allowed fails / minute or the exception is not one that should be immediately retried (e.g. 401) + + Returns: + - True if the deployment should be put in cooldown + - False if the deployment should not be put in cooldown """ if ( _should_run_cooldown_logic( @@ -163,7 +167,7 @@ def _set_cooldown_deployments( is False or deployment is None ): - return + return False exception_status_int = cast_exception_status_to_int(exception_status) @@ -191,6 +195,8 @@ def _set_cooldown_deployments( cooldown_time=cooldown_time, ) ) + return True + return False async def _async_get_cooldown_deployments( diff --git a/litellm/router_utils/router_callbacks/track_deployment_metrics.py b/litellm/router_utils/router_callbacks/track_deployment_metrics.py index c09c25543..5d4440222 100644 --- a/litellm/router_utils/router_callbacks/track_deployment_metrics.py +++ b/litellm/router_utils/router_callbacks/track_deployment_metrics.py @@ -24,7 +24,7 @@ else: def increment_deployment_successes_for_current_minute( litellm_router_instance: LitellmRouter, deployment_id: str, -): +) -> str: """ In-Memory: Increments the number of successes for the current minute for a deployment_id """ @@ -35,6 +35,7 @@ def increment_deployment_successes_for_current_minute( value=1, ttl=60, ) + return key def increment_deployment_failures_for_current_minute( diff --git a/tests/code_coverage_tests/router_code_coverage.py b/tests/code_coverage_tests/router_code_coverage.py index fb88c3504..946c30220 100644 --- a/tests/code_coverage_tests/router_code_coverage.py +++ b/tests/code_coverage_tests/router_code_coverage.py @@ -11,9 +11,15 @@ def get_function_names_from_file(file_path): function_names = [] - for node in ast.walk(tree): + for node in tree.body: if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + # Top-level functions function_names.append(node.name) + elif isinstance(node, ast.ClassDef): + # Functions inside classes + for class_node in node.body: + if isinstance(class_node, (ast.FunctionDef, ast.AsyncFunctionDef)): + function_names.append(class_node.name) return function_names @@ -79,6 +85,7 @@ ignored_function_names = [ "a_add_message", "aget_messages", "arun_thread", + "try_retrieve_batch", ] @@ -103,8 +110,8 @@ def main(): if func not in ignored_function_names: all_untested_functions.append(func) untested_perc = (len(all_untested_functions)) / len(router_functions) - print("perc_covered: ", untested_perc) - if untested_perc < 0.3: + print("untested_perc: ", untested_perc) + if untested_perc > 0: print("The following functions in router.py are not tested:") raise Exception( f"{untested_perc * 100:.2f}% of functions in router.py are not tested: {all_untested_functions}" diff --git a/tests/local_testing/test_amazing_s3_logs.py b/tests/local_testing/test_amazing_s3_logs.py index 5459647c1..f489a5a0e 100644 --- a/tests/local_testing/test_amazing_s3_logs.py +++ b/tests/local_testing/test_amazing_s3_logs.py @@ -20,6 +20,7 @@ import boto3 @pytest.mark.asyncio @pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.flaky(retries=6, delay=1) async def test_basic_s3_logging(sync_mode): verbose_logger.setLevel(level=logging.DEBUG) litellm.success_callback = ["s3"] diff --git a/tests/local_testing/test_completion.py b/tests/local_testing/test_completion.py index 33c0b67f1..4dc9cc91c 100644 --- a/tests/local_testing/test_completion.py +++ b/tests/local_testing/test_completion.py @@ -3789,6 +3789,7 @@ def test_completion_anyscale_api(): # @pytest.mark.skip(reason="flaky test, times out frequently") +@pytest.mark.flaky(retries=6, delay=1) def test_completion_cohere(): try: # litellm.set_verbose=True diff --git a/tests/router_unit_tests/test_router_helper_utils.py b/tests/router_unit_tests/test_router_helper_utils.py index 97660471b..f34eb428f 100644 --- a/tests/router_unit_tests/test_router_helper_utils.py +++ b/tests/router_unit_tests/test_router_helper_utils.py @@ -10,6 +10,7 @@ sys.path.insert( ) # Adds the parent directory to the system path from litellm import Router import pytest +import litellm from unittest.mock import patch, MagicMock, AsyncMock @@ -22,6 +23,9 @@ def model_list(): "model": "gpt-3.5-turbo", "api_key": os.getenv("OPENAI_API_KEY"), }, + "model_info": { + "access_groups": ["group1", "group2"], + }, }, { "model_name": "gpt-4o", @@ -250,3 +254,583 @@ async def test_router_make_call(model_list): mock_response="https://example.com/image.png", ) assert response.data[0].url == "https://example.com/image.png" + + +def test_update_kwargs_with_deployment(model_list): + """Test if the '_update_kwargs_with_deployment' function is working correctly""" + router = Router(model_list=model_list) + kwargs: dict = {"metadata": {}} + deployment = router.get_deployment_by_model_group_name( + model_group_name="gpt-3.5-turbo" + ) + router._update_kwargs_with_deployment( + deployment=deployment, + kwargs=kwargs, + ) + set_fields = ["deployment", "api_base", "model_info"] + assert all(field in kwargs["metadata"] for field in set_fields) + + +def test_update_kwargs_with_default_litellm_params(model_list): + """Test if the '_update_kwargs_with_default_litellm_params' function is working correctly""" + router = Router( + model_list=model_list, + default_litellm_params={"api_key": "test", "metadata": {"key": "value"}}, + ) + kwargs: dict = {"metadata": {"key2": "value2"}} + router._update_kwargs_with_default_litellm_params(kwargs=kwargs) + assert kwargs["api_key"] == "test" + assert kwargs["metadata"]["key"] == "value" + assert kwargs["metadata"]["key2"] == "value2" + + +def test_get_async_openai_model_client(model_list): + """Test if the '_get_async_openai_model_client' function is working correctly""" + router = Router(model_list=model_list) + deployment = router.get_deployment_by_model_group_name( + model_group_name="gpt-3.5-turbo" + ) + model_client = router._get_async_openai_model_client( + deployment=deployment, kwargs={} + ) + assert model_client is not None + + +def test_get_timeout(model_list): + """Test if the '_get_timeout' function is working correctly""" + router = Router(model_list=model_list) + timeout = router._get_timeout(kwargs={}, data={"timeout": 100}) + assert timeout == 100 + + +@pytest.mark.parametrize( + "fallback_kwarg, expected_error", + [ + ("mock_testing_fallbacks", litellm.InternalServerError), + ("mock_testing_context_fallbacks", litellm.ContextWindowExceededError), + ("mock_testing_content_policy_fallbacks", litellm.ContentPolicyViolationError), + ], +) +def test_handle_mock_testing_fallbacks(model_list, fallback_kwarg, expected_error): + """Test if the '_handle_mock_testing_fallbacks' function is working correctly""" + router = Router(model_list=model_list) + with pytest.raises(expected_error): + data = { + fallback_kwarg: True, + } + router._handle_mock_testing_fallbacks( + kwargs=data, + ) + + +def test_handle_mock_testing_rate_limit_error(model_list): + """Test if the '_handle_mock_testing_rate_limit_error' function is working correctly""" + router = Router(model_list=model_list) + with pytest.raises(litellm.RateLimitError): + data = { + "mock_testing_rate_limit_error": True, + } + router._handle_mock_testing_rate_limit_error( + kwargs=data, + ) + + +def test_get_fallback_model_group_from_fallbacks(model_list): + """Test if the '_get_fallback_model_group_from_fallbacks' function is working correctly""" + router = Router(model_list=model_list) + fallback_model_group_name = router._get_fallback_model_group_from_fallbacks( + model_group="gpt-4o", + fallbacks=[{"gpt-4o": "gpt-3.5-turbo"}], + ) + assert fallback_model_group_name == "gpt-3.5-turbo" + + +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_deployment_callback_on_success(model_list, sync_mode): + """Test if the '_deployment_callback_on_success' function is working correctly""" + import time + + router = Router(model_list=model_list) + + kwargs = { + "litellm_params": { + "metadata": { + "model_group": "gpt-3.5-turbo", + }, + "model_info": {"id": 100}, + }, + } + response = litellm.ModelResponse( + model="gpt-3.5-turbo", + usage={"total_tokens": 100}, + ) + if sync_mode: + tpm_key = router.sync_deployment_callback_on_success( + kwargs=kwargs, + completion_response=response, + start_time=time.time(), + end_time=time.time(), + ) + else: + tpm_key = await router.deployment_callback_on_success( + kwargs=kwargs, + completion_response=response, + start_time=time.time(), + end_time=time.time(), + ) + assert tpm_key is not None + + +def test_deployment_callback_on_failure(model_list): + """Test if the '_deployment_callback_on_failure' function is working correctly""" + import time + + router = Router(model_list=model_list) + kwargs = { + "litellm_params": { + "metadata": { + "model_group": "gpt-3.5-turbo", + }, + "model_info": {"id": 100}, + }, + } + result = router.deployment_callback_on_failure( + kwargs=kwargs, + completion_response=None, + start_time=time.time(), + end_time=time.time(), + ) + assert isinstance(result, bool) + assert result is False + + +def test_log_retry(model_list): + """Test if the '_log_retry' function is working correctly""" + import time + + router = Router(model_list=model_list) + new_kwargs = router.log_retry( + kwargs={"metadata": {}}, + e=Exception(), + ) + assert "metadata" in new_kwargs + assert "previous_models" in new_kwargs["metadata"] + + +def test_update_usage(model_list): + """Test if the '_update_usage' function is working correctly""" + router = Router(model_list=model_list) + deployment = router.get_deployment_by_model_group_name( + model_group_name="gpt-3.5-turbo" + ) + deployment_id = deployment["model_info"]["id"] + request_count = router._update_usage( + deployment_id=deployment_id, + ) + assert request_count == 1 + + request_count = router._update_usage( + deployment_id=deployment_id, + ) + + assert request_count == 2 + + +@pytest.mark.parametrize( + "finish_reason, expected_error", [("content_filter", True), ("stop", False)] +) +def test_should_raise_content_policy_error(model_list, finish_reason, expected_error): + """Test if the '_should_raise_content_policy_error' function is working correctly""" + router = Router(model_list=model_list) + + assert ( + router._should_raise_content_policy_error( + model="gpt-3.5-turbo", + response=litellm.ModelResponse( + model="gpt-3.5-turbo", + choices=[ + { + "finish_reason": finish_reason, + "message": {"content": "I'm fine, thank you!"}, + } + ], + usage={"total_tokens": 100}, + ), + kwargs={ + "content_policy_fallbacks": [{"gpt-3.5-turbo": "gpt-4o"}], + }, + ) + is expected_error + ) + + +def test_get_healthy_deployments(model_list): + """Test if the '_get_healthy_deployments' function is working correctly""" + router = Router(model_list=model_list) + deployments = router._get_healthy_deployments(model="gpt-3.5-turbo") + assert len(deployments) > 0 + + +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_routing_strategy_pre_call_checks(model_list, sync_mode): + """Test if the '_routing_strategy_pre_call_checks' function is working correctly""" + from litellm.integrations.custom_logger import CustomLogger + from litellm.litellm_core_utils.litellm_logging import Logging + + callback = CustomLogger() + litellm.callbacks = [callback] + + router = Router(model_list=model_list) + + deployment = router.get_deployment_by_model_group_name( + model_group_name="gpt-3.5-turbo" + ) + + litellm_logging_obj = Logging( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "hi"}], + stream=False, + call_type="acompletion", + litellm_call_id="1234", + start_time=datetime.now(), + function_id="1234", + ) + if sync_mode: + router.routing_strategy_pre_call_checks(deployment) + else: + ## NO EXCEPTION + await router.async_routing_strategy_pre_call_checks( + deployment, litellm_logging_obj + ) + + ## WITH EXCEPTION - rate limit error + with patch.object( + callback, + "async_pre_call_check", + AsyncMock( + side_effect=litellm.RateLimitError( + message="Rate limit error", + llm_provider="openai", + model="gpt-3.5-turbo", + ) + ), + ): + try: + await router.async_routing_strategy_pre_call_checks( + deployment, litellm_logging_obj + ) + pytest.fail("Exception was not raised") + except Exception as e: + assert isinstance(e, litellm.RateLimitError) + + ## WITH EXCEPTION - generic error + with patch.object( + callback, "async_pre_call_check", AsyncMock(side_effect=Exception("Error")) + ): + try: + await router.async_routing_strategy_pre_call_checks( + deployment, litellm_logging_obj + ) + pytest.fail("Exception was not raised") + except Exception as e: + assert isinstance(e, Exception) + + +@pytest.mark.parametrize( + "set_supported_environments, supported_environments, is_supported", + [(True, ["staging"], True), (False, None, True), (True, ["development"], False)], +) +def test_create_deployment( + model_list, set_supported_environments, supported_environments, is_supported +): + """Test if the '_create_deployment' function is working correctly""" + router = Router(model_list=model_list) + + if set_supported_environments: + os.environ["LITELLM_ENVIRONMENT"] = "staging" + deployment = router._create_deployment( + deployment_info={}, + _model_name="gpt-3.5-turbo", + _litellm_params={ + "model": "gpt-3.5-turbo", + "api_key": "test", + "custom_llm_provider": "openai", + }, + _model_info={ + "id": 100, + "supported_environments": supported_environments, + }, + ) + if is_supported: + assert deployment is not None + else: + assert deployment is None + + +@pytest.mark.parametrize( + "set_supported_environments, supported_environments, is_supported", + [(True, ["staging"], True), (False, None, True), (True, ["development"], False)], +) +def test_deployment_is_active_for_environment( + model_list, set_supported_environments, supported_environments, is_supported +): + """Test if the '_deployment_is_active_for_environment' function is working correctly""" + router = Router(model_list=model_list) + deployment = router.get_deployment_by_model_group_name( + model_group_name="gpt-3.5-turbo" + ) + if set_supported_environments: + os.environ["LITELLM_ENVIRONMENT"] = "staging" + deployment["model_info"]["supported_environments"] = supported_environments + if is_supported: + assert ( + router.deployment_is_active_for_environment(deployment=deployment) is True + ) + else: + assert ( + router.deployment_is_active_for_environment(deployment=deployment) is False + ) + + +def test_set_model_list(model_list): + """Test if the '_set_model_list' function is working correctly""" + router = Router(model_list=model_list) + router.set_model_list(model_list=model_list) + assert len(router.model_list) == len(model_list) + + +def test_add_deployment(model_list): + """Test if the '_add_deployment' function is working correctly""" + router = Router(model_list=model_list) + deployment = router.get_deployment_by_model_group_name( + model_group_name="gpt-3.5-turbo" + ) + deployment["model_info"]["id"] = 100 + ## Test 1: call user facing function + router.add_deployment(deployment=deployment) + + ## Test 2: call internal function + router._add_deployment(deployment=deployment) + assert len(router.model_list) == len(model_list) + 1 + + +def test_upsert_deployment(model_list): + """Test if the 'upsert_deployment' function is working correctly""" + router = Router(model_list=model_list) + print("model list", len(router.model_list)) + deployment = router.get_deployment_by_model_group_name( + model_group_name="gpt-3.5-turbo" + ) + deployment.litellm_params.model = "gpt-4o" + router.upsert_deployment(deployment=deployment) + assert len(router.model_list) == len(model_list) + + +def test_delete_deployment(model_list): + """Test if the 'delete_deployment' function is working correctly""" + router = Router(model_list=model_list) + deployment = router.get_deployment_by_model_group_name( + model_group_name="gpt-3.5-turbo" + ) + router.delete_deployment(id=deployment["model_info"]["id"]) + assert len(router.model_list) == len(model_list) - 1 + + +def test_get_model_info(model_list): + """Test if the 'get_model_info' function is working correctly""" + router = Router(model_list=model_list) + deployment = router.get_deployment_by_model_group_name( + model_group_name="gpt-3.5-turbo" + ) + model_info = router.get_model_info(id=deployment["model_info"]["id"]) + assert model_info is not None + + +def test_get_model_group(model_list): + """Test if the 'get_model_group' function is working correctly""" + router = Router(model_list=model_list) + deployment = router.get_deployment_by_model_group_name( + model_group_name="gpt-3.5-turbo" + ) + model_group = router.get_model_group(id=deployment["model_info"]["id"]) + assert model_group is not None + assert model_group[0]["model_name"] == "gpt-3.5-turbo" + + +@pytest.mark.parametrize("user_facing_model_group_name", ["gpt-3.5-turbo", "gpt-4o"]) +def test_set_model_group_info(model_list, user_facing_model_group_name): + """Test if the 'set_model_group_info' function is working correctly""" + router = Router(model_list=model_list) + resp = router._set_model_group_info( + model_group="gpt-3.5-turbo", + user_facing_model_group_name=user_facing_model_group_name, + ) + assert resp is not None + assert resp.model_group == user_facing_model_group_name + + +@pytest.mark.asyncio +async def test_set_response_headers(model_list): + """Test if the 'set_response_headers' function is working correctly""" + router = Router(model_list=model_list) + resp = await router.set_response_headers(response=None, model_group=None) + assert resp is None + + +def test_get_all_deployments(model_list): + """Test if the 'get_all_deployments' function is working correctly""" + router = Router(model_list=model_list) + deployments = router._get_all_deployments( + model_name="gpt-3.5-turbo", model_alias="gpt-3.5-turbo" + ) + assert len(deployments) > 0 + + +def test_get_model_access_groups(model_list): + """Test if the 'get_model_access_groups' function is working correctly""" + router = Router(model_list=model_list) + access_groups = router.get_model_access_groups() + assert len(access_groups) == 2 + + +def test_update_settings(model_list): + """Test if the 'update_settings' function is working correctly""" + router = Router(model_list=model_list) + pre_update_allowed_fails = router.allowed_fails + router.update_settings(**{"allowed_fails": 20}) + assert router.allowed_fails != pre_update_allowed_fails + assert router.allowed_fails == 20 + + +def test_common_checks_available_deployment(model_list): + """Test if the 'common_checks_available_deployment' function is working correctly""" + router = Router(model_list=model_list) + _, available_deployments = router._common_checks_available_deployment( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "hi"}], + input="hi", + specific_deployment=False, + ) + + assert len(available_deployments) > 0 + + +def test_filter_cooldown_deployments(model_list): + """Test if the 'filter_cooldown_deployments' function is working correctly""" + router = Router(model_list=model_list) + deployments = router._filter_cooldown_deployments( + healthy_deployments=router._get_all_deployments(model_name="gpt-3.5-turbo"), # type: ignore + cooldown_deployments=[], + ) + assert len(deployments) == len( + router._get_all_deployments(model_name="gpt-3.5-turbo") + ) + + +def test_track_deployment_metrics(model_list): + """Test if the 'track_deployment_metrics' function is working correctly""" + from litellm.types.utils import ModelResponse + + router = Router(model_list=model_list) + router._track_deployment_metrics( + deployment=router.get_deployment_by_model_group_name( + model_group_name="gpt-3.5-turbo" + ), + response=ModelResponse( + model="gpt-3.5-turbo", + usage={"total_tokens": 100}, + ), + ) + + +@pytest.mark.parametrize( + "exception_type, exception_name, num_retries", + [ + (litellm.exceptions.BadRequestError, "BadRequestError", 3), + (litellm.exceptions.AuthenticationError, "AuthenticationError", 4), + (litellm.exceptions.RateLimitError, "RateLimitError", 6), + ( + litellm.exceptions.ContentPolicyViolationError, + "ContentPolicyViolationError", + 7, + ), + ], +) +def test_get_num_retries_from_retry_policy( + model_list, exception_type, exception_name, num_retries +): + """Test if the 'get_num_retries_from_retry_policy' function is working correctly""" + from litellm.router import RetryPolicy + + data = {exception_name + "Retries": num_retries} + print("data", data) + router = Router( + model_list=model_list, + retry_policy=RetryPolicy(**data), + ) + print("exception_type", exception_type) + calc_num_retries = router.get_num_retries_from_retry_policy( + exception=exception_type( + message="test", llm_provider="openai", model="gpt-3.5-turbo" + ) + ) + assert calc_num_retries == num_retries + + +@pytest.mark.parametrize( + "exception_type, exception_name, allowed_fails", + [ + (litellm.exceptions.BadRequestError, "BadRequestError", 3), + (litellm.exceptions.AuthenticationError, "AuthenticationError", 4), + (litellm.exceptions.RateLimitError, "RateLimitError", 6), + ( + litellm.exceptions.ContentPolicyViolationError, + "ContentPolicyViolationError", + 7, + ), + ], +) +def test_get_allowed_fails_from_policy( + model_list, exception_type, exception_name, allowed_fails +): + """Test if the 'get_allowed_fails_from_policy' function is working correctly""" + from litellm.types.router import AllowedFailsPolicy + + data = {exception_name + "AllowedFails": allowed_fails} + router = Router( + model_list=model_list, allowed_fails_policy=AllowedFailsPolicy(**data) + ) + calc_allowed_fails = router.get_allowed_fails_from_policy( + exception=exception_type( + message="test", llm_provider="openai", model="gpt-3.5-turbo" + ) + ) + assert calc_allowed_fails == allowed_fails + + +def test_initialize_alerting(model_list): + """Test if the 'initialize_alerting' function is working correctly""" + from litellm.types.router import AlertingConfig + from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting + + router = Router( + model_list=model_list, alerting_config=AlertingConfig(webhook_url="test") + ) + router._initialize_alerting() + + callback_added = False + for callback in litellm.callbacks: + if isinstance(callback, SlackAlerting): + callback_added = True + assert callback_added is True + + +def test_flush_cache(model_list): + """Test if the 'flush_cache' function is working correctly""" + router = Router(model_list=model_list) + router.cache.set_cache("test", "test") + assert router.cache.get_cache("test") == "test" + router.flush_cache() + assert router.cache.get_cache("test") is None