(QA / testing) - Add unit testing for key model access checks (#7999)

* fix _model_matches_any_wildcard_pattern_in_list

* fix docstring
This commit is contained in:
Ishaan Jaff 2025-01-25 10:01:35 -08:00 committed by GitHub
parent c2fa213ae2
commit d9dcfccdf6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 107 additions and 4 deletions

View file

@ -874,6 +874,11 @@ async def can_key_call_model(
verbose_proxy_logger.debug(f"model: {model}; allowed_models: {filtered_models}") verbose_proxy_logger.debug(f"model: {model}; allowed_models: {filtered_models}")
if _model_matches_any_wildcard_pattern_in_list(
model=model, allowed_model_list=filtered_models
):
return True
all_model_access: bool = False all_model_access: bool = False
if ( if (
@ -1076,9 +1081,8 @@ def _team_model_access_check(
pass pass
elif model and "*" in model: elif model and "*" in model:
pass pass
elif any( elif _model_matches_any_wildcard_pattern_in_list(
is_model_allowed_by_pattern(model=model, allowed_model_pattern=team_model) model=model, allowed_model_list=team_object.models
for team_model in team_object.models
): ):
pass pass
else: else:
@ -1104,3 +1108,23 @@ def is_model_allowed_by_pattern(model: str, allowed_model_pattern: str) -> bool:
return bool(re.match(pattern, model)) return bool(re.match(pattern, model))
return False return False
def _model_matches_any_wildcard_pattern_in_list(
model: str, allowed_model_list: list
) -> bool:
"""
Returns True if a model matches any wildcard pattern in a list.
eg.
- model=`bedrock/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/*` returns True
- model=`bedrock/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/us.*` returns True
- model=`bedrockzzzz/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/*` returns False
"""
return any(
"*" in allowed_model_pattern
and is_model_allowed_by_pattern(
model=model, allowed_model_pattern=allowed_model_pattern
)
for allowed_model_pattern in allowed_model_list
)

View file

@ -110,7 +110,10 @@ async def test_handle_failed_db_connection():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model, expect_to_work", "model, expect_to_work",
[("openai/gpt-4o-mini", True), ("openai/gpt-4o", False)], [
("openai/gpt-4o-mini", True),
("openai/gpt-4o", False),
],
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_can_key_call_model(model, expect_to_work): async def test_can_key_call_model(model, expect_to_work):
@ -212,6 +215,82 @@ async def test_can_team_call_model(model, expect_to_work):
assert not model_in_access_group(**args) assert not model_in_access_group(**args)
@pytest.mark.parametrize(
"key_models, model, expect_to_work",
[
(["openai/*"], "openai/gpt-4o", True),
(["openai/*"], "openai/gpt-4o-mini", True),
(["openai/*"], "openaiz/gpt-4o-mini", False),
(["bedrock/*"], "bedrock/anthropic.claude-3-5-sonnet-20240620", True),
(["bedrock/*"], "bedrockz/anthropic.claude-3-5-sonnet-20240620", False),
(["bedrock/us.*"], "bedrock/us.amazon.nova-micro-v1:0", True),
],
)
@pytest.mark.asyncio
async def test_can_key_call_model_wildcard_access(key_models, model, expect_to_work):
from litellm.proxy.auth.auth_checks import can_key_call_model
from fastapi import HTTPException
llm_model_list = [
{
"model_name": "openai/*",
"litellm_params": {
"model": "openai/*",
"api_key": "test-api-key",
},
"model_info": {
"id": "e6e7006f83029df40ebc02ddd068890253f4cd3092bcb203d3d8e6f6f606f30f",
"db_model": False,
},
},
{
"model_name": "bedrock/*",
"litellm_params": {
"model": "bedrock/*",
"api_key": "test-api-key",
},
"model_info": {
"id": "e6e7006f83029df40ebc02ddd068890253f4cd3092bcb203d3d8e6f6f606f30f",
"db_model": False,
},
},
{
"model_name": "openai/gpt-4o",
"litellm_params": {
"model": "openai/gpt-4o",
"api_key": "test-api-key",
},
"model_info": {
"id": "0cfcd87f2cb12a783a466888d05c6c89df66db23e01cecd75ec0b83aed73c9ad",
"db_model": False,
},
},
]
router = litellm.Router(model_list=llm_model_list)
user_api_key_object = UserAPIKeyAuth(
models=key_models,
)
if expect_to_work:
await can_key_call_model(
model=model,
llm_model_list=llm_model_list,
valid_token=user_api_key_object,
llm_router=router,
)
else:
with pytest.raises(Exception) as e:
await can_key_call_model(
model=model,
llm_model_list=llm_model_list,
valid_token=user_api_key_object,
llm_router=router,
)
print(e)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_is_valid_fallback_model(): async def test_is_valid_fallback_model():
from litellm.proxy.auth.auth_checks import is_valid_fallback_model from litellm.proxy.auth.auth_checks import is_valid_fallback_model