From 63a966679491ac86f048954111f93f1ff29d7c5a Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 29 Nov 2024 15:37:16 -0800 Subject: [PATCH] feat(auth_checks.py): ensure specific model access > wildcard model access if wildcard model is in access group, but specific model is not - deny access --- litellm/proxy/_new_secret_config.yaml | 16 ++++++ litellm/proxy/auth/auth_checks.py | 28 +++++----- litellm/proxy/auth/user_api_key_auth.py | 3 + litellm/router.py | 15 +++-- .../router_utils/pattern_match_deployments.py | 18 +++++- litellm/types/router.py | 2 +- tests/local_testing/test_auth_checks.py | 56 +++++++++++++++++++ 7 files changed, 118 insertions(+), 20 deletions(-) diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 03d66351d..97ae3a54d 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -15,6 +15,22 @@ model_list: litellm_params: model: openai/gpt-4o-realtime-preview-2024-10-01 api_key: os.environ/OPENAI_API_KEY + - model_name: openai/* + litellm_params: + model: openai/* + api_key: os.environ/OPENAI_API_KEY + - model_name: openai/* + litellm_params: + model: openai/* + api_key: os.environ/OPENAI_API_KEY + model_info: + access_groups: ["public-openai-models"] + - model_name: openai/gpt-4o + litellm_params: + model: openai/gpt-4o + api_key: os.environ/OPENAI_API_KEY + model_info: + access_groups: ["private-openai-models"] router_settings: routing_strategy: usage-based-routing-v2 diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 5d789436a..d0e0e805d 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -523,10 +523,6 @@ async def _cache_management_object( proxy_logging_obj: Optional[ProxyLogging], ): await user_api_key_cache.async_set_cache(key=key, value=value) - if proxy_logging_obj is not None: - await proxy_logging_obj.internal_usage_cache.dual_cache.async_set_cache( - key=key, value=value - ) async def _cache_team_object( @@ -878,7 +874,10 @@ async def get_org_object( async def can_key_call_model( - model: str, llm_model_list: Optional[list], valid_token: UserAPIKeyAuth + model: str, + llm_model_list: Optional[list], + valid_token: UserAPIKeyAuth, + llm_router: Optional[litellm.Router], ) -> Literal[True]: """ Checks if token can call a given model @@ -898,14 +897,14 @@ async def can_key_call_model( ) from collections import defaultdict - from litellm.proxy.proxy_server import llm_router - access_groups = defaultdict(list) if llm_router: - access_groups = llm_router.get_model_access_groups() + access_groups = llm_router.get_model_access_groups(model_name=model) models_in_current_access_groups = [] - if len(access_groups) > 0: # check if token contains any model access groups + if ( + len(access_groups) > 0 and llm_router is not None + ): # check if token contains any model access groups for idx, m in enumerate( valid_token.models ): # loop token models, if any of them are an access group add the access group @@ -923,10 +922,13 @@ async def can_key_call_model( all_model_access: bool = False if ( - len(filtered_models) == 0 - or "*" in filtered_models - or "openai/*" in filtered_models - ): + len(filtered_models) == 0 and len(valid_token.models) == 0 + ) or "*" in filtered_models: + all_model_access = True + elif ( + llm_router is not None + and llm_router.pattern_router.route(model, filtered_models) is not None + ): # wildcard access all_model_access = True if model is not None and model not in filtered_models and all_model_access is False: diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index c292a7dc3..e916e453a 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -259,6 +259,7 @@ async def user_api_key_auth( # noqa: PLR0915 jwt_handler, litellm_proxy_admin_name, llm_model_list, + llm_router, master_key, open_telemetry_logger, prisma_client, @@ -905,6 +906,7 @@ async def user_api_key_auth( # noqa: PLR0915 model=model, llm_model_list=llm_model_list, valid_token=valid_token, + llm_router=llm_router, ) if fallback_models is not None: @@ -913,6 +915,7 @@ async def user_api_key_auth( # noqa: PLR0915 model=m, llm_model_list=llm_model_list, valid_token=valid_token, + llm_router=llm_router, ) # Check 2. If user_id for this token is in budget - done in common_checks() diff --git a/litellm/router.py b/litellm/router.py index 3751b2403..89e7e8321 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -4712,6 +4712,9 @@ class Router: if hasattr(self, "model_list"): returned_models: List[DeploymentTypedDict] = [] + if model_name is not None: + returned_models.extend(self._get_all_deployments(model_name=model_name)) + if hasattr(self, "model_group_alias"): for model_alias, model_value in self.model_group_alias.items(): @@ -4743,17 +4746,21 @@ class Router: returned_models += self.model_list return returned_models - returned_models.extend(self._get_all_deployments(model_name=model_name)) + return returned_models return None - def get_model_access_groups(self): + def get_model_access_groups(self, model_name: Optional[str] = None): + """ + If model_name is provided, only return access groups for that model. + """ from collections import defaultdict access_groups = defaultdict(list) - if self.model_list: - for m in self.model_list: + model_list = self.get_model_list(model_name=model_name) + if model_list: + for m in model_list: for group in m.get("model_info", {}).get("access_groups", []): model_name = m["model_name"] access_groups[group].append(model_name) diff --git a/litellm/router_utils/pattern_match_deployments.py b/litellm/router_utils/pattern_match_deployments.py index 3896c3a95..a369100eb 100644 --- a/litellm/router_utils/pattern_match_deployments.py +++ b/litellm/router_utils/pattern_match_deployments.py @@ -79,7 +79,9 @@ class PatternMatchRouter: return new_deployments - def route(self, request: Optional[str]) -> Optional[List[Dict]]: + def route( + self, request: Optional[str], filtered_model_names: Optional[List[str]] = None + ) -> Optional[List[Dict]]: """ Route a requested model to the corresponding llm deployments based on the regex pattern @@ -89,14 +91,26 @@ class PatternMatchRouter: Args: request: Optional[str] - + filtered_model_names: Optional[List[str]] - if provided, only return deployments that match the filtered_model_names Returns: Optional[List[Deployment]]: llm deployments """ try: if request is None: return None + + regex_filtered_model_names = ( + [self._pattern_to_regex(m) for m in filtered_model_names] + if filtered_model_names is not None + else [] + ) + for pattern, llm_deployments in self.patterns.items(): + if ( + filtered_model_names is not None + and pattern not in regex_filtered_model_names + ): + continue pattern_match = re.match(pattern, request) if pattern_match: return self._return_pattern_matched_deployments( diff --git a/litellm/types/router.py b/litellm/types/router.py index 2b7d1d83b..99d981e4d 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -355,7 +355,7 @@ class LiteLLMParamsTypedDict(TypedDict, total=False): class DeploymentTypedDict(TypedDict, total=False): model_name: Required[str] litellm_params: Required[LiteLLMParamsTypedDict] - model_info: Optional[dict] + model_info: dict SPECIAL_MODEL_INFO_PARAMS = [ diff --git a/tests/local_testing/test_auth_checks.py b/tests/local_testing/test_auth_checks.py index f1683a153..0914107ad 100644 --- a/tests/local_testing/test_auth_checks.py +++ b/tests/local_testing/test_auth_checks.py @@ -95,3 +95,59 @@ async def test_handle_failed_db_connection(): print("_handle_failed_db_connection_for_get_key_object got exception", exc_info) assert str(exc_info.value) == "Failed to connect to DB" + + +@pytest.mark.parametrize( + "model, expect_to_work", + [("openai/gpt-4o-mini", True), ("openai/gpt-4o", False)], +) +@pytest.mark.asyncio +async def test_can_key_call_model(model, expect_to_work): + """ + If wildcard model + specific model is used, choose the specific model settings + """ + from litellm.proxy.auth.auth_checks import can_key_call_model + from fastapi import HTTPException + + llm_model_list = [ + { + "model_name": "openai/*", + "litellm_params": { + "model": "openai/*", + "api_key": "test-api-key", + }, + "model_info": { + "id": "e6e7006f83029df40ebc02ddd068890253f4cd3092bcb203d3d8e6f6f606f30f", + "db_model": False, + "access_groups": ["public-openai-models"], + }, + }, + { + "model_name": "openai/gpt-4o", + "litellm_params": { + "model": "openai/gpt-4o", + "api_key": "test-api-key", + }, + "model_info": { + "id": "0cfcd87f2cb12a783a466888d05c6c89df66db23e01cecd75ec0b83aed73c9ad", + "db_model": False, + "access_groups": ["private-openai-models"], + }, + }, + ] + router = litellm.Router(model_list=llm_model_list) + args = { + "model": model, + "llm_model_list": llm_model_list, + "valid_token": UserAPIKeyAuth( + models=["public-openai-models"], + ), + "llm_router": router, + } + if expect_to_work: + await can_key_call_model(**args) + else: + with pytest.raises(Exception) as e: + await can_key_call_model(**args) + + print(e)