add unit testing for new model regex pattern matching

This commit is contained in:
Ishaan Jaff 2024-11-25 15:35:53 -08:00
parent 79d51e0bdd
commit f76d0e41ef

View file

@ -15,6 +15,10 @@ from starlette.datastructures import URL
import litellm import litellm
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.auth.auth_checks import (
_model_is_within_list_of_allowed_models,
_model_matches_patterns,
)
class Request: class Request:
@ -387,3 +391,59 @@ def test_is_api_route_allowed(route, user_role, expected_result):
pass pass
else: else:
raise e raise e
@pytest.mark.parametrize(
"model, allowed_models, expected_result",
[
# Empty allowed models list
("gpt-4", [], True),
# Universal access patterns
("gpt-4", ["*"], True),
("gpt-4", ["all-proxy-models"], True),
# Exact matches
("gpt-4", ["gpt-4"], True),
("gpt-4", ["gpt-3.5", "claude-2"], False),
# Pattern matching
("azure/gpt-4", ["azure/*"], True),
("custom_engine/model-123", ["custom_engine/*"], True),
("custom_engine/model-123", ["custom_engine/*", "azure/*"], True),
("custom-engine/model-123", ["custom_engine/*", "azure/*"], False),
("gpt-4", ["gpt-*"], True),
("gpt-4", ["claude-*"], False),
# Mixed scenarios
("gpt-4", ["claude-2", "gpt-*", "palm-2"], True),
("anthropic/claude-instant-1", ["anthropic/*", "gpt-4"], True),
],
)
def test_model_is_within_list_of_allowed_models(model, allowed_models, expected_result):
result = _model_is_within_list_of_allowed_models(
model=model, allowed_models=allowed_models
)
assert result == expected_result
@pytest.mark.parametrize(
"model, allowed_models, expected_result",
[
# Basic pattern matching
("gpt-4", ["gpt-*"], True),
("azure/gpt-4", ["azure/*"], True),
("custom_engine/model-123", ["custom_engine/*"], True),
("custom_engine/model-123", ["custom_engine/*", "azure/*"], True),
("custom-engine/model-123", ["custom_engine/*", "azure/*"], False),
# Multiple patterns
("gpt-4", ["claude-*", "gpt-*"], True),
("anthropic/claude-instant-1", ["anthropic/*", "gpt-*"], True),
# No matches
("gpt-4", ["claude-*"], False),
("azure/gpt-4", ["anthropic/*"], False),
# Edge cases
("gpt-4", ["gpt-4"], False), # No wildcard, should return False
("model", ["*-suffix"], False),
("prefix-model", ["prefix-*"], True),
],
)
def test_model_matches_patterns(model, allowed_models, expected_result):
result = _model_matches_patterns(model=model, allowed_models=allowed_models)
assert result == expected_result