Litellm enforce enterprise features (#7357)

* fix(proxy_server.py): enforce team id based model add only works if enterprise user

* fix(auth_checks.py): enforce common_checks can only be imported by user_api_key_auth.py

* fix(auth_checks.py): insert not premium user error message on failed common checks run
This commit is contained in:
Krish Dholakia 2024-12-21 19:14:13 -08:00 committed by GitHub
parent 152056375a
commit ae7f54498f
3 changed files with 62 additions and 1 deletions

View file

@ -9,6 +9,7 @@ Run checks for:
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
""" """
import inspect
import time import time
import traceback import traceback
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
@ -21,6 +22,7 @@ from litellm.caching.caching import DualCache
from litellm.caching.dual_cache import LimitedSizeOrderedDict from litellm.caching.dual_cache import LimitedSizeOrderedDict
from litellm.proxy._types import ( from litellm.proxy._types import (
DB_CONNECTION_ERROR_TYPES, DB_CONNECTION_ERROR_TYPES,
CommonProxyErrors,
LiteLLM_EndUserTable, LiteLLM_EndUserTable,
LiteLLM_JWTAuth, LiteLLM_JWTAuth,
LiteLLM_OrganizationTable, LiteLLM_OrganizationTable,
@ -52,6 +54,33 @@ db_cache_expiry = 5 # refresh every 5s
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
def _allowed_import_check() -> bool:
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
# Get the calling frame
caller_frame = inspect.stack()[2]
caller_function = caller_frame.function
caller_function_callable = caller_frame.frame.f_globals.get(caller_function)
allowed_function = "user_api_key_auth"
allowed_signature = inspect.signature(user_api_key_auth)
if caller_function_callable is None or not callable(caller_function_callable):
raise Exception(f"Caller function {caller_function} is not callable")
caller_signature = inspect.signature(caller_function_callable)
if caller_signature != allowed_signature:
raise TypeError(
f"The function '{caller_function}' does not match the required signature of 'user_api_key_auth'. {CommonProxyErrors.not_premium_user.value}"
)
# Check if the caller module is allowed
if caller_function != allowed_function:
raise ImportError(
f"This function can only be imported by '{allowed_function}'. {CommonProxyErrors.not_premium_user.value}"
)
return True
def common_checks( # noqa: PLR0915 def common_checks( # noqa: PLR0915
request_body: dict, request_body: dict,
team_object: Optional[LiteLLM_TeamTable], team_object: Optional[LiteLLM_TeamTable],
@ -76,6 +105,7 @@ def common_checks( # noqa: PLR0915
9. Check if request body is safe 9. Check if request body is safe
10. [OPTIONAL] Organization checks - is user_object.organization_id is set, run these checks 10. [OPTIONAL] Organization checks - is user_object.organization_id is set, run these checks
""" """
_allowed_import_check()
_model = request_body.get("model", None) _model = request_body.get("model", None)
if team_object is not None and team_object.blocked is True: if team_object is not None and team_object.blocked is True:
raise Exception( raise Exception(

View file

@ -6234,7 +6234,7 @@ async def add_new_model(
model_params: Deployment, model_params: Deployment,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
): ):
global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config, prisma_client, master_key, store_model_in_db, proxy_logging_obj global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config, prisma_client, master_key, store_model_in_db, proxy_logging_obj, premium_user
try: try:
import base64 import base64
@ -6246,6 +6246,12 @@ async def add_new_model(
}, },
) )
if model_params.model_info.team_id is not None and premium_user is not True:
raise HTTPException(
status_code=403,
detail={"error": CommonProxyErrors.not_premium_user.value},
)
if not check_if_team_id_matches_key( if not check_if_team_id_matches_key(
team_id=model_params.model_info.team_id, user_api_key_dict=user_api_key_dict team_id=model_params.model_info.team_id, user_api_key_dict=user_api_key_dict
): ):

View file

@ -199,3 +199,28 @@ async def test_can_team_call_model(model, expect_to_work):
assert model_in_access_group(**args) assert model_in_access_group(**args)
else: else:
assert not model_in_access_group(**args) assert not model_in_access_group(**args)
def test_common_checks_import():
"""
Enforce that common_checks can only be imported by the 'user_api_key_auth()' function.
"""
try:
from litellm.proxy.auth.user_api_key_auth import common_checks
from litellm.proxy._types import CommonProxyErrors
common_checks(
request_body={},
team_object=None,
user_object=None,
end_user_object=None,
global_proxy_spend=None,
general_settings={},
route="",
llm_router=None,
)
pytest.fail(
"common_checks can only be imported by the 'user_api_key_auth()' function."
)
except Exception as e:
assert CommonProxyErrors.not_premium_user.value in str(e)