From 5729eb5168253017dd37209f49f59adccbca695a Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 20 Jun 2024 16:02:19 -0700 Subject: [PATCH] fix(user_api_key_auth.py): ensure user has access to fallback models for client side fallbacks, checks if user has access to fallback models --- litellm/proxy/auth/auth_checks.py | 60 +++++++++++++++++++ litellm/proxy/auth/user_api_key_auth.py | 63 +++++--------------- litellm/router.py | 2 +- tests/test_fallbacks.py | 78 +++++++++++++++++++++++-- 4 files changed, 150 insertions(+), 53 deletions(-) diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 9c3a79f58..e404a1d40 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -12,6 +12,7 @@ from datetime import datetime from typing import TYPE_CHECKING, Any, Literal, Optional import litellm +from litellm._logging import verbose_proxy_logger from litellm.caching import DualCache from litellm.proxy._types import ( LiteLLM_EndUserTable, @@ -21,6 +22,7 @@ from litellm.proxy._types import ( LiteLLM_UserTable, LiteLLMRoutes, LitellmUserRoles, + UserAPIKeyAuth, ) from litellm.proxy.utils import PrismaClient, ProxyLogging, log_to_opentelemetry from litellm.types.services import ServiceLoggerPayload, ServiceTypes @@ -431,3 +433,61 @@ async def get_org_object( raise Exception( f"Organization doesn't exist in db. Organization={org_id}. Create organization via `/organization/new` call." ) + + +async def can_key_call_model( + model: str, llm_model_list: Optional[list], valid_token: UserAPIKeyAuth +) -> Literal[True]: + """ + Checks if token can call a given model + + Returns: + - True: if token allowed to call model + + Raises: + - Exception: If token not allowed to call model + """ + if model in litellm.model_alias_map: + model = litellm.model_alias_map[model] + + ## check if model in allowed model names + verbose_proxy_logger.debug( + f"LLM Model List pre access group check: {llm_model_list}" + ) + from collections import defaultdict + + access_groups = defaultdict(list) + if llm_model_list is not None: + for m in llm_model_list: + for group in m.get("model_info", {}).get("access_groups", []): + model_name = m["model_name"] + access_groups[group].append(model_name) + + models_in_current_access_groups = [] + if len(access_groups) > 0: # check if token contains any model access groups + for idx, m in enumerate( + valid_token.models + ): # loop token models, if any of them are an access group add the access group + if m in access_groups: + # if it is an access group we need to remove it from valid_token.models + models_in_group = access_groups[m] + models_in_current_access_groups.extend(models_in_group) + + # Filter out models that are access_groups + filtered_models = [m for m in valid_token.models if m not in access_groups] + + filtered_models += models_in_current_access_groups + verbose_proxy_logger.debug(f"model: {model}; allowed_models: {filtered_models}") + if ( + model is not None + and model not in filtered_models + and "*" not in filtered_models + ): + raise ValueError( + f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}" + ) + valid_token.models = filtered_models + verbose_proxy_logger.debug( + f"filtered allowed_models: {filtered_models}; valid_token.models: {valid_token.models}" + ) + return True diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 9ab76b8d8..3d14f5300 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -47,6 +47,7 @@ from litellm._logging import verbose_logger, verbose_proxy_logger from litellm.proxy._types import * from litellm.proxy.auth.auth_checks import ( allowed_routes_check, + can_key_call_model, common_checks, get_actual_routes, get_end_user_object, @@ -494,6 +495,7 @@ async def user_api_key_auth( # Got Valid Token from Cache, DB # Run checks for # 1. If token can call model + ## 1a. If token can call fallback models (if client-side fallbacks given) # 2. If user_id for this token is in budget # 3. If the user spend within their own team is within budget # 4. If 'user' passed to /chat/completions, /embeddings endpoint is in budget @@ -540,55 +542,22 @@ async def user_api_key_auth( except json.JSONDecodeError: data = {} # Provide a default value, such as an empty dictionary model = data.get("model", None) - if model in litellm.model_alias_map: - model = litellm.model_alias_map[model] + fallback_models: Optional[List[str]] = data.get("fallbacks", None) - ## check if model in allowed model names - verbose_proxy_logger.debug( - f"LLM Model List pre access group check: {llm_model_list}" - ) - from collections import defaultdict - - access_groups = defaultdict(list) - if llm_model_list is not None: - for m in llm_model_list: - for group in m.get("model_info", {}).get("access_groups", []): - model_name = m["model_name"] - access_groups[group].append(model_name) - - models_in_current_access_groups = [] - if ( - len(access_groups) > 0 - ): # check if token contains any model access groups - for idx, m in enumerate( - valid_token.models - ): # loop token models, if any of them are an access group add the access group - if m in access_groups: - # if it is an access group we need to remove it from valid_token.models - models_in_group = access_groups[m] - models_in_current_access_groups.extend(models_in_group) - - # Filter out models that are access_groups - filtered_models = [ - m for m in valid_token.models if m not in access_groups - ] - - filtered_models += models_in_current_access_groups - verbose_proxy_logger.debug( - f"model: {model}; allowed_models: {filtered_models}" - ) - if ( - model is not None - and model not in filtered_models - and "*" not in filtered_models - ): - raise ValueError( - f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}" + if model is not None: + await can_key_call_model( + model=model, + llm_model_list=llm_model_list, + valid_token=valid_token, ) - valid_token.models = filtered_models - verbose_proxy_logger.debug( - f"filtered allowed_models: {filtered_models}; valid_token.models: {valid_token.models}" - ) + + if fallback_models is not None: + for m in fallback_models: + await can_key_call_model( + model=m, + llm_model_list=llm_model_list, + valid_token=valid_token, + ) # Check 2. If user_id for this token is in budget if valid_token.user_id is not None: diff --git a/litellm/router.py b/litellm/router.py index b4589c9f0..638df2bf0 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -2087,7 +2087,7 @@ class Router: "content_policy_fallbacks", self.content_policy_fallbacks ) try: - if mock_testing_fallbacks is not None and mock_testing_fallbacks == True: + if mock_testing_fallbacks is not None and mock_testing_fallbacks is True: raise Exception( f"This is a mock exception for model={model_group}, to trigger a fallback. Fallbacks={fallbacks}" ) diff --git a/tests/test_fallbacks.py b/tests/test_fallbacks.py index b87ff3706..e31761e10 100644 --- a/tests/test_fallbacks.py +++ b/tests/test_fallbacks.py @@ -6,16 +6,44 @@ import aiohttp from large_text import text -async def chat_completion(session, key: str, model: str, messages: list): +async def generate_key( + session, + i, + models: list, + calling_key="sk-1234", +): + url = "http://0.0.0.0:4000/key/generate" + headers = { + "Authorization": f"Bearer {calling_key}", + "Content-Type": "application/json", + } + data = { + "models": models, + } + + print(f"data: {data}") + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(f"Response {i} (Status code: {status}):") + print(response_text) + print() + + if status != 200: + raise Exception(f"Request {i} did not return a 200 status code: {status}") + + return await response.json() + + +async def chat_completion(session, key: str, model: str, messages: list, **kwargs): url = "http://0.0.0.0:4000/chat/completions" headers = { "Authorization": f"Bearer {key}", "Content-Type": "application/json", } - data = { - "model": model, - "messages": messages, - } + data = {"model": model, "messages": messages, **kwargs} async with session.post(url, headers=headers, json=data) as response: status = response.status @@ -43,3 +71,43 @@ async def test_chat_completion(): await chat_completion( session=session, key="sk-1234", model=model, messages=messages ) + + +@pytest.mark.parametrize("has_access", [True, False]) +@pytest.mark.asyncio +async def test_chat_completion_client_fallbacks(has_access): + """ + make chat completion call with prompt > context window. expect it to work with fallback + """ + + async with aiohttp.ClientSession() as session: + models = ["gpt-3.5-turbo"] + + if has_access: + models.append("gpt-instruct") + + ## CREATE KEY WITH MODELS + generated_key = await generate_key(session=session, i=0, models=models) + calling_key = generated_key["key"] + model = "gpt-3.5-turbo" + messages = [ + {"role": "user", "content": "Who was Alexander?"}, + ] + + ## CALL PROXY + try: + await chat_completion( + session=session, + key=calling_key, + model=model, + messages=messages, + mock_testing_fallbacks=True, + fallbacks=["gpt-instruct"], + ) + if not has_access: + pytest.fail( + "Expected this to fail, submitted fallback model that key did not have access to" + ) + except Exception as e: + if has_access: + pytest.fail("Expected this to work: {}".format(str(e)))