From f76d0e41ef513a755352b05cbe2777e48ce6f01e Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 25 Nov 2024 15:35:53 -0800 Subject: [PATCH] add unit testing for new model regex pattern matching --- .../test_user_api_key_auth.py | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/tests/proxy_unit_tests/test_user_api_key_auth.py b/tests/proxy_unit_tests/test_user_api_key_auth.py index 31daa358a..913ff34e1 100644 --- a/tests/proxy_unit_tests/test_user_api_key_auth.py +++ b/tests/proxy_unit_tests/test_user_api_key_auth.py @@ -15,6 +15,10 @@ from starlette.datastructures import URL import litellm from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.auth.auth_checks import ( + _model_is_within_list_of_allowed_models, + _model_matches_patterns, +) class Request: @@ -387,3 +391,59 @@ def test_is_api_route_allowed(route, user_role, expected_result): pass else: raise e + + +@pytest.mark.parametrize( + "model, allowed_models, expected_result", + [ + # Empty allowed models list + ("gpt-4", [], True), + # Universal access patterns + ("gpt-4", ["*"], True), + ("gpt-4", ["all-proxy-models"], True), + # Exact matches + ("gpt-4", ["gpt-4"], True), + ("gpt-4", ["gpt-3.5", "claude-2"], False), + # Pattern matching + ("azure/gpt-4", ["azure/*"], True), + ("custom_engine/model-123", ["custom_engine/*"], True), + ("custom_engine/model-123", ["custom_engine/*", "azure/*"], True), + ("custom-engine/model-123", ["custom_engine/*", "azure/*"], False), + ("gpt-4", ["gpt-*"], True), + ("gpt-4", ["claude-*"], False), + # Mixed scenarios + ("gpt-4", ["claude-2", "gpt-*", "palm-2"], True), + ("anthropic/claude-instant-1", ["anthropic/*", "gpt-4"], True), + ], +) +def test_model_is_within_list_of_allowed_models(model, allowed_models, expected_result): + result = _model_is_within_list_of_allowed_models( + model=model, allowed_models=allowed_models + ) + assert result == expected_result + + +@pytest.mark.parametrize( + "model, allowed_models, expected_result", + [ + # Basic pattern matching + ("gpt-4", ["gpt-*"], True), + ("azure/gpt-4", ["azure/*"], True), + ("custom_engine/model-123", ["custom_engine/*"], True), + ("custom_engine/model-123", ["custom_engine/*", "azure/*"], True), + ("custom-engine/model-123", ["custom_engine/*", "azure/*"], False), + # Multiple patterns + ("gpt-4", ["claude-*", "gpt-*"], True), + ("anthropic/claude-instant-1", ["anthropic/*", "gpt-*"], True), + # No matches + ("gpt-4", ["claude-*"], False), + ("azure/gpt-4", ["anthropic/*"], False), + # Edge cases + ("gpt-4", ["gpt-4"], False), # No wildcard, should return False + ("model", ["*-suffix"], False), + ("prefix-model", ["prefix-*"], True), + ], +) +def test_model_matches_patterns(model, allowed_models, expected_result): + result = _model_matches_patterns(model=model, allowed_models=allowed_models) + assert result == expected_result