forked from phoenix/litellm-mirror
fix(auth_checks.py): handle auth checks for team based model access groups
handles scenario where model access group used for wildcard models
This commit is contained in:
parent
63a9666794
commit
aa1621757c
3 changed files with 65 additions and 19 deletions
|
@ -60,6 +60,7 @@ def common_checks( # noqa: PLR0915
|
||||||
global_proxy_spend: Optional[float],
|
global_proxy_spend: Optional[float],
|
||||||
general_settings: dict,
|
general_settings: dict,
|
||||||
route: str,
|
route: str,
|
||||||
|
llm_router: Optional[litellm.Router],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Common checks across jwt + key-based auth.
|
Common checks across jwt + key-based auth.
|
||||||
|
@ -97,7 +98,12 @@ def common_checks( # noqa: PLR0915
|
||||||
# this means the team has access to all models on the proxy
|
# this means the team has access to all models on the proxy
|
||||||
pass
|
pass
|
||||||
# check if the team model is an access_group
|
# check if the team model is an access_group
|
||||||
elif model_in_access_group(_model, team_object.models) is True:
|
elif (
|
||||||
|
model_in_access_group(
|
||||||
|
model=_model, team_models=team_object.models, llm_router=llm_router
|
||||||
|
)
|
||||||
|
is True
|
||||||
|
):
|
||||||
pass
|
pass
|
||||||
elif _model and "*" in _model:
|
elif _model and "*" in _model:
|
||||||
pass
|
pass
|
||||||
|
@ -373,36 +379,33 @@ async def get_end_user_object(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def model_in_access_group(model: str, team_models: Optional[List[str]]) -> bool:
|
def model_in_access_group(
|
||||||
|
model: str, team_models: Optional[List[str]], llm_router: Optional[litellm.Router]
|
||||||
|
) -> bool:
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
from litellm.proxy.proxy_server import llm_router
|
|
||||||
|
|
||||||
if team_models is None:
|
if team_models is None:
|
||||||
return True
|
return True
|
||||||
if model in team_models:
|
if model in team_models:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
access_groups = defaultdict(list)
|
access_groups: dict[str, list[str]] = defaultdict(list)
|
||||||
if llm_router:
|
if llm_router:
|
||||||
access_groups = llm_router.get_model_access_groups()
|
access_groups = llm_router.get_model_access_groups(model_name=model)
|
||||||
|
|
||||||
models_in_current_access_groups = []
|
|
||||||
if len(access_groups) > 0: # check if token contains any model access groups
|
if len(access_groups) > 0: # check if token contains any model access groups
|
||||||
for idx, m in enumerate(
|
for idx, m in enumerate(
|
||||||
team_models
|
team_models
|
||||||
): # loop token models, if any of them are an access group add the access group
|
): # loop token models, if any of them are an access group add the access group
|
||||||
if m in access_groups:
|
if m in access_groups:
|
||||||
# if it is an access group we need to remove it from valid_token.models
|
return True
|
||||||
models_in_group = access_groups[m]
|
|
||||||
models_in_current_access_groups.extend(models_in_group)
|
|
||||||
|
|
||||||
# Filter out models that are access_groups
|
# Filter out models that are access_groups
|
||||||
filtered_models = [m for m in team_models if m not in access_groups]
|
filtered_models = [m for m in team_models if m not in access_groups]
|
||||||
filtered_models += models_in_current_access_groups
|
|
||||||
|
|
||||||
if model in filtered_models:
|
if model in filtered_models:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@ -909,9 +912,7 @@ async def can_key_call_model(
|
||||||
valid_token.models
|
valid_token.models
|
||||||
): # loop token models, if any of them are an access group add the access group
|
): # loop token models, if any of them are an access group add the access group
|
||||||
if m in access_groups:
|
if m in access_groups:
|
||||||
# if it is an access group we need to remove it from valid_token.models
|
return True
|
||||||
models_in_group = access_groups[m]
|
|
||||||
models_in_current_access_groups.extend(models_in_group)
|
|
||||||
|
|
||||||
# Filter out models that are access_groups
|
# Filter out models that are access_groups
|
||||||
filtered_models = [m for m in valid_token.models if m not in access_groups]
|
filtered_models = [m for m in valid_token.models if m not in access_groups]
|
||||||
|
@ -925,11 +926,6 @@ async def can_key_call_model(
|
||||||
len(filtered_models) == 0 and len(valid_token.models) == 0
|
len(filtered_models) == 0 and len(valid_token.models) == 0
|
||||||
) or "*" in filtered_models:
|
) or "*" in filtered_models:
|
||||||
all_model_access = True
|
all_model_access = True
|
||||||
elif (
|
|
||||||
llm_router is not None
|
|
||||||
and llm_router.pattern_router.route(model, filtered_models) is not None
|
|
||||||
): # wildcard access
|
|
||||||
all_model_access = True
|
|
||||||
|
|
||||||
if model is not None and model not in filtered_models and all_model_access is False:
|
if model is not None and model not in filtered_models and all_model_access is False:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
@ -543,6 +543,7 @@ async def user_api_key_auth( # noqa: PLR0915
|
||||||
general_settings=general_settings,
|
general_settings=general_settings,
|
||||||
global_proxy_spend=global_proxy_spend,
|
global_proxy_spend=global_proxy_spend,
|
||||||
route=route,
|
route=route,
|
||||||
|
llm_router=llm_router,
|
||||||
)
|
)
|
||||||
|
|
||||||
# return UserAPIKeyAuth object
|
# return UserAPIKeyAuth object
|
||||||
|
@ -1176,6 +1177,7 @@ async def user_api_key_auth( # noqa: PLR0915
|
||||||
general_settings=general_settings,
|
general_settings=general_settings,
|
||||||
global_proxy_spend=global_proxy_spend,
|
global_proxy_spend=global_proxy_spend,
|
||||||
route=route,
|
route=route,
|
||||||
|
llm_router=llm_router,
|
||||||
)
|
)
|
||||||
# Token passed all checks
|
# Token passed all checks
|
||||||
if valid_token is None:
|
if valid_token is None:
|
||||||
|
|
|
@ -151,3 +151,51 @@ async def test_can_key_call_model(model, expect_to_work):
|
||||||
await can_key_call_model(**args)
|
await can_key_call_model(**args)
|
||||||
|
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model, expect_to_work",
|
||||||
|
[("openai/gpt-4o", False), ("openai/gpt-4o-mini", True)],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_can_team_call_model(model, expect_to_work):
|
||||||
|
from litellm.proxy.auth.auth_checks import model_in_access_group
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
llm_model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "openai/*",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "openai/*",
|
||||||
|
"api_key": "test-api-key",
|
||||||
|
},
|
||||||
|
"model_info": {
|
||||||
|
"id": "e6e7006f83029df40ebc02ddd068890253f4cd3092bcb203d3d8e6f6f606f30f",
|
||||||
|
"db_model": False,
|
||||||
|
"access_groups": ["public-openai-models"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "openai/gpt-4o",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "openai/gpt-4o",
|
||||||
|
"api_key": "test-api-key",
|
||||||
|
},
|
||||||
|
"model_info": {
|
||||||
|
"id": "0cfcd87f2cb12a783a466888d05c6c89df66db23e01cecd75ec0b83aed73c9ad",
|
||||||
|
"db_model": False,
|
||||||
|
"access_groups": ["private-openai-models"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
router = litellm.Router(model_list=llm_model_list)
|
||||||
|
|
||||||
|
args = {
|
||||||
|
"model": model,
|
||||||
|
"team_models": ["public-openai-models"],
|
||||||
|
"llm_router": router,
|
||||||
|
}
|
||||||
|
if expect_to_work:
|
||||||
|
assert model_in_access_group(**args)
|
||||||
|
else:
|
||||||
|
assert not model_in_access_group(**args)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue