forked from phoenix/litellm-mirror
feat(auth_checks.py): ensure specific model access > wildcard model access
if wildcard model is in access group, but specific model is not - deny access
This commit is contained in:
parent
a014168c0c
commit
63a9666794
7 changed files with 118 additions and 20 deletions
|
@ -15,6 +15,22 @@ model_list:
|
|||
litellm_params:
|
||||
model: openai/gpt-4o-realtime-preview-2024-10-01
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
- model_name: openai/*
|
||||
litellm_params:
|
||||
model: openai/*
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
- model_name: openai/*
|
||||
litellm_params:
|
||||
model: openai/*
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
model_info:
|
||||
access_groups: ["public-openai-models"]
|
||||
- model_name: openai/gpt-4o
|
||||
litellm_params:
|
||||
model: openai/gpt-4o
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
model_info:
|
||||
access_groups: ["private-openai-models"]
|
||||
|
||||
router_settings:
|
||||
routing_strategy: usage-based-routing-v2
|
||||
|
|
|
@ -523,10 +523,6 @@ async def _cache_management_object(
|
|||
proxy_logging_obj: Optional[ProxyLogging],
|
||||
):
|
||||
await user_api_key_cache.async_set_cache(key=key, value=value)
|
||||
if proxy_logging_obj is not None:
|
||||
await proxy_logging_obj.internal_usage_cache.dual_cache.async_set_cache(
|
||||
key=key, value=value
|
||||
)
|
||||
|
||||
|
||||
async def _cache_team_object(
|
||||
|
@ -878,7 +874,10 @@ async def get_org_object(
|
|||
|
||||
|
||||
async def can_key_call_model(
|
||||
model: str, llm_model_list: Optional[list], valid_token: UserAPIKeyAuth
|
||||
model: str,
|
||||
llm_model_list: Optional[list],
|
||||
valid_token: UserAPIKeyAuth,
|
||||
llm_router: Optional[litellm.Router],
|
||||
) -> Literal[True]:
|
||||
"""
|
||||
Checks if token can call a given model
|
||||
|
@ -898,14 +897,14 @@ async def can_key_call_model(
|
|||
)
|
||||
from collections import defaultdict
|
||||
|
||||
from litellm.proxy.proxy_server import llm_router
|
||||
|
||||
access_groups = defaultdict(list)
|
||||
if llm_router:
|
||||
access_groups = llm_router.get_model_access_groups()
|
||||
access_groups = llm_router.get_model_access_groups(model_name=model)
|
||||
|
||||
models_in_current_access_groups = []
|
||||
if len(access_groups) > 0: # check if token contains any model access groups
|
||||
if (
|
||||
len(access_groups) > 0 and llm_router is not None
|
||||
): # check if token contains any model access groups
|
||||
for idx, m in enumerate(
|
||||
valid_token.models
|
||||
): # loop token models, if any of them are an access group add the access group
|
||||
|
@ -923,10 +922,13 @@ async def can_key_call_model(
|
|||
all_model_access: bool = False
|
||||
|
||||
if (
|
||||
len(filtered_models) == 0
|
||||
or "*" in filtered_models
|
||||
or "openai/*" in filtered_models
|
||||
):
|
||||
len(filtered_models) == 0 and len(valid_token.models) == 0
|
||||
) or "*" in filtered_models:
|
||||
all_model_access = True
|
||||
elif (
|
||||
llm_router is not None
|
||||
and llm_router.pattern_router.route(model, filtered_models) is not None
|
||||
): # wildcard access
|
||||
all_model_access = True
|
||||
|
||||
if model is not None and model not in filtered_models and all_model_access is False:
|
||||
|
|
|
@ -259,6 +259,7 @@ async def user_api_key_auth( # noqa: PLR0915
|
|||
jwt_handler,
|
||||
litellm_proxy_admin_name,
|
||||
llm_model_list,
|
||||
llm_router,
|
||||
master_key,
|
||||
open_telemetry_logger,
|
||||
prisma_client,
|
||||
|
@ -905,6 +906,7 @@ async def user_api_key_auth( # noqa: PLR0915
|
|||
model=model,
|
||||
llm_model_list=llm_model_list,
|
||||
valid_token=valid_token,
|
||||
llm_router=llm_router,
|
||||
)
|
||||
|
||||
if fallback_models is not None:
|
||||
|
@ -913,6 +915,7 @@ async def user_api_key_auth( # noqa: PLR0915
|
|||
model=m,
|
||||
llm_model_list=llm_model_list,
|
||||
valid_token=valid_token,
|
||||
llm_router=llm_router,
|
||||
)
|
||||
|
||||
# Check 2. If user_id for this token is in budget - done in common_checks()
|
||||
|
|
|
@ -4712,6 +4712,9 @@ class Router:
|
|||
if hasattr(self, "model_list"):
|
||||
returned_models: List[DeploymentTypedDict] = []
|
||||
|
||||
if model_name is not None:
|
||||
returned_models.extend(self._get_all_deployments(model_name=model_name))
|
||||
|
||||
if hasattr(self, "model_group_alias"):
|
||||
for model_alias, model_value in self.model_group_alias.items():
|
||||
|
||||
|
@ -4743,17 +4746,21 @@ class Router:
|
|||
returned_models += self.model_list
|
||||
|
||||
return returned_models
|
||||
returned_models.extend(self._get_all_deployments(model_name=model_name))
|
||||
|
||||
return returned_models
|
||||
return None
|
||||
|
||||
def get_model_access_groups(self):
|
||||
def get_model_access_groups(self, model_name: Optional[str] = None):
|
||||
"""
|
||||
If model_name is provided, only return access groups for that model.
|
||||
"""
|
||||
from collections import defaultdict
|
||||
|
||||
access_groups = defaultdict(list)
|
||||
|
||||
if self.model_list:
|
||||
for m in self.model_list:
|
||||
model_list = self.get_model_list(model_name=model_name)
|
||||
if model_list:
|
||||
for m in model_list:
|
||||
for group in m.get("model_info", {}).get("access_groups", []):
|
||||
model_name = m["model_name"]
|
||||
access_groups[group].append(model_name)
|
||||
|
|
|
@ -79,7 +79,9 @@ class PatternMatchRouter:
|
|||
|
||||
return new_deployments
|
||||
|
||||
def route(self, request: Optional[str]) -> Optional[List[Dict]]:
|
||||
def route(
|
||||
self, request: Optional[str], filtered_model_names: Optional[List[str]] = None
|
||||
) -> Optional[List[Dict]]:
|
||||
"""
|
||||
Route a requested model to the corresponding llm deployments based on the regex pattern
|
||||
|
||||
|
@ -89,14 +91,26 @@ class PatternMatchRouter:
|
|||
|
||||
Args:
|
||||
request: Optional[str]
|
||||
|
||||
filtered_model_names: Optional[List[str]] - if provided, only return deployments that match the filtered_model_names
|
||||
Returns:
|
||||
Optional[List[Deployment]]: llm deployments
|
||||
"""
|
||||
try:
|
||||
if request is None:
|
||||
return None
|
||||
|
||||
regex_filtered_model_names = (
|
||||
[self._pattern_to_regex(m) for m in filtered_model_names]
|
||||
if filtered_model_names is not None
|
||||
else []
|
||||
)
|
||||
|
||||
for pattern, llm_deployments in self.patterns.items():
|
||||
if (
|
||||
filtered_model_names is not None
|
||||
and pattern not in regex_filtered_model_names
|
||||
):
|
||||
continue
|
||||
pattern_match = re.match(pattern, request)
|
||||
if pattern_match:
|
||||
return self._return_pattern_matched_deployments(
|
||||
|
|
|
@ -355,7 +355,7 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
|
|||
class DeploymentTypedDict(TypedDict, total=False):
|
||||
model_name: Required[str]
|
||||
litellm_params: Required[LiteLLMParamsTypedDict]
|
||||
model_info: Optional[dict]
|
||||
model_info: dict
|
||||
|
||||
|
||||
SPECIAL_MODEL_INFO_PARAMS = [
|
||||
|
|
|
@ -95,3 +95,59 @@ async def test_handle_failed_db_connection():
|
|||
print("_handle_failed_db_connection_for_get_key_object got exception", exc_info)
|
||||
|
||||
assert str(exc_info.value) == "Failed to connect to DB"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, expect_to_work",
|
||||
[("openai/gpt-4o-mini", True), ("openai/gpt-4o", False)],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_key_call_model(model, expect_to_work):
|
||||
"""
|
||||
If wildcard model + specific model is used, choose the specific model settings
|
||||
"""
|
||||
from litellm.proxy.auth.auth_checks import can_key_call_model
|
||||
from fastapi import HTTPException
|
||||
|
||||
llm_model_list = [
|
||||
{
|
||||
"model_name": "openai/*",
|
||||
"litellm_params": {
|
||||
"model": "openai/*",
|
||||
"api_key": "test-api-key",
|
||||
},
|
||||
"model_info": {
|
||||
"id": "e6e7006f83029df40ebc02ddd068890253f4cd3092bcb203d3d8e6f6f606f30f",
|
||||
"db_model": False,
|
||||
"access_groups": ["public-openai-models"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "openai/gpt-4o",
|
||||
"litellm_params": {
|
||||
"model": "openai/gpt-4o",
|
||||
"api_key": "test-api-key",
|
||||
},
|
||||
"model_info": {
|
||||
"id": "0cfcd87f2cb12a783a466888d05c6c89df66db23e01cecd75ec0b83aed73c9ad",
|
||||
"db_model": False,
|
||||
"access_groups": ["private-openai-models"],
|
||||
},
|
||||
},
|
||||
]
|
||||
router = litellm.Router(model_list=llm_model_list)
|
||||
args = {
|
||||
"model": model,
|
||||
"llm_model_list": llm_model_list,
|
||||
"valid_token": UserAPIKeyAuth(
|
||||
models=["public-openai-models"],
|
||||
),
|
||||
"llm_router": router,
|
||||
}
|
||||
if expect_to_work:
|
||||
await can_key_call_model(**args)
|
||||
else:
|
||||
with pytest.raises(Exception) as e:
|
||||
await can_key_call_model(**args)
|
||||
|
||||
print(e)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue