mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Litellm enforce enterprise features (#7357)
* fix(proxy_server.py): enforce team id based model add only works if enterprise user * fix(auth_checks.py): enforce common_checks can only be imported by user_api_key_auth.py * fix(auth_checks.py): insert not premium user error message on failed common checks run
This commit is contained in:
parent
152056375a
commit
ae7f54498f
3 changed files with 62 additions and 1 deletions
|
@ -9,6 +9,7 @@ Run checks for:
|
|||
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import time
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
|
||||
|
@ -21,6 +22,7 @@ from litellm.caching.caching import DualCache
|
|||
from litellm.caching.dual_cache import LimitedSizeOrderedDict
|
||||
from litellm.proxy._types import (
|
||||
DB_CONNECTION_ERROR_TYPES,
|
||||
CommonProxyErrors,
|
||||
LiteLLM_EndUserTable,
|
||||
LiteLLM_JWTAuth,
|
||||
LiteLLM_OrganizationTable,
|
||||
|
@ -52,6 +54,33 @@ db_cache_expiry = 5 # refresh every 5s
|
|||
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
|
||||
|
||||
|
||||
def _allowed_import_check() -> bool:
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
|
||||
# Get the calling frame
|
||||
caller_frame = inspect.stack()[2]
|
||||
caller_function = caller_frame.function
|
||||
caller_function_callable = caller_frame.frame.f_globals.get(caller_function)
|
||||
|
||||
allowed_function = "user_api_key_auth"
|
||||
allowed_signature = inspect.signature(user_api_key_auth)
|
||||
if caller_function_callable is None or not callable(caller_function_callable):
|
||||
raise Exception(f"Caller function {caller_function} is not callable")
|
||||
caller_signature = inspect.signature(caller_function_callable)
|
||||
|
||||
if caller_signature != allowed_signature:
|
||||
raise TypeError(
|
||||
f"The function '{caller_function}' does not match the required signature of 'user_api_key_auth'. {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
# Check if the caller module is allowed
|
||||
if caller_function != allowed_function:
|
||||
raise ImportError(
|
||||
f"This function can only be imported by '{allowed_function}'. {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def common_checks( # noqa: PLR0915
|
||||
request_body: dict,
|
||||
team_object: Optional[LiteLLM_TeamTable],
|
||||
|
@ -76,6 +105,7 @@ def common_checks( # noqa: PLR0915
|
|||
9. Check if request body is safe
|
||||
10. [OPTIONAL] Organization checks - is user_object.organization_id is set, run these checks
|
||||
"""
|
||||
_allowed_import_check()
|
||||
_model = request_body.get("model", None)
|
||||
if team_object is not None and team_object.blocked is True:
|
||||
raise Exception(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue