mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(proxy_server.py): show all models user has access to in /models
This commit is contained in:
parent
658fd4de38
commit
c8dd36db9e
1 changed files with 25 additions and 9 deletions
|
@ -337,12 +337,14 @@ async def user_api_key_auth(
|
|||
access_groups.append((m["model_name"], group))
|
||||
|
||||
allowed_models = valid_token.models
|
||||
access_group_idx = set()
|
||||
if (
|
||||
len(access_groups) > 0
|
||||
): # check if token contains any model access groups
|
||||
for m in valid_token.models:
|
||||
for idx, m in enumerate(valid_token.models):
|
||||
for model_name, group in access_groups:
|
||||
if m == group:
|
||||
access_group_idx.add(idx)
|
||||
allowed_models.append(model_name)
|
||||
verbose_proxy_logger.debug(
|
||||
f"model: {model}; allowed_models: {allowed_models}"
|
||||
|
@ -351,6 +353,12 @@ async def user_api_key_auth(
|
|||
raise ValueError(
|
||||
f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}"
|
||||
)
|
||||
for val in access_group_idx:
|
||||
allowed_models.pop(val)
|
||||
valid_token.models = allowed_models
|
||||
verbose_proxy_logger.debug(
|
||||
f"filtered allowed_models: {allowed_models}; valid_token.models: {valid_token.models}"
|
||||
)
|
||||
|
||||
# Check 2. If user_id for this token is in budget
|
||||
if valid_token.user_id is not None:
|
||||
|
@ -1397,15 +1405,23 @@ async def startup_event():
|
|||
@router.get(
|
||||
"/models", dependencies=[Depends(user_api_key_auth)], tags=["model management"]
|
||||
) # if project requires model list
|
||||
def model_list():
|
||||
def model_list(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
global llm_model_list, general_settings
|
||||
all_models = []
|
||||
if general_settings.get("infer_model_from_keys", False):
|
||||
all_models = litellm.utils.get_valid_models()
|
||||
if llm_model_list:
|
||||
all_models = list(set(all_models + [m["model_name"] for m in llm_model_list]))
|
||||
if user_model is not None:
|
||||
all_models += [user_model]
|
||||
if len(user_api_key_dict.models) > 0:
|
||||
all_models = user_api_key_dict.models
|
||||
else:
|
||||
## if no specific model access
|
||||
if general_settings.get("infer_model_from_keys", False):
|
||||
all_models = litellm.utils.get_valid_models()
|
||||
if llm_model_list:
|
||||
all_models = list(
|
||||
set(all_models + [m["model_name"] for m in llm_model_list])
|
||||
)
|
||||
if user_model is not None:
|
||||
all_models += [user_model]
|
||||
verbose_proxy_logger.debug(f"all_models: {all_models}")
|
||||
### CHECK OLLAMA MODELS ###
|
||||
try:
|
||||
|
@ -2112,7 +2128,7 @@ async def delete_key_fn(request: Request, data: DeleteKeyRequest):
|
|||
"/key/info", tags=["key management"], dependencies=[Depends(user_api_key_auth)]
|
||||
)
|
||||
async def info_key_fn(
|
||||
key: str = fastapi.Query(..., description="Key in the request parameters")
|
||||
key: str = fastapi.Query(..., description="Key in the request parameters"),
|
||||
):
|
||||
global prisma_client
|
||||
try:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue