From aa8b901b1b064b7155a8b1ce129a466e0ae9e29c Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 25 Feb 2025 14:50:10 -0800 Subject: [PATCH] can_team_access_model --- litellm/proxy/auth/auth_checks.py | 92 +++++++++++++------------------ 1 file changed, 39 insertions(+), 53 deletions(-) diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 0590bcb50a..c922599f86 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -38,6 +38,7 @@ from litellm.proxy._types import ( ProxyErrorTypes, ProxyException, RoleBasedPermissions, + SpecialModelNames, UserAPIKeyAuth, ) from litellm.proxy.auth.route_checks import RouteChecks @@ -97,12 +98,23 @@ async def common_checks( ) # 2. If team can call model - _team_model_access_check( - team_object=team_object, - model=_model, - llm_router=llm_router, - team_model_aliases=valid_token.team_model_aliases if valid_token else None, - ) + if ( + team_object is not None + and _model is not None + and can_team_access_model( + model=_model, + team_object=team_object, + llm_router=llm_router, + team_model_aliases=valid_token.team_model_aliases if valid_token else None, + ) + is False + ): + raise ProxyException( + message=f"Team not allowed to access model. Team={team_object.team_id}, Model={_model}. Allowed team models = {team_object.models}", + type=ProxyErrorTypes.team_model_access_denied, + param="model", + code=status.HTTP_401_UNAUTHORIZED, + ) ## 2.1 If user can call model (if personal key) if team_object is None and user_object is not None: @@ -1017,6 +1029,9 @@ async def _can_object_call_model( if (len(filtered_models) == 0 and len(models) == 0) or "*" in filtered_models: all_model_access = True + if SpecialModelNames.all_proxy_models in filtered_models: + all_model_access = True + if model is not None and model not in filtered_models and all_model_access is False: raise ProxyException( message=f"API Key not allowed to access model. This token can only access models={models}. Tried to access {model}", @@ -1074,6 +1089,24 @@ async def can_key_call_model( ) +async def can_team_access_model( + model: str, + team_object: Optional[LiteLLM_TeamTable], + llm_router: Optional[Router], + team_model_aliases: Optional[Dict[str, str]] = None, +) -> Literal[True]: + """ + Returns True if the team can access a specific model. + + """ + return await _can_object_call_model( + model=model, + llm_router=llm_router, + models=team_object.models if team_object else [], + team_model_aliases=team_model_aliases, + ) + + async def can_user_call_model( model: str, llm_router: Optional[Router], @@ -1239,53 +1272,6 @@ async def _team_max_budget_check( ) -def _team_model_access_check( - model: Optional[str], - team_object: Optional[LiteLLM_TeamTable], - llm_router: Optional[Router], - team_model_aliases: Optional[Dict[str, str]] = None, -): - """ - Access check for team models - Raises: - Exception if the team is not allowed to call the`model` - """ - if ( - model is not None - and team_object is not None - and team_object.models is not None - and len(team_object.models) > 0 - and model not in team_object.models - ): - # this means the team has access to all models on the proxy - if "all-proxy-models" in team_object.models or "*" in team_object.models: - # this means the team has access to all models on the proxy - pass - # check if the team model is an access_group - elif ( - model_in_access_group( - model=model, team_models=team_object.models, llm_router=llm_router - ) - is True - ): - pass - elif model and "*" in model: - pass - elif _model_in_team_aliases(model=model, team_model_aliases=team_model_aliases): - pass - elif _model_matches_any_wildcard_pattern_in_list( - model=model, allowed_model_list=team_object.models - ): - pass - else: - raise ProxyException( - message=f"Team not allowed to access model. Team={team_object.team_id}, Model={model}. Allowed team models = {team_object.models}", - type=ProxyErrorTypes.team_model_access_denied, - param="model", - code=status.HTTP_401_UNAUTHORIZED, - ) - - def is_model_allowed_by_pattern(model: str, allowed_model_pattern: str) -> bool: """ Check if a model matches an allowed pattern.