From f2766fddbfc743ab8be657694118d830551cb2bb Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 7 May 2024 13:42:58 -0700 Subject: [PATCH] fix(proxy_server.py): fix `/v1/models` bug where it would return empty list handle 'all-team-models' being set for a given key --- litellm/proxy/proxy_server.py | 8 +++++- tests/test_keys.py | 54 +++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index c22b381e2..bb0734e1d 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -227,6 +227,10 @@ class UserAPIKeyCacheTTLEnum(enum.Enum): global_proxy_spend = 60 +class SpecialModelNames(enum.Enum): + all_team_models = "all-team-models" + + @app.exception_handler(ProxyException) async def openai_exception_handler(request: Request, exc: ProxyException): # NOTE: DO NOT MODIFY THIS, its crucial to map to Openai exceptions @@ -3436,7 +3440,9 @@ def model_list( all_models = [] if len(user_api_key_dict.models) > 0: all_models = user_api_key_dict.models - else: + if SpecialModelNames.all_team_models.value in all_models: + all_models = user_api_key_dict.team_models + if len(all_models) == 0: # has all proxy models ## if no specific model access if general_settings.get("infer_model_from_keys", False): all_models = litellm.utils.get_valid_models() diff --git a/tests/test_keys.py b/tests/test_keys.py index 3af7dea66..89dadce36 100644 --- a/tests/test_keys.py +++ b/tests/test_keys.py @@ -62,6 +62,7 @@ async def generate_key( models=["azure-models", "gpt-4", "dall-e-3"], max_parallel_requests: Optional[int] = None, user_id: Optional[str] = None, + team_id: Optional[str] = None, calling_key="sk-1234", ): url = "http://0.0.0.0:4000/key/generate" @@ -77,6 +78,7 @@ async def generate_key( "budget_duration": budget_duration, "max_parallel_requests": max_parallel_requests, "user_id": user_id, + "team_id": team_id, } print(f"data: {data}") @@ -355,6 +357,29 @@ async def get_key_info(session, call_key, get_key=None): return await response.json() +async def get_model_list(session, call_key): + """ + Make sure only models user has access to are returned + """ + url = "http://0.0.0.0:4000/v1/models" + headers = { + "Authorization": f"Bearer {call_key}", + "Content-Type": "application/json", + } + + async with session.get(url, headers=headers) as response: + status = response.status + response_text = await response.text() + print(response_text) + print() + + if status != 200: + raise Exception( + f"Request did not return a 200 status code: {status}. Responses {response_text}" + ) + return await response.json() + + async def get_model_info(session, call_key): """ Make sure only models user has access to are returned @@ -719,3 +744,32 @@ async def test_key_delete_ui(): get_key=key, auth_key=admin_ui_key["key"], ) + + +@pytest.mark.parametrize("model_access", ["all-team-models", "gpt-3.5-turbo"]) +@pytest.mark.asyncio +async def test_key_model_list(model_access): + """ + Test if `/v1/models` works as expected. + """ + async with aiohttp.ClientSession() as session: + new_team = await generate_team(session=session) + team_id = new_team["team_id"] + key_gen = await generate_key( + session=session, + i=0, + team_id=team_id, + models=[] if model_access == "all-team-models" else [model_access], + ) + key = key_gen["key"] + print(f"key: {key}") + + model_list = await get_model_list(session=session, call_key=key) + print(f"model_list: {model_list}") + + if model_access == "all-team-models": + assert not isinstance(model_list["data"][0]["id"], list) + assert isinstance(model_list["data"][0]["id"], str) + if model_access == "gpt-3.5-turbo": + assert len(model_list["data"]) == 1 + assert model_list["data"][0]["id"] == model_access