From 25a2f00db6fce76ed4a90f5fcf231d0784418b36 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 25 May 2024 13:02:03 -0700 Subject: [PATCH] fix(proxy_server.py): fix model check for `/v1/models` endpoint when team has restricted access --- litellm/proxy/_types.py | 5 ++ litellm/proxy/auth/model_checks.py | 76 ++++++++++++++++++++++++++++++ litellm/proxy/proxy_server.py | 42 +++++++++-------- litellm/utils.py | 15 +++--- tests/test_keys.py | 32 +++++++++---- 5 files changed, 134 insertions(+), 36 deletions(-) create mode 100644 litellm/proxy/auth/model_checks.py diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index e8b3e6572..680ec5494 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1112,3 +1112,8 @@ class WebhookEvent(CallInfo): ] event_group: Literal["user", "key", "team", "proxy"] event_message: str # human-readable description of event + + +class SpecialModelNames(enum.Enum): + all_team_models = "all-team-models" + all_proxy_models = "all-proxy-models" diff --git a/litellm/proxy/auth/model_checks.py b/litellm/proxy/auth/model_checks.py new file mode 100644 index 000000000..3c874ff0e --- /dev/null +++ b/litellm/proxy/auth/model_checks.py @@ -0,0 +1,76 @@ +# What is this? +## Common checks for /v1/models and `/model/info` +from typing import List, Optional +from litellm.proxy._types import UserAPIKeyAuth, SpecialModelNames +from litellm.utils import get_valid_models +from litellm._logging import verbose_proxy_logger + + +def get_key_models( + user_api_key_dict: UserAPIKeyAuth, proxy_model_list: List[str] +) -> List[str]: + """ + Returns: + - List of model name strings + - Empty list if no models set + """ + all_models = [] + if len(user_api_key_dict.models) > 0: + all_models = user_api_key_dict.models + if SpecialModelNames.all_team_models.value in all_models: + all_models = user_api_key_dict.team_models + if SpecialModelNames.all_proxy_models.value in all_models: + all_models = proxy_model_list + + verbose_proxy_logger.debug("ALL KEY MODELS - {}".format(len(all_models))) + return all_models + + +def get_team_models( + user_api_key_dict: UserAPIKeyAuth, proxy_model_list: List[str] +) -> List[str]: + """ + Returns: + - List of model name strings + - Empty list if no models set + """ + all_models = [] + if len(user_api_key_dict.team_models) > 0: + all_models = user_api_key_dict.team_models + if SpecialModelNames.all_team_models.value in all_models: + all_models = user_api_key_dict.team_models + if SpecialModelNames.all_proxy_models.value in all_models: + all_models = proxy_model_list + + verbose_proxy_logger.debug("ALL TEAM MODELS - {}".format(len(all_models))) + return all_models + + +def get_complete_model_list( + key_models: List[str], + team_models: List[str], + proxy_model_list: List[str], + user_model: Optional[str], + infer_model_from_keys: Optional[bool], +) -> List[str]: + """Logic for returning complete model list for a given key + team pair""" + + """ + - If key list is empty -> defer to team list + - If team list is empty -> defer to proxy model list + """ + + if len(key_models) > 0: + return key_models + + if len(team_models) > 0: + return team_models + + returned_models = proxy_model_list + if user_model is not None: # set via `litellm --model ollama/llam3` + returned_models.append(user_model) + + if infer_model_from_keys is not None and infer_model_from_keys == True: + valid_models = get_valid_models() + returned_models.extend(valid_models) + return returned_models diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 83bc11189..2e301fe8b 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -112,6 +112,11 @@ from litellm.router import ModelInfo as RouterModelInfo from litellm._logging import verbose_router_logger, verbose_proxy_logger from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.auth.litellm_license import LicenseCheck +from litellm.proxy.auth.model_checks import ( + get_complete_model_list, + get_key_models, + get_team_models, +) from litellm.proxy.hooks.prompt_injection_detection import ( _OPTIONAL_PromptInjectionDetection, ) @@ -266,10 +271,6 @@ class UserAPIKeyCacheTTLEnum(enum.Enum): in_memory_cache_ttl = 60 # 1 min ttl ## configure via `general_settings::user_api_key_cache_ttl: ` -class SpecialModelNames(enum.Enum): - all_team_models = "all-team-models" - - class CommonProxyErrors(enum.Enum): db_not_connected_error = "DB not connected" no_llm_router = "No models configured on proxy" @@ -3778,21 +3779,24 @@ def model_list( ): global llm_model_list, general_settings all_models = [] - if len(user_api_key_dict.models) > 0: - all_models = user_api_key_dict.models - if SpecialModelNames.all_team_models.value in all_models: - all_models = user_api_key_dict.team_models - if len(all_models) == 0: # has all proxy models - ## if no specific model access - if general_settings.get("infer_model_from_keys", False): - all_models = litellm.utils.get_valid_models() - if llm_model_list: - all_models = list( - set(all_models + [m["model_name"] for m in llm_model_list]) - ) - if user_model is not None: - all_models += [user_model] - verbose_proxy_logger.debug("all_models: %s", all_models) + ## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ## + if llm_model_list is None: + proxy_model_list = [] + else: + proxy_model_list = [m["model_name"] for m in llm_model_list] + key_models = get_key_models( + user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list + ) + team_models = get_team_models( + user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list + ) + all_models = get_complete_model_list( + key_models=key_models, + team_models=team_models, + proxy_model_list=proxy_model_list, + user_model=user_model, + infer_model_from_keys=general_settings.get("infer_model_from_keys", False), + ) return dict( data=[ { diff --git a/litellm/utils.py b/litellm/utils.py index 9750816ff..7d5e5235f 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -11441,11 +11441,8 @@ class CustomStreamWrapper: self.response_id = original_chunk.id if len(original_chunk.choices) > 0: delta = original_chunk.choices[0].delta - if ( - delta is not None and ( - delta.function_call is not None - or delta.tool_calls is not None - ) + if delta is not None and ( + delta.function_call is not None or delta.tool_calls is not None ): try: model_response.system_fingerprint = ( @@ -11506,7 +11503,11 @@ class CustomStreamWrapper: model_response.choices[0].delta = Delta() else: try: - delta = dict() if original_chunk.choices[0].delta is None else dict(original_chunk.choices[0].delta) + delta = ( + dict() + if original_chunk.choices[0].delta is None + else dict(original_chunk.choices[0].delta) + ) print_verbose(f"original delta: {delta}") model_response.choices[0].delta = Delta(**delta) print_verbose( @@ -12256,7 +12257,7 @@ def trim_messages( return messages -def get_valid_models(): +def get_valid_models() -> List[str]: """ Returns a list of valid LLMs based on the set environment variables diff --git a/tests/test_keys.py b/tests/test_keys.py index 8da6eba0e..f7256e60f 100644 --- a/tests/test_keys.py +++ b/tests/test_keys.py @@ -2,7 +2,7 @@ ## Tests /key endpoints. import pytest -import asyncio, time +import asyncio, time, uuid import aiohttp from openai import AsyncOpenAI import sys, os @@ -14,12 +14,14 @@ sys.path.insert( import litellm -async def generate_team(session): +async def generate_team( + session, models: Optional[list] = None, team_id: Optional[str] = None +): url = "http://0.0.0.0:4000/team/new" headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} - data = { - "team_id": "litellm-dashboard", - } + if team_id is None: + team_id = "litellm-dashboard" + data = {"team_id": team_id, "models": models} async with session.post(url, headers=headers, json=data) as response: status = response.status @@ -746,19 +748,25 @@ async def test_key_delete_ui(): @pytest.mark.parametrize("model_access", ["all-team-models", "gpt-3.5-turbo"]) +@pytest.mark.parametrize("model_access_level", ["key", "team"]) @pytest.mark.asyncio -async def test_key_model_list(model_access): +async def test_key_model_list(model_access, model_access_level): """ Test if `/v1/models` works as expected. """ async with aiohttp.ClientSession() as session: - new_team = await generate_team(session=session) - team_id = "litellm-dashboard" + _models = [] if model_access == "all-team-models" else [model_access] + team_id = "litellm_dashboard_{}".format(uuid.uuid4()) + new_team = await generate_team( + session=session, + models=_models if model_access_level == "team" else None, + team_id=team_id, + ) key_gen = await generate_key( session=session, i=0, team_id=team_id, - models=[] if model_access == "all-team-models" else [model_access], + models=_models if model_access_level == "key" else [], ) key = key_gen["key"] print(f"key: {key}") @@ -770,5 +778,9 @@ async def test_key_model_list(model_access): assert not isinstance(model_list["data"][0]["id"], list) assert isinstance(model_list["data"][0]["id"], str) if model_access == "gpt-3.5-turbo": - assert len(model_list["data"]) == 1 + assert ( + len(model_list["data"]) == 1 + ), "model_access={}, model_access_level={}".format( + model_access, model_access_level + ) assert model_list["data"][0]["id"] == model_access