diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index baca826804..c150534110 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -874,6 +874,11 @@ async def can_key_call_model( verbose_proxy_logger.debug(f"model: {model}; allowed_models: {filtered_models}") + if _model_matches_any_wildcard_pattern_in_list( + model=model, allowed_model_list=filtered_models + ): + return True + all_model_access: bool = False if ( @@ -1076,9 +1081,8 @@ def _team_model_access_check( pass elif model and "*" in model: pass - elif any( - is_model_allowed_by_pattern(model=model, allowed_model_pattern=team_model) - for team_model in team_object.models + elif _model_matches_any_wildcard_pattern_in_list( + model=model, allowed_model_list=team_object.models ): pass else: @@ -1104,3 +1108,23 @@ def is_model_allowed_by_pattern(model: str, allowed_model_pattern: str) -> bool: return bool(re.match(pattern, model)) return False + + +def _model_matches_any_wildcard_pattern_in_list( + model: str, allowed_model_list: list +) -> bool: + """ + Returns True if a model matches any wildcard pattern in a list. + + eg. + - model=`bedrock/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/*` returns True + - model=`bedrock/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/us.*` returns True + - model=`bedrockzzzz/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/*` returns False + """ + return any( + "*" in allowed_model_pattern + and is_model_allowed_by_pattern( + model=model, allowed_model_pattern=allowed_model_pattern + ) + for allowed_model_pattern in allowed_model_list + ) diff --git a/tests/proxy_unit_tests/test_auth_checks.py b/tests/proxy_unit_tests/test_auth_checks.py index 68ab5cae6e..85b5b216a5 100644 --- a/tests/proxy_unit_tests/test_auth_checks.py +++ b/tests/proxy_unit_tests/test_auth_checks.py @@ -110,7 +110,10 @@ async def test_handle_failed_db_connection(): @pytest.mark.parametrize( "model, expect_to_work", - [("openai/gpt-4o-mini", True), ("openai/gpt-4o", False)], + [ + ("openai/gpt-4o-mini", True), + ("openai/gpt-4o", False), + ], ) @pytest.mark.asyncio async def test_can_key_call_model(model, expect_to_work): @@ -212,6 +215,82 @@ async def test_can_team_call_model(model, expect_to_work): assert not model_in_access_group(**args) +@pytest.mark.parametrize( + "key_models, model, expect_to_work", + [ + (["openai/*"], "openai/gpt-4o", True), + (["openai/*"], "openai/gpt-4o-mini", True), + (["openai/*"], "openaiz/gpt-4o-mini", False), + (["bedrock/*"], "bedrock/anthropic.claude-3-5-sonnet-20240620", True), + (["bedrock/*"], "bedrockz/anthropic.claude-3-5-sonnet-20240620", False), + (["bedrock/us.*"], "bedrock/us.amazon.nova-micro-v1:0", True), + ], +) +@pytest.mark.asyncio +async def test_can_key_call_model_wildcard_access(key_models, model, expect_to_work): + from litellm.proxy.auth.auth_checks import can_key_call_model + from fastapi import HTTPException + + llm_model_list = [ + { + "model_name": "openai/*", + "litellm_params": { + "model": "openai/*", + "api_key": "test-api-key", + }, + "model_info": { + "id": "e6e7006f83029df40ebc02ddd068890253f4cd3092bcb203d3d8e6f6f606f30f", + "db_model": False, + }, + }, + { + "model_name": "bedrock/*", + "litellm_params": { + "model": "bedrock/*", + "api_key": "test-api-key", + }, + "model_info": { + "id": "e6e7006f83029df40ebc02ddd068890253f4cd3092bcb203d3d8e6f6f606f30f", + "db_model": False, + }, + }, + { + "model_name": "openai/gpt-4o", + "litellm_params": { + "model": "openai/gpt-4o", + "api_key": "test-api-key", + }, + "model_info": { + "id": "0cfcd87f2cb12a783a466888d05c6c89df66db23e01cecd75ec0b83aed73c9ad", + "db_model": False, + }, + }, + ] + router = litellm.Router(model_list=llm_model_list) + + user_api_key_object = UserAPIKeyAuth( + models=key_models, + ) + + if expect_to_work: + await can_key_call_model( + model=model, + llm_model_list=llm_model_list, + valid_token=user_api_key_object, + llm_router=router, + ) + else: + with pytest.raises(Exception) as e: + await can_key_call_model( + model=model, + llm_model_list=llm_model_list, + valid_token=user_api_key_object, + llm_router=router, + ) + + print(e) + + @pytest.mark.asyncio async def test_is_valid_fallback_model(): from litellm.proxy.auth.auth_checks import is_valid_fallback_model