litellm-mirror/litellm/proxy/auth/model_checks.py

76 lines
2.4 KiB
Python

# What is this?
## Common checks for /v1/models and `/model/info`
from typing import List, Optional
from litellm.proxy._types import UserAPIKeyAuth, SpecialModelNames
from litellm.utils import get_valid_models
from litellm._logging import verbose_proxy_logger
def get_key_models(
user_api_key_dict: UserAPIKeyAuth, proxy_model_list: List[str]
) -> List[str]:
"""
Returns:
- List of model name strings
- Empty list if no models set
"""
all_models = []
if len(user_api_key_dict.models) > 0:
all_models = user_api_key_dict.models
if SpecialModelNames.all_team_models.value in all_models:
all_models = user_api_key_dict.team_models
if SpecialModelNames.all_proxy_models.value in all_models:
all_models = proxy_model_list
verbose_proxy_logger.debug("ALL KEY MODELS - {}".format(len(all_models)))
return all_models
def get_team_models(
user_api_key_dict: UserAPIKeyAuth, proxy_model_list: List[str]
) -> List[str]:
"""
Returns:
- List of model name strings
- Empty list if no models set
"""
all_models = []
if len(user_api_key_dict.team_models) > 0:
all_models = user_api_key_dict.team_models
if SpecialModelNames.all_team_models.value in all_models:
all_models = user_api_key_dict.team_models
if SpecialModelNames.all_proxy_models.value in all_models:
all_models = proxy_model_list
verbose_proxy_logger.debug("ALL TEAM MODELS - {}".format(len(all_models)))
return all_models
def get_complete_model_list(
key_models: List[str],
team_models: List[str],
proxy_model_list: List[str],
user_model: Optional[str],
infer_model_from_keys: Optional[bool],
) -> List[str]:
"""Logic for returning complete model list for a given key + team pair"""
"""
- If key list is empty -> defer to team list
- If team list is empty -> defer to proxy model list
"""
if len(key_models) > 0:
return key_models
if len(team_models) > 0:
return team_models
returned_models = proxy_model_list
if user_model is not None: # set via `litellm --model ollama/llam3`
returned_models.append(user_model)
if infer_model_from_keys is not None and infer_model_from_keys == True:
valid_models = get_valid_models()
returned_models.extend(valid_models)
return returned_models