From aa1621757c968ba94c49d360d89c0a9190c0508d Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 29 Nov 2024 16:02:05 -0800 Subject: [PATCH] fix(auth_checks.py): handle auth checks for team based model access groups handles scenario where model access group used for wildcard models --- litellm/proxy/auth/auth_checks.py | 34 ++++++++---------- litellm/proxy/auth/user_api_key_auth.py | 2 ++ tests/local_testing/test_auth_checks.py | 48 +++++++++++++++++++++++++ 3 files changed, 65 insertions(+), 19 deletions(-) diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index d0e0e805d..315d9dd36 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -60,6 +60,7 @@ def common_checks( # noqa: PLR0915 global_proxy_spend: Optional[float], general_settings: dict, route: str, + llm_router: Optional[litellm.Router], ) -> bool: """ Common checks across jwt + key-based auth. @@ -97,7 +98,12 @@ def common_checks( # noqa: PLR0915 # this means the team has access to all models on the proxy pass # check if the team model is an access_group - elif model_in_access_group(_model, team_object.models) is True: + elif ( + model_in_access_group( + model=_model, team_models=team_object.models, llm_router=llm_router + ) + is True + ): pass elif _model and "*" in _model: pass @@ -373,36 +379,33 @@ async def get_end_user_object( return None -def model_in_access_group(model: str, team_models: Optional[List[str]]) -> bool: +def model_in_access_group( + model: str, team_models: Optional[List[str]], llm_router: Optional[litellm.Router] +) -> bool: from collections import defaultdict - from litellm.proxy.proxy_server import llm_router - if team_models is None: return True if model in team_models: return True - access_groups = defaultdict(list) + access_groups: dict[str, list[str]] = defaultdict(list) if llm_router: - access_groups = llm_router.get_model_access_groups() + access_groups = llm_router.get_model_access_groups(model_name=model) - models_in_current_access_groups = [] if len(access_groups) > 0: # check if token contains any model access groups for idx, m in enumerate( team_models ): # loop token models, if any of them are an access group add the access group if m in access_groups: - # if it is an access group we need to remove it from valid_token.models - models_in_group = access_groups[m] - models_in_current_access_groups.extend(models_in_group) + return True # Filter out models that are access_groups filtered_models = [m for m in team_models if m not in access_groups] - filtered_models += models_in_current_access_groups if model in filtered_models: return True + return False @@ -909,9 +912,7 @@ async def can_key_call_model( valid_token.models ): # loop token models, if any of them are an access group add the access group if m in access_groups: - # if it is an access group we need to remove it from valid_token.models - models_in_group = access_groups[m] - models_in_current_access_groups.extend(models_in_group) + return True # Filter out models that are access_groups filtered_models = [m for m in valid_token.models if m not in access_groups] @@ -925,11 +926,6 @@ async def can_key_call_model( len(filtered_models) == 0 and len(valid_token.models) == 0 ) or "*" in filtered_models: all_model_access = True - elif ( - llm_router is not None - and llm_router.pattern_router.route(model, filtered_models) is not None - ): # wildcard access - all_model_access = True if model is not None and model not in filtered_models and all_model_access is False: raise ValueError( diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index e916e453a..d0d3b2e9f 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -543,6 +543,7 @@ async def user_api_key_auth( # noqa: PLR0915 general_settings=general_settings, global_proxy_spend=global_proxy_spend, route=route, + llm_router=llm_router, ) # return UserAPIKeyAuth object @@ -1176,6 +1177,7 @@ async def user_api_key_auth( # noqa: PLR0915 general_settings=general_settings, global_proxy_spend=global_proxy_spend, route=route, + llm_router=llm_router, ) # Token passed all checks if valid_token is None: diff --git a/tests/local_testing/test_auth_checks.py b/tests/local_testing/test_auth_checks.py index 0914107ad..67b5cf11d 100644 --- a/tests/local_testing/test_auth_checks.py +++ b/tests/local_testing/test_auth_checks.py @@ -151,3 +151,51 @@ async def test_can_key_call_model(model, expect_to_work): await can_key_call_model(**args) print(e) + + +@pytest.mark.parametrize( + "model, expect_to_work", + [("openai/gpt-4o", False), ("openai/gpt-4o-mini", True)], +) +@pytest.mark.asyncio +async def test_can_team_call_model(model, expect_to_work): + from litellm.proxy.auth.auth_checks import model_in_access_group + from fastapi import HTTPException + + llm_model_list = [ + { + "model_name": "openai/*", + "litellm_params": { + "model": "openai/*", + "api_key": "test-api-key", + }, + "model_info": { + "id": "e6e7006f83029df40ebc02ddd068890253f4cd3092bcb203d3d8e6f6f606f30f", + "db_model": False, + "access_groups": ["public-openai-models"], + }, + }, + { + "model_name": "openai/gpt-4o", + "litellm_params": { + "model": "openai/gpt-4o", + "api_key": "test-api-key", + }, + "model_info": { + "id": "0cfcd87f2cb12a783a466888d05c6c89df66db23e01cecd75ec0b83aed73c9ad", + "db_model": False, + "access_groups": ["private-openai-models"], + }, + }, + ] + router = litellm.Router(model_list=llm_model_list) + + args = { + "model": model, + "team_models": ["public-openai-models"], + "llm_router": router, + } + if expect_to_work: + assert model_in_access_group(**args) + else: + assert not model_in_access_group(**args)