can_team_access_model

This commit is contained in:
Ishaan Jaff 2025-02-25 14:50:10 -08:00
parent 835df849b0
commit eeee61db65

View file

@ -38,6 +38,7 @@ from litellm.proxy._types import (
ProxyErrorTypes, ProxyErrorTypes,
ProxyException, ProxyException,
RoleBasedPermissions, RoleBasedPermissions,
SpecialModelNames,
UserAPIKeyAuth, UserAPIKeyAuth,
) )
from litellm.proxy.auth.route_checks import RouteChecks from litellm.proxy.auth.route_checks import RouteChecks
@ -97,12 +98,23 @@ async def common_checks(
) )
# 2. If team can call model # 2. If team can call model
_team_model_access_check( if (
team_object=team_object, team_object is not None
model=_model, and _model is not None
llm_router=llm_router, and can_team_access_model(
team_model_aliases=valid_token.team_model_aliases if valid_token else None, model=_model,
) team_object=team_object,
llm_router=llm_router,
team_model_aliases=valid_token.team_model_aliases if valid_token else None,
)
is False
):
raise ProxyException(
message=f"Team not allowed to access model. Team={team_object.team_id}, Model={_model}. Allowed team models = {team_object.models}",
type=ProxyErrorTypes.team_model_access_denied,
param="model",
code=status.HTTP_401_UNAUTHORIZED,
)
## 2.1 If user can call model (if personal key) ## 2.1 If user can call model (if personal key)
if team_object is None and user_object is not None: if team_object is None and user_object is not None:
@ -1017,6 +1029,9 @@ async def _can_object_call_model(
if (len(filtered_models) == 0 and len(models) == 0) or "*" in filtered_models: if (len(filtered_models) == 0 and len(models) == 0) or "*" in filtered_models:
all_model_access = True all_model_access = True
if SpecialModelNames.all_proxy_models in filtered_models:
all_model_access = True
if model is not None and model not in filtered_models and all_model_access is False: if model is not None and model not in filtered_models and all_model_access is False:
raise ProxyException( raise ProxyException(
message=f"API Key not allowed to access model. This token can only access models={models}. Tried to access {model}", message=f"API Key not allowed to access model. This token can only access models={models}. Tried to access {model}",
@ -1074,6 +1089,24 @@ async def can_key_call_model(
) )
async def can_team_access_model(
model: str,
team_object: Optional[LiteLLM_TeamTable],
llm_router: Optional[Router],
team_model_aliases: Optional[Dict[str, str]] = None,
) -> Literal[True]:
"""
Returns True if the team can access a specific model.
"""
return await _can_object_call_model(
model=model,
llm_router=llm_router,
models=team_object.models if team_object else [],
team_model_aliases=team_model_aliases,
)
async def can_user_call_model( async def can_user_call_model(
model: str, model: str,
llm_router: Optional[Router], llm_router: Optional[Router],
@ -1239,53 +1272,6 @@ async def _team_max_budget_check(
) )
def _team_model_access_check(
model: Optional[str],
team_object: Optional[LiteLLM_TeamTable],
llm_router: Optional[Router],
team_model_aliases: Optional[Dict[str, str]] = None,
):
"""
Access check for team models
Raises:
Exception if the team is not allowed to call the`model`
"""
if (
model is not None
and team_object is not None
and team_object.models is not None
and len(team_object.models) > 0
and model not in team_object.models
):
# this means the team has access to all models on the proxy
if "all-proxy-models" in team_object.models or "*" in team_object.models:
# this means the team has access to all models on the proxy
pass
# check if the team model is an access_group
elif (
model_in_access_group(
model=model, team_models=team_object.models, llm_router=llm_router
)
is True
):
pass
elif model and "*" in model:
pass
elif _model_in_team_aliases(model=model, team_model_aliases=team_model_aliases):
pass
elif _model_matches_any_wildcard_pattern_in_list(
model=model, allowed_model_list=team_object.models
):
pass
else:
raise ProxyException(
message=f"Team not allowed to access model. Team={team_object.team_id}, Model={model}. Allowed team models = {team_object.models}",
type=ProxyErrorTypes.team_model_access_denied,
param="model",
code=status.HTTP_401_UNAUTHORIZED,
)
def is_model_allowed_by_pattern(model: str, allowed_model_pattern: str) -> bool: def is_model_allowed_by_pattern(model: str, allowed_model_pattern: str) -> bool:
""" """
Check if a model matches an allowed pattern. Check if a model matches an allowed pattern.