Merge pull request #2066 from BerriAI/litellm_show_if_user_has_model_access

[FEAT] /model/info show models user has access to
This commit is contained in:
Ishaan Jaff 2024-02-19 15:00:17 -08:00 committed by GitHub
commit f8a204c101
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 105 additions and 5 deletions

View file

@ -727,6 +727,7 @@ async def user_api_key_auth(
"/spend",
"/user",
"/model/info",
"/v2/model/info",
]
# check if the current route startswith any of the allowed routes
if (
@ -4328,6 +4329,82 @@ async def add_new_model(model_params: ModelParams):
)
@router.get(
"/v2/model/info",
description="v2 - returns all the models set on the config.yaml, shows 'user_access' = True if the user has access to the model. Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)",
tags=["model management"],
dependencies=[Depends(user_api_key_auth)],
)
async def model_info_v2(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
global llm_model_list, general_settings, user_config_file_path, proxy_config
# Load existing config
config = await proxy_config.get_config()
all_models = config.get("model_list", [])
if user_model is not None:
# if user does not use a config.yaml, https://github.com/BerriAI/litellm/issues/2061
all_models += [user_model]
# check all models user has access to in user_api_key_dict
user_models = []
if len(user_api_key_dict.models) > 0:
user_models = user_api_key_dict.models
# for all models check if the user has access, and mark it as "user_access": `True` or `False`
for model in all_models:
model_name = model.get("model_name", None)
if model_name is not None:
user_has_access = model_name in user_models
if (
user_models == []
): # if user_api_key_dict.models == [], user has access to all models
user_has_access = True
model["user_access"] = user_has_access
# fill in model info based on config.yaml and litellm model_prices_and_context_window.json
for model in all_models:
# provided model_info in config.yaml
model_info = model.get("model_info", {})
# read litellm model_prices_and_context_window.json to get the following:
# input_cost_per_token, output_cost_per_token, max_tokens
litellm_model_info = get_litellm_model_info(model=model)
# 2nd pass on the model, try seeing if we can find model in litellm model_cost map
if litellm_model_info == {}:
# use litellm_param model_name to get model_info
litellm_params = model.get("litellm_params", {})
litellm_model = litellm_params.get("model", None)
try:
litellm_model_info = litellm.get_model_info(model=litellm_model)
except:
litellm_model_info = {}
# 3rd pass on the model, try seeing if we can find model but without the "/" in model cost map
if litellm_model_info == {}:
# use litellm_param model_name to get model_info
litellm_params = model.get("litellm_params", {})
litellm_model = litellm_params.get("model", None)
split_model = litellm_model.split("/")
if len(split_model) > 0:
litellm_model = split_model[-1]
try:
litellm_model_info = litellm.get_model_info(model=litellm_model)
except:
litellm_model_info = {}
for k, v in litellm_model_info.items():
if k not in model_info:
model_info[k] = v
model["model_info"] = model_info
# don't return the api key
model["litellm_params"].pop("api_key", None)
verbose_proxy_logger.debug(f"all_models: {all_models}")
return {"data": all_models}
@router.get(
"/model/info",
description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)",
@ -4981,8 +5058,25 @@ async def auth_callback(request: Request):
if user_id is None:
user_id = getattr(result, "first_name", "") + getattr(result, "last_name", "")
# get user_info from litellm DB
user_info = None
if prisma_client is not None:
user_info = await prisma_client.get_data(user_id=user_id, table_name="user")
if user_info is not None:
user_id_models = getattr(user_info, "models", [])
response = await generate_key_helper_fn(
**{"duration": "1hr", "key_max_budget": 0.01, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_id, "team_id": "litellm-dashboard", "user_email": user_email} # type: ignore
**{
"duration": "1hr",
"key_max_budget": 0.01,
"models": user_id_models,
"aliases": {},
"config": {},
"spend": 0,
"user_id": user_id,
"team_id": "litellm-dashboard",
"user_email": user_email,
} # type: ignore
)
key = response["token"] # type: ignore
user_id = response["user_id"] # type: ignore