forked from phoenix/litellm-mirror
feat(auth_checks.py): ensure specific model access > wildcard model access
if wildcard model is in access group, but specific model is not - deny access
This commit is contained in:
parent
a014168c0c
commit
63a9666794
7 changed files with 118 additions and 20 deletions
|
@ -15,6 +15,22 @@ model_list:
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: openai/gpt-4o-realtime-preview-2024-10-01
|
model: openai/gpt-4o-realtime-preview-2024-10-01
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
|
- model_name: openai/*
|
||||||
|
litellm_params:
|
||||||
|
model: openai/*
|
||||||
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
|
- model_name: openai/*
|
||||||
|
litellm_params:
|
||||||
|
model: openai/*
|
||||||
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
|
model_info:
|
||||||
|
access_groups: ["public-openai-models"]
|
||||||
|
- model_name: openai/gpt-4o
|
||||||
|
litellm_params:
|
||||||
|
model: openai/gpt-4o
|
||||||
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
|
model_info:
|
||||||
|
access_groups: ["private-openai-models"]
|
||||||
|
|
||||||
router_settings:
|
router_settings:
|
||||||
routing_strategy: usage-based-routing-v2
|
routing_strategy: usage-based-routing-v2
|
||||||
|
|
|
@ -523,10 +523,6 @@ async def _cache_management_object(
|
||||||
proxy_logging_obj: Optional[ProxyLogging],
|
proxy_logging_obj: Optional[ProxyLogging],
|
||||||
):
|
):
|
||||||
await user_api_key_cache.async_set_cache(key=key, value=value)
|
await user_api_key_cache.async_set_cache(key=key, value=value)
|
||||||
if proxy_logging_obj is not None:
|
|
||||||
await proxy_logging_obj.internal_usage_cache.dual_cache.async_set_cache(
|
|
||||||
key=key, value=value
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _cache_team_object(
|
async def _cache_team_object(
|
||||||
|
@ -878,7 +874,10 @@ async def get_org_object(
|
||||||
|
|
||||||
|
|
||||||
async def can_key_call_model(
|
async def can_key_call_model(
|
||||||
model: str, llm_model_list: Optional[list], valid_token: UserAPIKeyAuth
|
model: str,
|
||||||
|
llm_model_list: Optional[list],
|
||||||
|
valid_token: UserAPIKeyAuth,
|
||||||
|
llm_router: Optional[litellm.Router],
|
||||||
) -> Literal[True]:
|
) -> Literal[True]:
|
||||||
"""
|
"""
|
||||||
Checks if token can call a given model
|
Checks if token can call a given model
|
||||||
|
@ -898,14 +897,14 @@ async def can_key_call_model(
|
||||||
)
|
)
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
from litellm.proxy.proxy_server import llm_router
|
|
||||||
|
|
||||||
access_groups = defaultdict(list)
|
access_groups = defaultdict(list)
|
||||||
if llm_router:
|
if llm_router:
|
||||||
access_groups = llm_router.get_model_access_groups()
|
access_groups = llm_router.get_model_access_groups(model_name=model)
|
||||||
|
|
||||||
models_in_current_access_groups = []
|
models_in_current_access_groups = []
|
||||||
if len(access_groups) > 0: # check if token contains any model access groups
|
if (
|
||||||
|
len(access_groups) > 0 and llm_router is not None
|
||||||
|
): # check if token contains any model access groups
|
||||||
for idx, m in enumerate(
|
for idx, m in enumerate(
|
||||||
valid_token.models
|
valid_token.models
|
||||||
): # loop token models, if any of them are an access group add the access group
|
): # loop token models, if any of them are an access group add the access group
|
||||||
|
@ -923,10 +922,13 @@ async def can_key_call_model(
|
||||||
all_model_access: bool = False
|
all_model_access: bool = False
|
||||||
|
|
||||||
if (
|
if (
|
||||||
len(filtered_models) == 0
|
len(filtered_models) == 0 and len(valid_token.models) == 0
|
||||||
or "*" in filtered_models
|
) or "*" in filtered_models:
|
||||||
or "openai/*" in filtered_models
|
all_model_access = True
|
||||||
):
|
elif (
|
||||||
|
llm_router is not None
|
||||||
|
and llm_router.pattern_router.route(model, filtered_models) is not None
|
||||||
|
): # wildcard access
|
||||||
all_model_access = True
|
all_model_access = True
|
||||||
|
|
||||||
if model is not None and model not in filtered_models and all_model_access is False:
|
if model is not None and model not in filtered_models and all_model_access is False:
|
||||||
|
|
|
@ -259,6 +259,7 @@ async def user_api_key_auth( # noqa: PLR0915
|
||||||
jwt_handler,
|
jwt_handler,
|
||||||
litellm_proxy_admin_name,
|
litellm_proxy_admin_name,
|
||||||
llm_model_list,
|
llm_model_list,
|
||||||
|
llm_router,
|
||||||
master_key,
|
master_key,
|
||||||
open_telemetry_logger,
|
open_telemetry_logger,
|
||||||
prisma_client,
|
prisma_client,
|
||||||
|
@ -905,6 +906,7 @@ async def user_api_key_auth( # noqa: PLR0915
|
||||||
model=model,
|
model=model,
|
||||||
llm_model_list=llm_model_list,
|
llm_model_list=llm_model_list,
|
||||||
valid_token=valid_token,
|
valid_token=valid_token,
|
||||||
|
llm_router=llm_router,
|
||||||
)
|
)
|
||||||
|
|
||||||
if fallback_models is not None:
|
if fallback_models is not None:
|
||||||
|
@ -913,6 +915,7 @@ async def user_api_key_auth( # noqa: PLR0915
|
||||||
model=m,
|
model=m,
|
||||||
llm_model_list=llm_model_list,
|
llm_model_list=llm_model_list,
|
||||||
valid_token=valid_token,
|
valid_token=valid_token,
|
||||||
|
llm_router=llm_router,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check 2. If user_id for this token is in budget - done in common_checks()
|
# Check 2. If user_id for this token is in budget - done in common_checks()
|
||||||
|
|
|
@ -4712,6 +4712,9 @@ class Router:
|
||||||
if hasattr(self, "model_list"):
|
if hasattr(self, "model_list"):
|
||||||
returned_models: List[DeploymentTypedDict] = []
|
returned_models: List[DeploymentTypedDict] = []
|
||||||
|
|
||||||
|
if model_name is not None:
|
||||||
|
returned_models.extend(self._get_all_deployments(model_name=model_name))
|
||||||
|
|
||||||
if hasattr(self, "model_group_alias"):
|
if hasattr(self, "model_group_alias"):
|
||||||
for model_alias, model_value in self.model_group_alias.items():
|
for model_alias, model_value in self.model_group_alias.items():
|
||||||
|
|
||||||
|
@ -4743,17 +4746,21 @@ class Router:
|
||||||
returned_models += self.model_list
|
returned_models += self.model_list
|
||||||
|
|
||||||
return returned_models
|
return returned_models
|
||||||
returned_models.extend(self._get_all_deployments(model_name=model_name))
|
|
||||||
return returned_models
|
return returned_models
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_model_access_groups(self):
|
def get_model_access_groups(self, model_name: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
If model_name is provided, only return access groups for that model.
|
||||||
|
"""
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
access_groups = defaultdict(list)
|
access_groups = defaultdict(list)
|
||||||
|
|
||||||
if self.model_list:
|
model_list = self.get_model_list(model_name=model_name)
|
||||||
for m in self.model_list:
|
if model_list:
|
||||||
|
for m in model_list:
|
||||||
for group in m.get("model_info", {}).get("access_groups", []):
|
for group in m.get("model_info", {}).get("access_groups", []):
|
||||||
model_name = m["model_name"]
|
model_name = m["model_name"]
|
||||||
access_groups[group].append(model_name)
|
access_groups[group].append(model_name)
|
||||||
|
|
|
@ -79,7 +79,9 @@ class PatternMatchRouter:
|
||||||
|
|
||||||
return new_deployments
|
return new_deployments
|
||||||
|
|
||||||
def route(self, request: Optional[str]) -> Optional[List[Dict]]:
|
def route(
|
||||||
|
self, request: Optional[str], filtered_model_names: Optional[List[str]] = None
|
||||||
|
) -> Optional[List[Dict]]:
|
||||||
"""
|
"""
|
||||||
Route a requested model to the corresponding llm deployments based on the regex pattern
|
Route a requested model to the corresponding llm deployments based on the regex pattern
|
||||||
|
|
||||||
|
@ -89,14 +91,26 @@ class PatternMatchRouter:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: Optional[str]
|
request: Optional[str]
|
||||||
|
filtered_model_names: Optional[List[str]] - if provided, only return deployments that match the filtered_model_names
|
||||||
Returns:
|
Returns:
|
||||||
Optional[List[Deployment]]: llm deployments
|
Optional[List[Deployment]]: llm deployments
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if request is None:
|
if request is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
regex_filtered_model_names = (
|
||||||
|
[self._pattern_to_regex(m) for m in filtered_model_names]
|
||||||
|
if filtered_model_names is not None
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
|
||||||
for pattern, llm_deployments in self.patterns.items():
|
for pattern, llm_deployments in self.patterns.items():
|
||||||
|
if (
|
||||||
|
filtered_model_names is not None
|
||||||
|
and pattern not in regex_filtered_model_names
|
||||||
|
):
|
||||||
|
continue
|
||||||
pattern_match = re.match(pattern, request)
|
pattern_match = re.match(pattern, request)
|
||||||
if pattern_match:
|
if pattern_match:
|
||||||
return self._return_pattern_matched_deployments(
|
return self._return_pattern_matched_deployments(
|
||||||
|
|
|
@ -355,7 +355,7 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
|
||||||
class DeploymentTypedDict(TypedDict, total=False):
|
class DeploymentTypedDict(TypedDict, total=False):
|
||||||
model_name: Required[str]
|
model_name: Required[str]
|
||||||
litellm_params: Required[LiteLLMParamsTypedDict]
|
litellm_params: Required[LiteLLMParamsTypedDict]
|
||||||
model_info: Optional[dict]
|
model_info: dict
|
||||||
|
|
||||||
|
|
||||||
SPECIAL_MODEL_INFO_PARAMS = [
|
SPECIAL_MODEL_INFO_PARAMS = [
|
||||||
|
|
|
@ -95,3 +95,59 @@ async def test_handle_failed_db_connection():
|
||||||
print("_handle_failed_db_connection_for_get_key_object got exception", exc_info)
|
print("_handle_failed_db_connection_for_get_key_object got exception", exc_info)
|
||||||
|
|
||||||
assert str(exc_info.value) == "Failed to connect to DB"
|
assert str(exc_info.value) == "Failed to connect to DB"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model, expect_to_work",
|
||||||
|
[("openai/gpt-4o-mini", True), ("openai/gpt-4o", False)],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_can_key_call_model(model, expect_to_work):
|
||||||
|
"""
|
||||||
|
If wildcard model + specific model is used, choose the specific model settings
|
||||||
|
"""
|
||||||
|
from litellm.proxy.auth.auth_checks import can_key_call_model
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
llm_model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "openai/*",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "openai/*",
|
||||||
|
"api_key": "test-api-key",
|
||||||
|
},
|
||||||
|
"model_info": {
|
||||||
|
"id": "e6e7006f83029df40ebc02ddd068890253f4cd3092bcb203d3d8e6f6f606f30f",
|
||||||
|
"db_model": False,
|
||||||
|
"access_groups": ["public-openai-models"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "openai/gpt-4o",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "openai/gpt-4o",
|
||||||
|
"api_key": "test-api-key",
|
||||||
|
},
|
||||||
|
"model_info": {
|
||||||
|
"id": "0cfcd87f2cb12a783a466888d05c6c89df66db23e01cecd75ec0b83aed73c9ad",
|
||||||
|
"db_model": False,
|
||||||
|
"access_groups": ["private-openai-models"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
router = litellm.Router(model_list=llm_model_list)
|
||||||
|
args = {
|
||||||
|
"model": model,
|
||||||
|
"llm_model_list": llm_model_list,
|
||||||
|
"valid_token": UserAPIKeyAuth(
|
||||||
|
models=["public-openai-models"],
|
||||||
|
),
|
||||||
|
"llm_router": router,
|
||||||
|
}
|
||||||
|
if expect_to_work:
|
||||||
|
await can_key_call_model(**args)
|
||||||
|
else:
|
||||||
|
with pytest.raises(Exception) as e:
|
||||||
|
await can_key_call_model(**args)
|
||||||
|
|
||||||
|
print(e)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue