mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
(QA / testing) - Add e2e tests for key model access auth checks (#8000)
* fix _model_matches_any_wildcard_pattern_in_list * test key model access checks * add key_model_access_denied to ProxyErrorTypes * update auth checks * test_model_access_update * test_team_model_access_patterns * fix _team_model_access_check * fix config used for otel testing * test fix test_call_with_invalid_model * fix model acces check tests * test_team_access_groups * test _model_matches_any_wildcard_pattern_in_list
This commit is contained in:
parent
833a268f4b
commit
d19614b8c0
6 changed files with 375 additions and 21 deletions
|
@ -14,12 +14,14 @@ import time
|
|||
import traceback
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
|
||||
|
||||
from fastapi import status
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.caching.dual_cache import LimitedSizeOrderedDict
|
||||
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
|
||||
from litellm.proxy._types import (
|
||||
DB_CONNECTION_ERROR_TYPES,
|
||||
CallInfo,
|
||||
|
@ -31,6 +33,8 @@ from litellm.proxy._types import (
|
|||
LiteLLM_UserTable,
|
||||
LiteLLMRoutes,
|
||||
LitellmUserRoles,
|
||||
ProxyErrorTypes,
|
||||
ProxyException,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.route_checks import RouteChecks
|
||||
|
@ -887,8 +891,11 @@ async def can_key_call_model(
|
|||
all_model_access = True
|
||||
|
||||
if model is not None and model not in filtered_models and all_model_access is False:
|
||||
raise ValueError(
|
||||
f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}"
|
||||
raise ProxyException(
|
||||
message=f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}",
|
||||
type=ProxyErrorTypes.key_model_access_denied,
|
||||
param="model",
|
||||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
valid_token.models = filtered_models
|
||||
verbose_proxy_logger.debug(
|
||||
|
@ -1064,11 +1071,7 @@ def _team_model_access_check(
|
|||
and model not in team_object.models
|
||||
):
|
||||
# this means the team has access to all models on the proxy
|
||||
if (
|
||||
"all-proxy-models" in team_object.models
|
||||
or "*" in team_object.models
|
||||
or "openai/*" in team_object.models
|
||||
):
|
||||
if "all-proxy-models" in team_object.models or "*" in team_object.models:
|
||||
# this means the team has access to all models on the proxy
|
||||
pass
|
||||
# check if the team model is an access_group
|
||||
|
@ -1086,8 +1089,11 @@ def _team_model_access_check(
|
|||
):
|
||||
pass
|
||||
else:
|
||||
raise Exception(
|
||||
f"Team={team_object.team_id} not allowed to call model={model}. Allowed team models = {team_object.models}"
|
||||
raise ProxyException(
|
||||
message=f"Team not allowed to access model. Team={team_object.team_id}, Model={model}. Allowed team models = {team_object.models}",
|
||||
type=ProxyErrorTypes.team_model_access_denied,
|
||||
param="model",
|
||||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
|
||||
|
@ -1121,10 +1127,51 @@ def _model_matches_any_wildcard_pattern_in_list(
|
|||
- model=`bedrock/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/us.*` returns True
|
||||
- model=`bedrockzzzz/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/*` returns False
|
||||
"""
|
||||
return any(
|
||||
"*" in allowed_model_pattern
|
||||
|
||||
if any(
|
||||
_is_wildcard_pattern(allowed_model_pattern)
|
||||
and is_model_allowed_by_pattern(
|
||||
model=model, allowed_model_pattern=allowed_model_pattern
|
||||
)
|
||||
for allowed_model_pattern in allowed_model_list
|
||||
):
|
||||
return True
|
||||
|
||||
if any(
|
||||
_is_wildcard_pattern(allowed_model_pattern)
|
||||
and _model_custom_llm_provider_matches_wildcard_pattern(
|
||||
model=model, allowed_model_pattern=allowed_model_pattern
|
||||
)
|
||||
for allowed_model_pattern in allowed_model_list
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _model_custom_llm_provider_matches_wildcard_pattern(
|
||||
model: str, allowed_model_pattern: str
|
||||
) -> bool:
|
||||
"""
|
||||
Returns True for this scenario:
|
||||
- `model=gpt-4o`
|
||||
- `allowed_model_pattern=openai/*`
|
||||
|
||||
or
|
||||
- `model=claude-3-5-sonnet-20240620`
|
||||
- `allowed_model_pattern=anthropic/*`
|
||||
"""
|
||||
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
|
||||
return is_model_allowed_by_pattern(
|
||||
model=f"{custom_llm_provider}/{model}",
|
||||
allowed_model_pattern=allowed_model_pattern,
|
||||
)
|
||||
|
||||
|
||||
def _is_wildcard_pattern(allowed_model_pattern: str) -> bool:
|
||||
"""
|
||||
Returns True if the pattern is a wildcard pattern.
|
||||
|
||||
Checks if `*` is in the pattern.
|
||||
"""
|
||||
return "*" in allowed_model_pattern
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue