fix(proxy_server.py): pass model access groups to get_key/get_team mo… (#7281)

* fix(proxy_server.py): pass model access groups to get_key/get_team models

allows end user to see actual models they have access to, instead of default models

* fix(auth_checks.py): fix linting errors

* fix: fix linting errors
This commit is contained in:
Krish Dholakia 2024-12-18 09:33:33 -08:00 committed by GitHub
parent c7ff5a53d7
commit e7918f097b
9 changed files with 118 additions and 20 deletions

View file

@ -101,7 +101,7 @@ def generate_feedback_box():
print() # noqa
import pydantic
from collections import defaultdict
import litellm
from litellm import (
@ -3207,16 +3207,23 @@ async def model_list(
"""
global llm_model_list, general_settings, llm_router
all_models = []
model_access_groups: Dict[str, List[str]] = defaultdict(list)
## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
if llm_router is None:
proxy_model_list = []
else:
proxy_model_list = llm_router.get_model_names()
model_access_groups = llm_router.get_model_access_groups()
key_models = get_key_models(
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
user_api_key_dict=user_api_key_dict,
proxy_model_list=proxy_model_list,
model_access_groups=model_access_groups,
)
team_models = get_team_models(
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
user_api_key_dict=user_api_key_dict,
proxy_model_list=proxy_model_list,
model_access_groups=model_access_groups,
)
all_models = get_complete_model_list(
key_models=key_models,
@ -7136,17 +7143,22 @@ async def model_info_v1( # noqa: PLR0915
return {"data": _deployment_info_dict}
all_models: List[dict] = []
model_access_groups: Dict[str, List[str]] = defaultdict(list)
## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
if llm_router is None:
proxy_model_list = []
else:
proxy_model_list = llm_router.get_model_names()
model_access_groups = llm_router.get_model_access_groups()
key_models = get_key_models(
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
user_api_key_dict=user_api_key_dict,
proxy_model_list=proxy_model_list,
model_access_groups=model_access_groups,
)
team_models = get_team_models(
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
user_api_key_dict=user_api_key_dict,
proxy_model_list=proxy_model_list,
model_access_groups=model_access_groups,
)
all_models_str = get_complete_model_list(
key_models=key_models,
@ -7358,16 +7370,22 @@ async def model_group_info(
status_code=500, detail={"error": "LLM Router is not loaded in"}
)
## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
model_access_groups: Dict[str, List[str]] = defaultdict(list)
if llm_router is None:
proxy_model_list = []
else:
proxy_model_list = llm_router.get_model_names()
model_access_groups = llm_router.get_model_access_groups()
key_models = get_key_models(
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
user_api_key_dict=user_api_key_dict,
proxy_model_list=proxy_model_list,
model_access_groups=model_access_groups,
)
team_models = get_team_models(
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
user_api_key_dict=user_api_key_dict,
proxy_model_list=proxy_model_list,
model_access_groups=model_access_groups,
)
all_models_str = get_complete_model_list(
key_models=key_models,