(testing) Router add testing coverage (#6253)

* test: add more router code coverage

* test: additional router testing coverage

* fix: fix linting error

* test: fix tests for ci/cd

* test: fix test

* test: handle flaky tests

---------

Co-authored-by: Krrish Dholakia <krrishdholakia@gmail.com>
This commit is contained in:
Ishaan Jaff 2024-10-16 20:02:27 +05:30 committed by GitHub
parent 54ebdbf7ce
commit 8530000b44
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 706 additions and 106 deletions

View file

@ -11,9 +11,15 @@ def get_function_names_from_file(file_path):
function_names = []
for node in ast.walk(tree):
for node in tree.body:
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
# Top-level functions
function_names.append(node.name)
elif isinstance(node, ast.ClassDef):
# Functions inside classes
for class_node in node.body:
if isinstance(class_node, (ast.FunctionDef, ast.AsyncFunctionDef)):
function_names.append(class_node.name)
return function_names
@ -79,6 +85,7 @@ ignored_function_names = [
"a_add_message",
"aget_messages",
"arun_thread",
"try_retrieve_batch",
]
@ -103,8 +110,8 @@ def main():
if func not in ignored_function_names:
all_untested_functions.append(func)
untested_perc = (len(all_untested_functions)) / len(router_functions)
print("perc_covered: ", untested_perc)
if untested_perc < 0.3:
print("untested_perc: ", untested_perc)
if untested_perc > 0:
print("The following functions in router.py are not tested:")
raise Exception(
f"{untested_perc * 100:.2f}% of functions in router.py are not tested: {all_untested_functions}"

View file

@ -20,6 +20,7 @@ import boto3
@pytest.mark.asyncio
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.flaky(retries=6, delay=1)
async def test_basic_s3_logging(sync_mode):
verbose_logger.setLevel(level=logging.DEBUG)
litellm.success_callback = ["s3"]

View file

@ -3789,6 +3789,7 @@ def test_completion_anyscale_api():
# @pytest.mark.skip(reason="flaky test, times out frequently")
@pytest.mark.flaky(retries=6, delay=1)
def test_completion_cohere():
try:
# litellm.set_verbose=True

View file

@ -10,6 +10,7 @@ sys.path.insert(
) # Adds the parent directory to the system path
from litellm import Router
import pytest
import litellm
from unittest.mock import patch, MagicMock, AsyncMock
@ -22,6 +23,9 @@ def model_list():
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"model_info": {
"access_groups": ["group1", "group2"],
},
},
{
"model_name": "gpt-4o",
@ -250,3 +254,583 @@ async def test_router_make_call(model_list):
mock_response="https://example.com/image.png",
)
assert response.data[0].url == "https://example.com/image.png"
def test_update_kwargs_with_deployment(model_list):
"""Test if the '_update_kwargs_with_deployment' function is working correctly"""
router = Router(model_list=model_list)
kwargs: dict = {"metadata": {}}
deployment = router.get_deployment_by_model_group_name(
model_group_name="gpt-3.5-turbo"
)
router._update_kwargs_with_deployment(
deployment=deployment,
kwargs=kwargs,
)
set_fields = ["deployment", "api_base", "model_info"]
assert all(field in kwargs["metadata"] for field in set_fields)
def test_update_kwargs_with_default_litellm_params(model_list):
"""Test if the '_update_kwargs_with_default_litellm_params' function is working correctly"""
router = Router(
model_list=model_list,
default_litellm_params={"api_key": "test", "metadata": {"key": "value"}},
)
kwargs: dict = {"metadata": {"key2": "value2"}}
router._update_kwargs_with_default_litellm_params(kwargs=kwargs)
assert kwargs["api_key"] == "test"
assert kwargs["metadata"]["key"] == "value"
assert kwargs["metadata"]["key2"] == "value2"
def test_get_async_openai_model_client(model_list):
"""Test if the '_get_async_openai_model_client' function is working correctly"""
router = Router(model_list=model_list)
deployment = router.get_deployment_by_model_group_name(
model_group_name="gpt-3.5-turbo"
)
model_client = router._get_async_openai_model_client(
deployment=deployment, kwargs={}
)
assert model_client is not None
def test_get_timeout(model_list):
"""Test if the '_get_timeout' function is working correctly"""
router = Router(model_list=model_list)
timeout = router._get_timeout(kwargs={}, data={"timeout": 100})
assert timeout == 100
@pytest.mark.parametrize(
"fallback_kwarg, expected_error",
[
("mock_testing_fallbacks", litellm.InternalServerError),
("mock_testing_context_fallbacks", litellm.ContextWindowExceededError),
("mock_testing_content_policy_fallbacks", litellm.ContentPolicyViolationError),
],
)
def test_handle_mock_testing_fallbacks(model_list, fallback_kwarg, expected_error):
"""Test if the '_handle_mock_testing_fallbacks' function is working correctly"""
router = Router(model_list=model_list)
with pytest.raises(expected_error):
data = {
fallback_kwarg: True,
}
router._handle_mock_testing_fallbacks(
kwargs=data,
)
def test_handle_mock_testing_rate_limit_error(model_list):
"""Test if the '_handle_mock_testing_rate_limit_error' function is working correctly"""
router = Router(model_list=model_list)
with pytest.raises(litellm.RateLimitError):
data = {
"mock_testing_rate_limit_error": True,
}
router._handle_mock_testing_rate_limit_error(
kwargs=data,
)
def test_get_fallback_model_group_from_fallbacks(model_list):
"""Test if the '_get_fallback_model_group_from_fallbacks' function is working correctly"""
router = Router(model_list=model_list)
fallback_model_group_name = router._get_fallback_model_group_from_fallbacks(
model_group="gpt-4o",
fallbacks=[{"gpt-4o": "gpt-3.5-turbo"}],
)
assert fallback_model_group_name == "gpt-3.5-turbo"
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_deployment_callback_on_success(model_list, sync_mode):
"""Test if the '_deployment_callback_on_success' function is working correctly"""
import time
router = Router(model_list=model_list)
kwargs = {
"litellm_params": {
"metadata": {
"model_group": "gpt-3.5-turbo",
},
"model_info": {"id": 100},
},
}
response = litellm.ModelResponse(
model="gpt-3.5-turbo",
usage={"total_tokens": 100},
)
if sync_mode:
tpm_key = router.sync_deployment_callback_on_success(
kwargs=kwargs,
completion_response=response,
start_time=time.time(),
end_time=time.time(),
)
else:
tpm_key = await router.deployment_callback_on_success(
kwargs=kwargs,
completion_response=response,
start_time=time.time(),
end_time=time.time(),
)
assert tpm_key is not None
def test_deployment_callback_on_failure(model_list):
"""Test if the '_deployment_callback_on_failure' function is working correctly"""
import time
router = Router(model_list=model_list)
kwargs = {
"litellm_params": {
"metadata": {
"model_group": "gpt-3.5-turbo",
},
"model_info": {"id": 100},
},
}
result = router.deployment_callback_on_failure(
kwargs=kwargs,
completion_response=None,
start_time=time.time(),
end_time=time.time(),
)
assert isinstance(result, bool)
assert result is False
def test_log_retry(model_list):
"""Test if the '_log_retry' function is working correctly"""
import time
router = Router(model_list=model_list)
new_kwargs = router.log_retry(
kwargs={"metadata": {}},
e=Exception(),
)
assert "metadata" in new_kwargs
assert "previous_models" in new_kwargs["metadata"]
def test_update_usage(model_list):
"""Test if the '_update_usage' function is working correctly"""
router = Router(model_list=model_list)
deployment = router.get_deployment_by_model_group_name(
model_group_name="gpt-3.5-turbo"
)
deployment_id = deployment["model_info"]["id"]
request_count = router._update_usage(
deployment_id=deployment_id,
)
assert request_count == 1
request_count = router._update_usage(
deployment_id=deployment_id,
)
assert request_count == 2
@pytest.mark.parametrize(
"finish_reason, expected_error", [("content_filter", True), ("stop", False)]
)
def test_should_raise_content_policy_error(model_list, finish_reason, expected_error):
"""Test if the '_should_raise_content_policy_error' function is working correctly"""
router = Router(model_list=model_list)
assert (
router._should_raise_content_policy_error(
model="gpt-3.5-turbo",
response=litellm.ModelResponse(
model="gpt-3.5-turbo",
choices=[
{
"finish_reason": finish_reason,
"message": {"content": "I'm fine, thank you!"},
}
],
usage={"total_tokens": 100},
),
kwargs={
"content_policy_fallbacks": [{"gpt-3.5-turbo": "gpt-4o"}],
},
)
is expected_error
)
def test_get_healthy_deployments(model_list):
"""Test if the '_get_healthy_deployments' function is working correctly"""
router = Router(model_list=model_list)
deployments = router._get_healthy_deployments(model="gpt-3.5-turbo")
assert len(deployments) > 0
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_routing_strategy_pre_call_checks(model_list, sync_mode):
"""Test if the '_routing_strategy_pre_call_checks' function is working correctly"""
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.litellm_logging import Logging
callback = CustomLogger()
litellm.callbacks = [callback]
router = Router(model_list=model_list)
deployment = router.get_deployment_by_model_group_name(
model_group_name="gpt-3.5-turbo"
)
litellm_logging_obj = Logging(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "hi"}],
stream=False,
call_type="acompletion",
litellm_call_id="1234",
start_time=datetime.now(),
function_id="1234",
)
if sync_mode:
router.routing_strategy_pre_call_checks(deployment)
else:
## NO EXCEPTION
await router.async_routing_strategy_pre_call_checks(
deployment, litellm_logging_obj
)
## WITH EXCEPTION - rate limit error
with patch.object(
callback,
"async_pre_call_check",
AsyncMock(
side_effect=litellm.RateLimitError(
message="Rate limit error",
llm_provider="openai",
model="gpt-3.5-turbo",
)
),
):
try:
await router.async_routing_strategy_pre_call_checks(
deployment, litellm_logging_obj
)
pytest.fail("Exception was not raised")
except Exception as e:
assert isinstance(e, litellm.RateLimitError)
## WITH EXCEPTION - generic error
with patch.object(
callback, "async_pre_call_check", AsyncMock(side_effect=Exception("Error"))
):
try:
await router.async_routing_strategy_pre_call_checks(
deployment, litellm_logging_obj
)
pytest.fail("Exception was not raised")
except Exception as e:
assert isinstance(e, Exception)
@pytest.mark.parametrize(
"set_supported_environments, supported_environments, is_supported",
[(True, ["staging"], True), (False, None, True), (True, ["development"], False)],
)
def test_create_deployment(
model_list, set_supported_environments, supported_environments, is_supported
):
"""Test if the '_create_deployment' function is working correctly"""
router = Router(model_list=model_list)
if set_supported_environments:
os.environ["LITELLM_ENVIRONMENT"] = "staging"
deployment = router._create_deployment(
deployment_info={},
_model_name="gpt-3.5-turbo",
_litellm_params={
"model": "gpt-3.5-turbo",
"api_key": "test",
"custom_llm_provider": "openai",
},
_model_info={
"id": 100,
"supported_environments": supported_environments,
},
)
if is_supported:
assert deployment is not None
else:
assert deployment is None
@pytest.mark.parametrize(
"set_supported_environments, supported_environments, is_supported",
[(True, ["staging"], True), (False, None, True), (True, ["development"], False)],
)
def test_deployment_is_active_for_environment(
model_list, set_supported_environments, supported_environments, is_supported
):
"""Test if the '_deployment_is_active_for_environment' function is working correctly"""
router = Router(model_list=model_list)
deployment = router.get_deployment_by_model_group_name(
model_group_name="gpt-3.5-turbo"
)
if set_supported_environments:
os.environ["LITELLM_ENVIRONMENT"] = "staging"
deployment["model_info"]["supported_environments"] = supported_environments
if is_supported:
assert (
router.deployment_is_active_for_environment(deployment=deployment) is True
)
else:
assert (
router.deployment_is_active_for_environment(deployment=deployment) is False
)
def test_set_model_list(model_list):
"""Test if the '_set_model_list' function is working correctly"""
router = Router(model_list=model_list)
router.set_model_list(model_list=model_list)
assert len(router.model_list) == len(model_list)
def test_add_deployment(model_list):
"""Test if the '_add_deployment' function is working correctly"""
router = Router(model_list=model_list)
deployment = router.get_deployment_by_model_group_name(
model_group_name="gpt-3.5-turbo"
)
deployment["model_info"]["id"] = 100
## Test 1: call user facing function
router.add_deployment(deployment=deployment)
## Test 2: call internal function
router._add_deployment(deployment=deployment)
assert len(router.model_list) == len(model_list) + 1
def test_upsert_deployment(model_list):
"""Test if the 'upsert_deployment' function is working correctly"""
router = Router(model_list=model_list)
print("model list", len(router.model_list))
deployment = router.get_deployment_by_model_group_name(
model_group_name="gpt-3.5-turbo"
)
deployment.litellm_params.model = "gpt-4o"
router.upsert_deployment(deployment=deployment)
assert len(router.model_list) == len(model_list)
def test_delete_deployment(model_list):
"""Test if the 'delete_deployment' function is working correctly"""
router = Router(model_list=model_list)
deployment = router.get_deployment_by_model_group_name(
model_group_name="gpt-3.5-turbo"
)
router.delete_deployment(id=deployment["model_info"]["id"])
assert len(router.model_list) == len(model_list) - 1
def test_get_model_info(model_list):
"""Test if the 'get_model_info' function is working correctly"""
router = Router(model_list=model_list)
deployment = router.get_deployment_by_model_group_name(
model_group_name="gpt-3.5-turbo"
)
model_info = router.get_model_info(id=deployment["model_info"]["id"])
assert model_info is not None
def test_get_model_group(model_list):
"""Test if the 'get_model_group' function is working correctly"""
router = Router(model_list=model_list)
deployment = router.get_deployment_by_model_group_name(
model_group_name="gpt-3.5-turbo"
)
model_group = router.get_model_group(id=deployment["model_info"]["id"])
assert model_group is not None
assert model_group[0]["model_name"] == "gpt-3.5-turbo"
@pytest.mark.parametrize("user_facing_model_group_name", ["gpt-3.5-turbo", "gpt-4o"])
def test_set_model_group_info(model_list, user_facing_model_group_name):
"""Test if the 'set_model_group_info' function is working correctly"""
router = Router(model_list=model_list)
resp = router._set_model_group_info(
model_group="gpt-3.5-turbo",
user_facing_model_group_name=user_facing_model_group_name,
)
assert resp is not None
assert resp.model_group == user_facing_model_group_name
@pytest.mark.asyncio
async def test_set_response_headers(model_list):
"""Test if the 'set_response_headers' function is working correctly"""
router = Router(model_list=model_list)
resp = await router.set_response_headers(response=None, model_group=None)
assert resp is None
def test_get_all_deployments(model_list):
"""Test if the 'get_all_deployments' function is working correctly"""
router = Router(model_list=model_list)
deployments = router._get_all_deployments(
model_name="gpt-3.5-turbo", model_alias="gpt-3.5-turbo"
)
assert len(deployments) > 0
def test_get_model_access_groups(model_list):
"""Test if the 'get_model_access_groups' function is working correctly"""
router = Router(model_list=model_list)
access_groups = router.get_model_access_groups()
assert len(access_groups) == 2
def test_update_settings(model_list):
"""Test if the 'update_settings' function is working correctly"""
router = Router(model_list=model_list)
pre_update_allowed_fails = router.allowed_fails
router.update_settings(**{"allowed_fails": 20})
assert router.allowed_fails != pre_update_allowed_fails
assert router.allowed_fails == 20
def test_common_checks_available_deployment(model_list):
"""Test if the 'common_checks_available_deployment' function is working correctly"""
router = Router(model_list=model_list)
_, available_deployments = router._common_checks_available_deployment(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "hi"}],
input="hi",
specific_deployment=False,
)
assert len(available_deployments) > 0
def test_filter_cooldown_deployments(model_list):
"""Test if the 'filter_cooldown_deployments' function is working correctly"""
router = Router(model_list=model_list)
deployments = router._filter_cooldown_deployments(
healthy_deployments=router._get_all_deployments(model_name="gpt-3.5-turbo"), # type: ignore
cooldown_deployments=[],
)
assert len(deployments) == len(
router._get_all_deployments(model_name="gpt-3.5-turbo")
)
def test_track_deployment_metrics(model_list):
"""Test if the 'track_deployment_metrics' function is working correctly"""
from litellm.types.utils import ModelResponse
router = Router(model_list=model_list)
router._track_deployment_metrics(
deployment=router.get_deployment_by_model_group_name(
model_group_name="gpt-3.5-turbo"
),
response=ModelResponse(
model="gpt-3.5-turbo",
usage={"total_tokens": 100},
),
)
@pytest.mark.parametrize(
"exception_type, exception_name, num_retries",
[
(litellm.exceptions.BadRequestError, "BadRequestError", 3),
(litellm.exceptions.AuthenticationError, "AuthenticationError", 4),
(litellm.exceptions.RateLimitError, "RateLimitError", 6),
(
litellm.exceptions.ContentPolicyViolationError,
"ContentPolicyViolationError",
7,
),
],
)
def test_get_num_retries_from_retry_policy(
model_list, exception_type, exception_name, num_retries
):
"""Test if the 'get_num_retries_from_retry_policy' function is working correctly"""
from litellm.router import RetryPolicy
data = {exception_name + "Retries": num_retries}
print("data", data)
router = Router(
model_list=model_list,
retry_policy=RetryPolicy(**data),
)
print("exception_type", exception_type)
calc_num_retries = router.get_num_retries_from_retry_policy(
exception=exception_type(
message="test", llm_provider="openai", model="gpt-3.5-turbo"
)
)
assert calc_num_retries == num_retries
@pytest.mark.parametrize(
"exception_type, exception_name, allowed_fails",
[
(litellm.exceptions.BadRequestError, "BadRequestError", 3),
(litellm.exceptions.AuthenticationError, "AuthenticationError", 4),
(litellm.exceptions.RateLimitError, "RateLimitError", 6),
(
litellm.exceptions.ContentPolicyViolationError,
"ContentPolicyViolationError",
7,
),
],
)
def test_get_allowed_fails_from_policy(
model_list, exception_type, exception_name, allowed_fails
):
"""Test if the 'get_allowed_fails_from_policy' function is working correctly"""
from litellm.types.router import AllowedFailsPolicy
data = {exception_name + "AllowedFails": allowed_fails}
router = Router(
model_list=model_list, allowed_fails_policy=AllowedFailsPolicy(**data)
)
calc_allowed_fails = router.get_allowed_fails_from_policy(
exception=exception_type(
message="test", llm_provider="openai", model="gpt-3.5-turbo"
)
)
assert calc_allowed_fails == allowed_fails
def test_initialize_alerting(model_list):
"""Test if the 'initialize_alerting' function is working correctly"""
from litellm.types.router import AlertingConfig
from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting
router = Router(
model_list=model_list, alerting_config=AlertingConfig(webhook_url="test")
)
router._initialize_alerting()
callback_added = False
for callback in litellm.callbacks:
if isinstance(callback, SlackAlerting):
callback_added = True
assert callback_added is True
def test_flush_cache(model_list):
"""Test if the 'flush_cache' function is working correctly"""
router = Router(model_list=model_list)
router.cache.set_cache("test", "test")
assert router.cache.get_cache("test") == "test"
router.flush_cache()
assert router.cache.get_cache("test") is None