(QA / testing) - Add e2e tests for key model access auth checks (#8000)

* fix _model_matches_any_wildcard_pattern_in_list

* test key model access checks

* add key_model_access_denied to ProxyErrorTypes

* update auth checks

* test_model_access_update

* test_team_model_access_patterns

* fix _team_model_access_check

* fix config used for otel testing

* test fix test_call_with_invalid_model

* fix model acces check tests

* test_team_access_groups

* test _model_matches_any_wildcard_pattern_in_list
This commit is contained in:
Ishaan Jaff 2025-01-25 17:15:11 -08:00 committed by GitHub
parent 833a268f4b
commit d19614b8c0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 375 additions and 21 deletions

View file

@ -14,12 +14,14 @@ import time
import traceback
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
from fastapi import status
from pydantic import BaseModel
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.caching.dual_cache import LimitedSizeOrderedDict
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
from litellm.proxy._types import (
DB_CONNECTION_ERROR_TYPES,
CallInfo,
@ -31,6 +33,8 @@ from litellm.proxy._types import (
LiteLLM_UserTable,
LiteLLMRoutes,
LitellmUserRoles,
ProxyErrorTypes,
ProxyException,
UserAPIKeyAuth,
)
from litellm.proxy.auth.route_checks import RouteChecks
@ -887,8 +891,11 @@ async def can_key_call_model(
all_model_access = True
if model is not None and model not in filtered_models and all_model_access is False:
raise ValueError(
f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}"
raise ProxyException(
message=f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}",
type=ProxyErrorTypes.key_model_access_denied,
param="model",
code=status.HTTP_401_UNAUTHORIZED,
)
valid_token.models = filtered_models
verbose_proxy_logger.debug(
@ -1064,11 +1071,7 @@ def _team_model_access_check(
and model not in team_object.models
):
# this means the team has access to all models on the proxy
if (
"all-proxy-models" in team_object.models
or "*" in team_object.models
or "openai/*" in team_object.models
):
if "all-proxy-models" in team_object.models or "*" in team_object.models:
# this means the team has access to all models on the proxy
pass
# check if the team model is an access_group
@ -1086,8 +1089,11 @@ def _team_model_access_check(
):
pass
else:
raise Exception(
f"Team={team_object.team_id} not allowed to call model={model}. Allowed team models = {team_object.models}"
raise ProxyException(
message=f"Team not allowed to access model. Team={team_object.team_id}, Model={model}. Allowed team models = {team_object.models}",
type=ProxyErrorTypes.team_model_access_denied,
param="model",
code=status.HTTP_401_UNAUTHORIZED,
)
@ -1121,10 +1127,51 @@ def _model_matches_any_wildcard_pattern_in_list(
- model=`bedrock/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/us.*` returns True
- model=`bedrockzzzz/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/*` returns False
"""
return any(
"*" in allowed_model_pattern
if any(
_is_wildcard_pattern(allowed_model_pattern)
and is_model_allowed_by_pattern(
model=model, allowed_model_pattern=allowed_model_pattern
)
for allowed_model_pattern in allowed_model_list
):
return True
if any(
_is_wildcard_pattern(allowed_model_pattern)
and _model_custom_llm_provider_matches_wildcard_pattern(
model=model, allowed_model_pattern=allowed_model_pattern
)
for allowed_model_pattern in allowed_model_list
):
return True
return False
def _model_custom_llm_provider_matches_wildcard_pattern(
model: str, allowed_model_pattern: str
) -> bool:
"""
Returns True for this scenario:
- `model=gpt-4o`
- `allowed_model_pattern=openai/*`
or
- `model=claude-3-5-sonnet-20240620`
- `allowed_model_pattern=anthropic/*`
"""
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
return is_model_allowed_by_pattern(
model=f"{custom_llm_provider}/{model}",
allowed_model_pattern=allowed_model_pattern,
)
def _is_wildcard_pattern(allowed_model_pattern: str) -> bool:
"""
Returns True if the pattern is a wildcard pattern.
Checks if `*` is in the pattern.
"""
return "*" in allowed_model_pattern