Litellm enforce enterprise features (#7357)

* fix(proxy_server.py): enforce team id based model add only works if enterprise user

* fix(auth_checks.py): enforce common_checks can only be imported by user_api_key_auth.py

* fix(auth_checks.py): insert not premium user error message on failed common checks run
This commit is contained in:
Krish Dholakia 2024-12-21 19:14:13 -08:00 committed by GitHub
parent 152056375a
commit ae7f54498f
3 changed files with 62 additions and 1 deletions

View file

@ -9,6 +9,7 @@ Run checks for:
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
"""
import inspect
import time
import traceback
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
@ -21,6 +22,7 @@ from litellm.caching.caching import DualCache
from litellm.caching.dual_cache import LimitedSizeOrderedDict
from litellm.proxy._types import (
DB_CONNECTION_ERROR_TYPES,
CommonProxyErrors,
LiteLLM_EndUserTable,
LiteLLM_JWTAuth,
LiteLLM_OrganizationTable,
@ -52,6 +54,33 @@ db_cache_expiry = 5 # refresh every 5s
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
def _allowed_import_check() -> bool:
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
# Get the calling frame
caller_frame = inspect.stack()[2]
caller_function = caller_frame.function
caller_function_callable = caller_frame.frame.f_globals.get(caller_function)
allowed_function = "user_api_key_auth"
allowed_signature = inspect.signature(user_api_key_auth)
if caller_function_callable is None or not callable(caller_function_callable):
raise Exception(f"Caller function {caller_function} is not callable")
caller_signature = inspect.signature(caller_function_callable)
if caller_signature != allowed_signature:
raise TypeError(
f"The function '{caller_function}' does not match the required signature of 'user_api_key_auth'. {CommonProxyErrors.not_premium_user.value}"
)
# Check if the caller module is allowed
if caller_function != allowed_function:
raise ImportError(
f"This function can only be imported by '{allowed_function}'. {CommonProxyErrors.not_premium_user.value}"
)
return True
def common_checks( # noqa: PLR0915
request_body: dict,
team_object: Optional[LiteLLM_TeamTable],
@ -76,6 +105,7 @@ def common_checks( # noqa: PLR0915
9. Check if request body is safe
10. [OPTIONAL] Organization checks - is user_object.organization_id is set, run these checks
"""
_allowed_import_check()
_model = request_body.get("model", None)
if team_object is not None and team_object.blocked is True:
raise Exception(