fix(auth_checks.py): handle auth checks for team based model access groups

handles scenario where model access group used for wildcard models
This commit is contained in:
Krrish Dholakia 2024-11-29 16:02:05 -08:00
parent 63a9666794
commit aa1621757c
3 changed files with 65 additions and 19 deletions

View file

@ -60,6 +60,7 @@ def common_checks( # noqa: PLR0915
global_proxy_spend: Optional[float], global_proxy_spend: Optional[float],
general_settings: dict, general_settings: dict,
route: str, route: str,
llm_router: Optional[litellm.Router],
) -> bool: ) -> bool:
""" """
Common checks across jwt + key-based auth. Common checks across jwt + key-based auth.
@ -97,7 +98,12 @@ def common_checks( # noqa: PLR0915
# this means the team has access to all models on the proxy # this means the team has access to all models on the proxy
pass pass
# check if the team model is an access_group # check if the team model is an access_group
elif model_in_access_group(_model, team_object.models) is True: elif (
model_in_access_group(
model=_model, team_models=team_object.models, llm_router=llm_router
)
is True
):
pass pass
elif _model and "*" in _model: elif _model and "*" in _model:
pass pass
@ -373,36 +379,33 @@ async def get_end_user_object(
return None return None
def model_in_access_group(model: str, team_models: Optional[List[str]]) -> bool: def model_in_access_group(
model: str, team_models: Optional[List[str]], llm_router: Optional[litellm.Router]
) -> bool:
from collections import defaultdict from collections import defaultdict
from litellm.proxy.proxy_server import llm_router
if team_models is None: if team_models is None:
return True return True
if model in team_models: if model in team_models:
return True return True
access_groups = defaultdict(list) access_groups: dict[str, list[str]] = defaultdict(list)
if llm_router: if llm_router:
access_groups = llm_router.get_model_access_groups() access_groups = llm_router.get_model_access_groups(model_name=model)
models_in_current_access_groups = []
if len(access_groups) > 0: # check if token contains any model access groups if len(access_groups) > 0: # check if token contains any model access groups
for idx, m in enumerate( for idx, m in enumerate(
team_models team_models
): # loop token models, if any of them are an access group add the access group ): # loop token models, if any of them are an access group add the access group
if m in access_groups: if m in access_groups:
# if it is an access group we need to remove it from valid_token.models return True
models_in_group = access_groups[m]
models_in_current_access_groups.extend(models_in_group)
# Filter out models that are access_groups # Filter out models that are access_groups
filtered_models = [m for m in team_models if m not in access_groups] filtered_models = [m for m in team_models if m not in access_groups]
filtered_models += models_in_current_access_groups
if model in filtered_models: if model in filtered_models:
return True return True
return False return False
@ -909,9 +912,7 @@ async def can_key_call_model(
valid_token.models valid_token.models
): # loop token models, if any of them are an access group add the access group ): # loop token models, if any of them are an access group add the access group
if m in access_groups: if m in access_groups:
# if it is an access group we need to remove it from valid_token.models return True
models_in_group = access_groups[m]
models_in_current_access_groups.extend(models_in_group)
# Filter out models that are access_groups # Filter out models that are access_groups
filtered_models = [m for m in valid_token.models if m not in access_groups] filtered_models = [m for m in valid_token.models if m not in access_groups]
@ -925,11 +926,6 @@ async def can_key_call_model(
len(filtered_models) == 0 and len(valid_token.models) == 0 len(filtered_models) == 0 and len(valid_token.models) == 0
) or "*" in filtered_models: ) or "*" in filtered_models:
all_model_access = True all_model_access = True
elif (
llm_router is not None
and llm_router.pattern_router.route(model, filtered_models) is not None
): # wildcard access
all_model_access = True
if model is not None and model not in filtered_models and all_model_access is False: if model is not None and model not in filtered_models and all_model_access is False:
raise ValueError( raise ValueError(

View file

@ -543,6 +543,7 @@ async def user_api_key_auth( # noqa: PLR0915
general_settings=general_settings, general_settings=general_settings,
global_proxy_spend=global_proxy_spend, global_proxy_spend=global_proxy_spend,
route=route, route=route,
llm_router=llm_router,
) )
# return UserAPIKeyAuth object # return UserAPIKeyAuth object
@ -1176,6 +1177,7 @@ async def user_api_key_auth( # noqa: PLR0915
general_settings=general_settings, general_settings=general_settings,
global_proxy_spend=global_proxy_spend, global_proxy_spend=global_proxy_spend,
route=route, route=route,
llm_router=llm_router,
) )
# Token passed all checks # Token passed all checks
if valid_token is None: if valid_token is None:

View file

@ -151,3 +151,51 @@ async def test_can_key_call_model(model, expect_to_work):
await can_key_call_model(**args) await can_key_call_model(**args)
print(e) print(e)
@pytest.mark.parametrize(
"model, expect_to_work",
[("openai/gpt-4o", False), ("openai/gpt-4o-mini", True)],
)
@pytest.mark.asyncio
async def test_can_team_call_model(model, expect_to_work):
from litellm.proxy.auth.auth_checks import model_in_access_group
from fastapi import HTTPException
llm_model_list = [
{
"model_name": "openai/*",
"litellm_params": {
"model": "openai/*",
"api_key": "test-api-key",
},
"model_info": {
"id": "e6e7006f83029df40ebc02ddd068890253f4cd3092bcb203d3d8e6f6f606f30f",
"db_model": False,
"access_groups": ["public-openai-models"],
},
},
{
"model_name": "openai/gpt-4o",
"litellm_params": {
"model": "openai/gpt-4o",
"api_key": "test-api-key",
},
"model_info": {
"id": "0cfcd87f2cb12a783a466888d05c6c89df66db23e01cecd75ec0b83aed73c9ad",
"db_model": False,
"access_groups": ["private-openai-models"],
},
},
]
router = litellm.Router(model_list=llm_model_list)
args = {
"model": model,
"team_models": ["public-openai-models"],
"llm_router": router,
}
if expect_to_work:
assert model_in_access_group(**args)
else:
assert not model_in_access_group(**args)