diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index c01081abeb..2cb66d4cc0 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -9,6 +9,7 @@ Run checks for: 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget """ +import inspect import time import traceback from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional @@ -21,6 +22,7 @@ from litellm.caching.caching import DualCache from litellm.caching.dual_cache import LimitedSizeOrderedDict from litellm.proxy._types import ( DB_CONNECTION_ERROR_TYPES, + CommonProxyErrors, LiteLLM_EndUserTable, LiteLLM_JWTAuth, LiteLLM_OrganizationTable, @@ -52,6 +54,33 @@ db_cache_expiry = 5 # refresh every 5s all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value +def _allowed_import_check() -> bool: + from litellm.proxy.auth.user_api_key_auth import user_api_key_auth + + # Get the calling frame + caller_frame = inspect.stack()[2] + caller_function = caller_frame.function + caller_function_callable = caller_frame.frame.f_globals.get(caller_function) + + allowed_function = "user_api_key_auth" + allowed_signature = inspect.signature(user_api_key_auth) + if caller_function_callable is None or not callable(caller_function_callable): + raise Exception(f"Caller function {caller_function} is not callable") + caller_signature = inspect.signature(caller_function_callable) + + if caller_signature != allowed_signature: + raise TypeError( + f"The function '{caller_function}' does not match the required signature of 'user_api_key_auth'. {CommonProxyErrors.not_premium_user.value}" + ) + # Check if the caller module is allowed + if caller_function != allowed_function: + raise ImportError( + f"This function can only be imported by '{allowed_function}'. {CommonProxyErrors.not_premium_user.value}" + ) + + return True + + def common_checks( # noqa: PLR0915 request_body: dict, team_object: Optional[LiteLLM_TeamTable], @@ -76,6 +105,7 @@ def common_checks( # noqa: PLR0915 9. Check if request body is safe 10. [OPTIONAL] Organization checks - is user_object.organization_id is set, run these checks """ + _allowed_import_check() _model = request_body.get("model", None) if team_object is not None and team_object.blocked is True: raise Exception( diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index c5c5cfe1eb..fcf720d838 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -6234,7 +6234,7 @@ async def add_new_model( model_params: Deployment, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): - global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config, prisma_client, master_key, store_model_in_db, proxy_logging_obj + global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config, prisma_client, master_key, store_model_in_db, proxy_logging_obj, premium_user try: import base64 @@ -6246,6 +6246,12 @@ async def add_new_model( }, ) + if model_params.model_info.team_id is not None and premium_user is not True: + raise HTTPException( + status_code=403, + detail={"error": CommonProxyErrors.not_premium_user.value}, + ) + if not check_if_team_id_matches_key( team_id=model_params.model_info.team_id, user_api_key_dict=user_api_key_dict ): diff --git a/tests/local_testing/test_auth_checks.py b/tests/local_testing/test_auth_checks.py index 67b5cf11df..741ffbdac0 100644 --- a/tests/local_testing/test_auth_checks.py +++ b/tests/local_testing/test_auth_checks.py @@ -199,3 +199,28 @@ async def test_can_team_call_model(model, expect_to_work): assert model_in_access_group(**args) else: assert not model_in_access_group(**args) + + +def test_common_checks_import(): + """ + Enforce that common_checks can only be imported by the 'user_api_key_auth()' function. + """ + try: + from litellm.proxy.auth.user_api_key_auth import common_checks + from litellm.proxy._types import CommonProxyErrors + + common_checks( + request_body={}, + team_object=None, + user_object=None, + end_user_object=None, + global_proxy_spend=None, + general_settings={}, + route="", + llm_router=None, + ) + pytest.fail( + "common_checks can only be imported by the 'user_api_key_auth()' function." + ) + except Exception as e: + assert CommonProxyErrors.not_premium_user.value in str(e)