From e9cdbff75a3721748f7223a55ddc498049fec5ed Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 25 Nov 2024 13:58:08 -0800 Subject: [PATCH] add check for _model_is_within_list_of_allowed_models --- litellm/proxy/auth/auth_checks.py | 46 +++++++++++++++++++------------ 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index b1c27a409..52f7f22fe 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -92,17 +92,15 @@ def common_checks( # noqa: PLR0915 ): # this means the team has access to all models on the proxy if ( - "all-proxy-models" in team_object.models - or "*" in team_object.models - or "openai/*" in team_object.models + _model_is_within_list_of_allowed_models( + model=_model, allowed_models=team_object.models + ) + is True ): - # this means the team has access to all models on the proxy pass # check if the team model is an access_group elif model_in_access_group(_model, team_object.models) is True: pass - elif _model and "*" in _model: - pass else: raise Exception( f"Team={team_object.team_id} not allowed to call model={_model}. Allowed team models = {team_object.models}" @@ -868,18 +866,16 @@ async def can_key_call_model( filtered_models += models_in_current_access_groups verbose_proxy_logger.debug(f"model: {model}; allowed_models: {filtered_models}") - # Check for universal access patterns - if len(filtered_models) == 0: - return True - if "*" in filtered_models: - return True - if model_matches_patterns(model=model, allowed_models=filtered_models) is True: - return True - - if model is not None and model not in filtered_models: - raise ValueError( - f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}" + if ( + _model_is_within_list_of_allowed_models( + model=model, allowed_models=filtered_models ) + is False + ): + raise ValueError( + f"API Key not allowed to access model. List of allowed models={filtered_models}. Tried to access {model}" + ) + valid_token.models = filtered_models verbose_proxy_logger.debug( f"filtered allowed_models: {filtered_models}; valid_token.models: {valid_token.models}" @@ -887,6 +883,22 @@ async def can_key_call_model( return True +def _model_is_within_list_of_allowed_models( + model: str, allowed_models: List[str] +) -> bool: + # Check for universal access patterns + if len(allowed_models) == 0: + return True + if "*" in allowed_models: + return True + if "all-proxy-models" in allowed_models: + return True + if model_matches_patterns(model=model, allowed_models=allowed_models) is True: + return True + + return False + + def model_matches_patterns(model: str, allowed_models: List[str]) -> bool: """ Helper function to check if a model matches any of the allowed model patterns.