diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index d3667892bb..cde397d387 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -337,12 +337,14 @@ async def user_api_key_auth( access_groups.append((m["model_name"], group)) allowed_models = valid_token.models + access_group_idx = set() if ( len(access_groups) > 0 ): # check if token contains any model access groups - for m in valid_token.models: + for idx, m in enumerate(valid_token.models): for model_name, group in access_groups: if m == group: + access_group_idx.add(idx) allowed_models.append(model_name) verbose_proxy_logger.debug( f"model: {model}; allowed_models: {allowed_models}" @@ -351,6 +353,12 @@ async def user_api_key_auth( raise ValueError( f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}" ) + for val in access_group_idx: + allowed_models.pop(val) + valid_token.models = allowed_models + verbose_proxy_logger.debug( + f"filtered allowed_models: {allowed_models}; valid_token.models: {valid_token.models}" + ) # Check 2. If user_id for this token is in budget if valid_token.user_id is not None: @@ -1397,15 +1405,23 @@ async def startup_event(): @router.get( "/models", dependencies=[Depends(user_api_key_auth)], tags=["model management"] ) # if project requires model list -def model_list(): +def model_list( + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): global llm_model_list, general_settings all_models = [] - if general_settings.get("infer_model_from_keys", False): - all_models = litellm.utils.get_valid_models() - if llm_model_list: - all_models = list(set(all_models + [m["model_name"] for m in llm_model_list])) - if user_model is not None: - all_models += [user_model] + if len(user_api_key_dict.models) > 0: + all_models = user_api_key_dict.models + else: + ## if no specific model access + if general_settings.get("infer_model_from_keys", False): + all_models = litellm.utils.get_valid_models() + if llm_model_list: + all_models = list( + set(all_models + [m["model_name"] for m in llm_model_list]) + ) + if user_model is not None: + all_models += [user_model] verbose_proxy_logger.debug(f"all_models: {all_models}") ### CHECK OLLAMA MODELS ### try: @@ -2112,7 +2128,7 @@ async def delete_key_fn(request: Request, data: DeleteKeyRequest): "/key/info", tags=["key management"], dependencies=[Depends(user_api_key_auth)] ) async def info_key_fn( - key: str = fastapi.Query(..., description="Key in the request parameters") + key: str = fastapi.Query(..., description="Key in the request parameters"), ): global prisma_client try: