From 95566dc249dfb00569cf0f5943708eac6c60923c Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 25 May 2024 13:21:33 -0700 Subject: [PATCH] fix(proxy_server.py): fix model list returned for `/model/info` when team has restricted access --- litellm/proxy/proxy_server.py | 27 +++++++++++++++++++++++---- tests/test_keys.py | 35 +++++++++++++++++++++++------------ 2 files changed, 46 insertions(+), 16 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 2e301fe8b..bf4d825be 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -9395,12 +9395,31 @@ async def model_info_v1( status_code=500, detail={"error": "LLM Model List not loaded in"} ) - if len(user_api_key_dict.models) > 0: - model_names = user_api_key_dict.models + all_models: List[dict] = [] + ## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ## + if llm_model_list is None: + proxy_model_list = [] + else: + proxy_model_list = [m["model_name"] for m in llm_model_list] + key_models = get_key_models( + user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list + ) + team_models = get_team_models( + user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list + ) + all_models_str = get_complete_model_list( + key_models=key_models, + team_models=team_models, + proxy_model_list=proxy_model_list, + user_model=user_model, + infer_model_from_keys=general_settings.get("infer_model_from_keys", False), + ) + + if len(all_models_str) > 0: + model_names = all_models_str _relevant_models = [m for m in llm_model_list if m["model_name"] in model_names] all_models = copy.deepcopy(_relevant_models) - else: - all_models = copy.deepcopy(llm_model_list) + for model in all_models: # provided model_info in config.yaml model_info = model.get("model_info", {}) diff --git a/tests/test_keys.py b/tests/test_keys.py index f7256e60f..11961e2a2 100644 --- a/tests/test_keys.py +++ b/tests/test_keys.py @@ -359,11 +359,11 @@ async def get_key_info(session, call_key, get_key=None): return await response.json() -async def get_model_list(session, call_key): +async def get_model_list(session, call_key, endpoint: str = "/v1/models"): """ Make sure only models user has access to are returned """ - url = "http://0.0.0.0:4000/v1/models" + url = "http://0.0.0.0:4000" + endpoint headers = { "Authorization": f"Bearer {call_key}", "Content-Type": "application/json", @@ -749,8 +749,9 @@ async def test_key_delete_ui(): @pytest.mark.parametrize("model_access", ["all-team-models", "gpt-3.5-turbo"]) @pytest.mark.parametrize("model_access_level", ["key", "team"]) +@pytest.mark.parametrize("model_endpoint", ["/v1/models", "/model/info"]) @pytest.mark.asyncio -async def test_key_model_list(model_access, model_access_level): +async def test_key_model_list(model_access, model_access_level, model_endpoint): """ Test if `/v1/models` works as expected. """ @@ -771,16 +772,26 @@ async def test_key_model_list(model_access, model_access_level): key = key_gen["key"] print(f"key: {key}") - model_list = await get_model_list(session=session, call_key=key) + model_list = await get_model_list( + session=session, call_key=key, endpoint=model_endpoint + ) print(f"model_list: {model_list}") if model_access == "all-team-models": - assert not isinstance(model_list["data"][0]["id"], list) - assert isinstance(model_list["data"][0]["id"], str) + if model_endpoint == "/v1/models": + assert not isinstance(model_list["data"][0]["id"], list) + assert isinstance(model_list["data"][0]["id"], str) + elif model_endpoint == "/model/info": + assert isinstance(model_list["data"], list) + assert len(model_list["data"]) > 0 if model_access == "gpt-3.5-turbo": - assert ( - len(model_list["data"]) == 1 - ), "model_access={}, model_access_level={}".format( - model_access, model_access_level - ) - assert model_list["data"][0]["id"] == model_access + if model_endpoint == "/v1/models": + assert ( + len(model_list["data"]) == 1 + ), "model_access={}, model_access_level={}".format( + model_access, model_access_level + ) + assert model_list["data"][0]["id"] == model_access + elif model_endpoint == "/model/info": + assert isinstance(model_list["data"], list) + assert len(model_list["data"]) == 1