mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
(QA / testing) - Add unit testing for key model access checks (#7999)
* fix _model_matches_any_wildcard_pattern_in_list * fix docstring
This commit is contained in:
parent
c2fa213ae2
commit
d9dcfccdf6
2 changed files with 107 additions and 4 deletions
|
@ -874,6 +874,11 @@ async def can_key_call_model(
|
|||
|
||||
verbose_proxy_logger.debug(f"model: {model}; allowed_models: {filtered_models}")
|
||||
|
||||
if _model_matches_any_wildcard_pattern_in_list(
|
||||
model=model, allowed_model_list=filtered_models
|
||||
):
|
||||
return True
|
||||
|
||||
all_model_access: bool = False
|
||||
|
||||
if (
|
||||
|
@ -1076,9 +1081,8 @@ def _team_model_access_check(
|
|||
pass
|
||||
elif model and "*" in model:
|
||||
pass
|
||||
elif any(
|
||||
is_model_allowed_by_pattern(model=model, allowed_model_pattern=team_model)
|
||||
for team_model in team_object.models
|
||||
elif _model_matches_any_wildcard_pattern_in_list(
|
||||
model=model, allowed_model_list=team_object.models
|
||||
):
|
||||
pass
|
||||
else:
|
||||
|
@ -1104,3 +1108,23 @@ def is_model_allowed_by_pattern(model: str, allowed_model_pattern: str) -> bool:
|
|||
return bool(re.match(pattern, model))
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _model_matches_any_wildcard_pattern_in_list(
|
||||
model: str, allowed_model_list: list
|
||||
) -> bool:
|
||||
"""
|
||||
Returns True if a model matches any wildcard pattern in a list.
|
||||
|
||||
eg.
|
||||
- model=`bedrock/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/*` returns True
|
||||
- model=`bedrock/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/us.*` returns True
|
||||
- model=`bedrockzzzz/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/*` returns False
|
||||
"""
|
||||
return any(
|
||||
"*" in allowed_model_pattern
|
||||
and is_model_allowed_by_pattern(
|
||||
model=model, allowed_model_pattern=allowed_model_pattern
|
||||
)
|
||||
for allowed_model_pattern in allowed_model_list
|
||||
)
|
||||
|
|
|
@ -110,7 +110,10 @@ async def test_handle_failed_db_connection():
|
|||
|
||||
@pytest.mark.parametrize(
|
||||
"model, expect_to_work",
|
||||
[("openai/gpt-4o-mini", True), ("openai/gpt-4o", False)],
|
||||
[
|
||||
("openai/gpt-4o-mini", True),
|
||||
("openai/gpt-4o", False),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_key_call_model(model, expect_to_work):
|
||||
|
@ -212,6 +215,82 @@ async def test_can_team_call_model(model, expect_to_work):
|
|||
assert not model_in_access_group(**args)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"key_models, model, expect_to_work",
|
||||
[
|
||||
(["openai/*"], "openai/gpt-4o", True),
|
||||
(["openai/*"], "openai/gpt-4o-mini", True),
|
||||
(["openai/*"], "openaiz/gpt-4o-mini", False),
|
||||
(["bedrock/*"], "bedrock/anthropic.claude-3-5-sonnet-20240620", True),
|
||||
(["bedrock/*"], "bedrockz/anthropic.claude-3-5-sonnet-20240620", False),
|
||||
(["bedrock/us.*"], "bedrock/us.amazon.nova-micro-v1:0", True),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_key_call_model_wildcard_access(key_models, model, expect_to_work):
|
||||
from litellm.proxy.auth.auth_checks import can_key_call_model
|
||||
from fastapi import HTTPException
|
||||
|
||||
llm_model_list = [
|
||||
{
|
||||
"model_name": "openai/*",
|
||||
"litellm_params": {
|
||||
"model": "openai/*",
|
||||
"api_key": "test-api-key",
|
||||
},
|
||||
"model_info": {
|
||||
"id": "e6e7006f83029df40ebc02ddd068890253f4cd3092bcb203d3d8e6f6f606f30f",
|
||||
"db_model": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "bedrock/*",
|
||||
"litellm_params": {
|
||||
"model": "bedrock/*",
|
||||
"api_key": "test-api-key",
|
||||
},
|
||||
"model_info": {
|
||||
"id": "e6e7006f83029df40ebc02ddd068890253f4cd3092bcb203d3d8e6f6f606f30f",
|
||||
"db_model": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "openai/gpt-4o",
|
||||
"litellm_params": {
|
||||
"model": "openai/gpt-4o",
|
||||
"api_key": "test-api-key",
|
||||
},
|
||||
"model_info": {
|
||||
"id": "0cfcd87f2cb12a783a466888d05c6c89df66db23e01cecd75ec0b83aed73c9ad",
|
||||
"db_model": False,
|
||||
},
|
||||
},
|
||||
]
|
||||
router = litellm.Router(model_list=llm_model_list)
|
||||
|
||||
user_api_key_object = UserAPIKeyAuth(
|
||||
models=key_models,
|
||||
)
|
||||
|
||||
if expect_to_work:
|
||||
await can_key_call_model(
|
||||
model=model,
|
||||
llm_model_list=llm_model_list,
|
||||
valid_token=user_api_key_object,
|
||||
llm_router=router,
|
||||
)
|
||||
else:
|
||||
with pytest.raises(Exception) as e:
|
||||
await can_key_call_model(
|
||||
model=model,
|
||||
llm_model_list=llm_model_list,
|
||||
valid_token=user_api_key_object,
|
||||
llm_router=router,
|
||||
)
|
||||
|
||||
print(e)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_valid_fallback_model():
|
||||
from litellm.proxy.auth.auth_checks import is_valid_fallback_model
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue