mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 19:54:13 +00:00
fix(proxy_server.py): show all models user has access to in /models
This commit is contained in:
parent
9bd945d496
commit
c00117679b
1 changed files with 25 additions and 9 deletions
|
@ -337,12 +337,14 @@ async def user_api_key_auth(
|
||||||
access_groups.append((m["model_name"], group))
|
access_groups.append((m["model_name"], group))
|
||||||
|
|
||||||
allowed_models = valid_token.models
|
allowed_models = valid_token.models
|
||||||
|
access_group_idx = set()
|
||||||
if (
|
if (
|
||||||
len(access_groups) > 0
|
len(access_groups) > 0
|
||||||
): # check if token contains any model access groups
|
): # check if token contains any model access groups
|
||||||
for m in valid_token.models:
|
for idx, m in enumerate(valid_token.models):
|
||||||
for model_name, group in access_groups:
|
for model_name, group in access_groups:
|
||||||
if m == group:
|
if m == group:
|
||||||
|
access_group_idx.add(idx)
|
||||||
allowed_models.append(model_name)
|
allowed_models.append(model_name)
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"model: {model}; allowed_models: {allowed_models}"
|
f"model: {model}; allowed_models: {allowed_models}"
|
||||||
|
@ -351,6 +353,12 @@ async def user_api_key_auth(
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}"
|
f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}"
|
||||||
)
|
)
|
||||||
|
for val in access_group_idx:
|
||||||
|
allowed_models.pop(val)
|
||||||
|
valid_token.models = allowed_models
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"filtered allowed_models: {allowed_models}; valid_token.models: {valid_token.models}"
|
||||||
|
)
|
||||||
|
|
||||||
# Check 2. If user_id for this token is in budget
|
# Check 2. If user_id for this token is in budget
|
||||||
if valid_token.user_id is not None:
|
if valid_token.user_id is not None:
|
||||||
|
@ -1397,15 +1405,23 @@ async def startup_event():
|
||||||
@router.get(
|
@router.get(
|
||||||
"/models", dependencies=[Depends(user_api_key_auth)], tags=["model management"]
|
"/models", dependencies=[Depends(user_api_key_auth)], tags=["model management"]
|
||||||
) # if project requires model list
|
) # if project requires model list
|
||||||
def model_list():
|
def model_list(
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
):
|
||||||
global llm_model_list, general_settings
|
global llm_model_list, general_settings
|
||||||
all_models = []
|
all_models = []
|
||||||
if general_settings.get("infer_model_from_keys", False):
|
if len(user_api_key_dict.models) > 0:
|
||||||
all_models = litellm.utils.get_valid_models()
|
all_models = user_api_key_dict.models
|
||||||
if llm_model_list:
|
else:
|
||||||
all_models = list(set(all_models + [m["model_name"] for m in llm_model_list]))
|
## if no specific model access
|
||||||
if user_model is not None:
|
if general_settings.get("infer_model_from_keys", False):
|
||||||
all_models += [user_model]
|
all_models = litellm.utils.get_valid_models()
|
||||||
|
if llm_model_list:
|
||||||
|
all_models = list(
|
||||||
|
set(all_models + [m["model_name"] for m in llm_model_list])
|
||||||
|
)
|
||||||
|
if user_model is not None:
|
||||||
|
all_models += [user_model]
|
||||||
verbose_proxy_logger.debug(f"all_models: {all_models}")
|
verbose_proxy_logger.debug(f"all_models: {all_models}")
|
||||||
### CHECK OLLAMA MODELS ###
|
### CHECK OLLAMA MODELS ###
|
||||||
try:
|
try:
|
||||||
|
@ -2112,7 +2128,7 @@ async def delete_key_fn(request: Request, data: DeleteKeyRequest):
|
||||||
"/key/info", tags=["key management"], dependencies=[Depends(user_api_key_auth)]
|
"/key/info", tags=["key management"], dependencies=[Depends(user_api_key_auth)]
|
||||||
)
|
)
|
||||||
async def info_key_fn(
|
async def info_key_fn(
|
||||||
key: str = fastapi.Query(..., description="Key in the request parameters")
|
key: str = fastapi.Query(..., description="Key in the request parameters"),
|
||||||
):
|
):
|
||||||
global prisma_client
|
global prisma_client
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue