fix(proxy_server.py): show all models user has access to in /models

This commit is contained in:
Krrish Dholakia 2024-01-18 10:56:24 -08:00
parent 9bd945d496
commit c00117679b

View file

@ -337,12 +337,14 @@ async def user_api_key_auth(
access_groups.append((m["model_name"], group)) access_groups.append((m["model_name"], group))
allowed_models = valid_token.models allowed_models = valid_token.models
access_group_idx = set()
if ( if (
len(access_groups) > 0 len(access_groups) > 0
): # check if token contains any model access groups ): # check if token contains any model access groups
for m in valid_token.models: for idx, m in enumerate(valid_token.models):
for model_name, group in access_groups: for model_name, group in access_groups:
if m == group: if m == group:
access_group_idx.add(idx)
allowed_models.append(model_name) allowed_models.append(model_name)
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"model: {model}; allowed_models: {allowed_models}" f"model: {model}; allowed_models: {allowed_models}"
@ -351,6 +353,12 @@ async def user_api_key_auth(
raise ValueError( raise ValueError(
f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}" f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}"
) )
for val in access_group_idx:
allowed_models.pop(val)
valid_token.models = allowed_models
verbose_proxy_logger.debug(
f"filtered allowed_models: {allowed_models}; valid_token.models: {valid_token.models}"
)
# Check 2. If user_id for this token is in budget # Check 2. If user_id for this token is in budget
if valid_token.user_id is not None: if valid_token.user_id is not None:
@ -1397,15 +1405,23 @@ async def startup_event():
@router.get( @router.get(
"/models", dependencies=[Depends(user_api_key_auth)], tags=["model management"] "/models", dependencies=[Depends(user_api_key_auth)], tags=["model management"]
) # if project requires model list ) # if project requires model list
def model_list(): def model_list(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
global llm_model_list, general_settings global llm_model_list, general_settings
all_models = [] all_models = []
if general_settings.get("infer_model_from_keys", False): if len(user_api_key_dict.models) > 0:
all_models = litellm.utils.get_valid_models() all_models = user_api_key_dict.models
if llm_model_list: else:
all_models = list(set(all_models + [m["model_name"] for m in llm_model_list])) ## if no specific model access
if user_model is not None: if general_settings.get("infer_model_from_keys", False):
all_models += [user_model] all_models = litellm.utils.get_valid_models()
if llm_model_list:
all_models = list(
set(all_models + [m["model_name"] for m in llm_model_list])
)
if user_model is not None:
all_models += [user_model]
verbose_proxy_logger.debug(f"all_models: {all_models}") verbose_proxy_logger.debug(f"all_models: {all_models}")
### CHECK OLLAMA MODELS ### ### CHECK OLLAMA MODELS ###
try: try:
@ -2112,7 +2128,7 @@ async def delete_key_fn(request: Request, data: DeleteKeyRequest):
"/key/info", tags=["key management"], dependencies=[Depends(user_api_key_auth)] "/key/info", tags=["key management"], dependencies=[Depends(user_api_key_auth)]
) )
async def info_key_fn( async def info_key_fn(
key: str = fastapi.Query(..., description="Key in the request parameters") key: str = fastapi.Query(..., description="Key in the request parameters"),
): ):
global prisma_client global prisma_client
try: try: