feat(auth_checks.py): ensure specific model access > wildcard model access

if wildcard model is in access group, but specific model is not - deny access
This commit is contained in:
Krrish Dholakia 2024-11-29 15:37:16 -08:00
parent a014168c0c
commit 63a9666794
7 changed files with 118 additions and 20 deletions

View file

@ -15,6 +15,22 @@ model_list:
litellm_params:
model: openai/gpt-4o-realtime-preview-2024-10-01
api_key: os.environ/OPENAI_API_KEY
- model_name: openai/*
litellm_params:
model: openai/*
api_key: os.environ/OPENAI_API_KEY
- model_name: openai/*
litellm_params:
model: openai/*
api_key: os.environ/OPENAI_API_KEY
model_info:
access_groups: ["public-openai-models"]
- model_name: openai/gpt-4o
litellm_params:
model: openai/gpt-4o
api_key: os.environ/OPENAI_API_KEY
model_info:
access_groups: ["private-openai-models"]
router_settings:
routing_strategy: usage-based-routing-v2

View file

@ -523,10 +523,6 @@ async def _cache_management_object(
proxy_logging_obj: Optional[ProxyLogging],
):
await user_api_key_cache.async_set_cache(key=key, value=value)
if proxy_logging_obj is not None:
await proxy_logging_obj.internal_usage_cache.dual_cache.async_set_cache(
key=key, value=value
)
async def _cache_team_object(
@ -878,7 +874,10 @@ async def get_org_object(
async def can_key_call_model(
model: str, llm_model_list: Optional[list], valid_token: UserAPIKeyAuth
model: str,
llm_model_list: Optional[list],
valid_token: UserAPIKeyAuth,
llm_router: Optional[litellm.Router],
) -> Literal[True]:
"""
Checks if token can call a given model
@ -898,14 +897,14 @@ async def can_key_call_model(
)
from collections import defaultdict
from litellm.proxy.proxy_server import llm_router
access_groups = defaultdict(list)
if llm_router:
access_groups = llm_router.get_model_access_groups()
access_groups = llm_router.get_model_access_groups(model_name=model)
models_in_current_access_groups = []
if len(access_groups) > 0: # check if token contains any model access groups
if (
len(access_groups) > 0 and llm_router is not None
): # check if token contains any model access groups
for idx, m in enumerate(
valid_token.models
): # loop token models, if any of them are an access group add the access group
@ -923,10 +922,13 @@ async def can_key_call_model(
all_model_access: bool = False
if (
len(filtered_models) == 0
or "*" in filtered_models
or "openai/*" in filtered_models
):
len(filtered_models) == 0 and len(valid_token.models) == 0
) or "*" in filtered_models:
all_model_access = True
elif (
llm_router is not None
and llm_router.pattern_router.route(model, filtered_models) is not None
): # wildcard access
all_model_access = True
if model is not None and model not in filtered_models and all_model_access is False:

View file

@ -259,6 +259,7 @@ async def user_api_key_auth( # noqa: PLR0915
jwt_handler,
litellm_proxy_admin_name,
llm_model_list,
llm_router,
master_key,
open_telemetry_logger,
prisma_client,
@ -905,6 +906,7 @@ async def user_api_key_auth( # noqa: PLR0915
model=model,
llm_model_list=llm_model_list,
valid_token=valid_token,
llm_router=llm_router,
)
if fallback_models is not None:
@ -913,6 +915,7 @@ async def user_api_key_auth( # noqa: PLR0915
model=m,
llm_model_list=llm_model_list,
valid_token=valid_token,
llm_router=llm_router,
)
# Check 2. If user_id for this token is in budget - done in common_checks()

View file

@ -4712,6 +4712,9 @@ class Router:
if hasattr(self, "model_list"):
returned_models: List[DeploymentTypedDict] = []
if model_name is not None:
returned_models.extend(self._get_all_deployments(model_name=model_name))
if hasattr(self, "model_group_alias"):
for model_alias, model_value in self.model_group_alias.items():
@ -4743,17 +4746,21 @@ class Router:
returned_models += self.model_list
return returned_models
returned_models.extend(self._get_all_deployments(model_name=model_name))
return returned_models
return None
def get_model_access_groups(self):
def get_model_access_groups(self, model_name: Optional[str] = None):
"""
If model_name is provided, only return access groups for that model.
"""
from collections import defaultdict
access_groups = defaultdict(list)
if self.model_list:
for m in self.model_list:
model_list = self.get_model_list(model_name=model_name)
if model_list:
for m in model_list:
for group in m.get("model_info", {}).get("access_groups", []):
model_name = m["model_name"]
access_groups[group].append(model_name)

View file

@ -79,7 +79,9 @@ class PatternMatchRouter:
return new_deployments
def route(self, request: Optional[str]) -> Optional[List[Dict]]:
def route(
self, request: Optional[str], filtered_model_names: Optional[List[str]] = None
) -> Optional[List[Dict]]:
"""
Route a requested model to the corresponding llm deployments based on the regex pattern
@ -89,14 +91,26 @@ class PatternMatchRouter:
Args:
request: Optional[str]
filtered_model_names: Optional[List[str]] - if provided, only return deployments that match the filtered_model_names
Returns:
Optional[List[Deployment]]: llm deployments
"""
try:
if request is None:
return None
regex_filtered_model_names = (
[self._pattern_to_regex(m) for m in filtered_model_names]
if filtered_model_names is not None
else []
)
for pattern, llm_deployments in self.patterns.items():
if (
filtered_model_names is not None
and pattern not in regex_filtered_model_names
):
continue
pattern_match = re.match(pattern, request)
if pattern_match:
return self._return_pattern_matched_deployments(

View file

@ -355,7 +355,7 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
class DeploymentTypedDict(TypedDict, total=False):
model_name: Required[str]
litellm_params: Required[LiteLLMParamsTypedDict]
model_info: Optional[dict]
model_info: dict
SPECIAL_MODEL_INFO_PARAMS = [

View file

@ -95,3 +95,59 @@ async def test_handle_failed_db_connection():
print("_handle_failed_db_connection_for_get_key_object got exception", exc_info)
assert str(exc_info.value) == "Failed to connect to DB"
@pytest.mark.parametrize(
"model, expect_to_work",
[("openai/gpt-4o-mini", True), ("openai/gpt-4o", False)],
)
@pytest.mark.asyncio
async def test_can_key_call_model(model, expect_to_work):
"""
If wildcard model + specific model is used, choose the specific model settings
"""
from litellm.proxy.auth.auth_checks import can_key_call_model
from fastapi import HTTPException
llm_model_list = [
{
"model_name": "openai/*",
"litellm_params": {
"model": "openai/*",
"api_key": "test-api-key",
},
"model_info": {
"id": "e6e7006f83029df40ebc02ddd068890253f4cd3092bcb203d3d8e6f6f606f30f",
"db_model": False,
"access_groups": ["public-openai-models"],
},
},
{
"model_name": "openai/gpt-4o",
"litellm_params": {
"model": "openai/gpt-4o",
"api_key": "test-api-key",
},
"model_info": {
"id": "0cfcd87f2cb12a783a466888d05c6c89df66db23e01cecd75ec0b83aed73c9ad",
"db_model": False,
"access_groups": ["private-openai-models"],
},
},
]
router = litellm.Router(model_list=llm_model_list)
args = {
"model": model,
"llm_model_list": llm_model_list,
"valid_token": UserAPIKeyAuth(
models=["public-openai-models"],
),
"llm_router": router,
}
if expect_to_work:
await can_key_call_model(**args)
else:
with pytest.raises(Exception) as e:
await can_key_call_model(**args)
print(e)