Merge pull request #3839 from BerriAI/litellm_fix_models_endpoint

fix(proxy_server.py): fix model check for `/v1/models` + `/model/info` endpoint when team has restricted access
This commit is contained in:
Krish Dholakia 2024-05-25 14:23:01 -07:00 committed by GitHub
commit 02f2d67808
5 changed files with 174 additions and 46 deletions

View file

@ -111,6 +111,11 @@ from litellm.router import ModelInfo as RouterModelInfo
from litellm._logging import verbose_router_logger, verbose_proxy_logger
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.auth.litellm_license import LicenseCheck
from litellm.proxy.auth.model_checks import (
get_complete_model_list,
get_key_models,
get_team_models,
)
from litellm.proxy.hooks.prompt_injection_detection import (
_OPTIONAL_PromptInjectionDetection,
)
@ -265,10 +270,6 @@ class UserAPIKeyCacheTTLEnum(enum.Enum):
in_memory_cache_ttl = 60 # 1 min ttl ## configure via `general_settings::user_api_key_cache_ttl: <your-value>`
class SpecialModelNames(enum.Enum):
all_team_models = "all-team-models"
class CommonProxyErrors(enum.Enum):
db_not_connected_error = "DB not connected"
no_llm_router = "No models configured on proxy"
@ -3777,21 +3778,24 @@ def model_list(
):
global llm_model_list, general_settings
all_models = []
if len(user_api_key_dict.models) > 0:
all_models = user_api_key_dict.models
if SpecialModelNames.all_team_models.value in all_models:
all_models = user_api_key_dict.team_models
if len(all_models) == 0: # has all proxy models
## if no specific model access
if general_settings.get("infer_model_from_keys", False):
all_models = litellm.utils.get_valid_models()
if llm_model_list:
all_models = list(
set(all_models + [m["model_name"] for m in llm_model_list])
)
if user_model is not None:
all_models += [user_model]
verbose_proxy_logger.debug("all_models: %s", all_models)
## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
if llm_model_list is None:
proxy_model_list = []
else:
proxy_model_list = [m["model_name"] for m in llm_model_list]
key_models = get_key_models(
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
)
team_models = get_team_models(
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
)
all_models = get_complete_model_list(
key_models=key_models,
team_models=team_models,
proxy_model_list=proxy_model_list,
user_model=user_model,
infer_model_from_keys=general_settings.get("infer_model_from_keys", False),
)
return dict(
data=[
{
@ -9640,12 +9644,31 @@ async def model_info_v1(
status_code=500, detail={"error": "LLM Model List not loaded in"}
)
if len(user_api_key_dict.models) > 0:
model_names = user_api_key_dict.models
all_models: List[dict] = []
## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
if llm_model_list is None:
proxy_model_list = []
else:
proxy_model_list = [m["model_name"] for m in llm_model_list]
key_models = get_key_models(
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
)
team_models = get_team_models(
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
)
all_models_str = get_complete_model_list(
key_models=key_models,
team_models=team_models,
proxy_model_list=proxy_model_list,
user_model=user_model,
infer_model_from_keys=general_settings.get("infer_model_from_keys", False),
)
if len(all_models_str) > 0:
model_names = all_models_str
_relevant_models = [m for m in llm_model_list if m["model_name"] in model_names]
all_models = copy.deepcopy(_relevant_models)
else:
all_models = copy.deepcopy(llm_model_list)
for model in all_models:
# provided model_info in config.yaml
model_info = model.get("model_info", {})