mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Litellm enforce enterprise features (#7357)
* fix(proxy_server.py): enforce team id based model add only works if enterprise user * fix(auth_checks.py): enforce common_checks can only be imported by user_api_key_auth.py * fix(auth_checks.py): insert not premium user error message on failed common checks run
This commit is contained in:
parent
49b6e539b7
commit
a8ae2f551a
3 changed files with 62 additions and 1 deletions
|
@ -9,6 +9,7 @@ Run checks for:
|
||||||
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import inspect
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
|
||||||
|
@ -21,6 +22,7 @@ from litellm.caching.caching import DualCache
|
||||||
from litellm.caching.dual_cache import LimitedSizeOrderedDict
|
from litellm.caching.dual_cache import LimitedSizeOrderedDict
|
||||||
from litellm.proxy._types import (
|
from litellm.proxy._types import (
|
||||||
DB_CONNECTION_ERROR_TYPES,
|
DB_CONNECTION_ERROR_TYPES,
|
||||||
|
CommonProxyErrors,
|
||||||
LiteLLM_EndUserTable,
|
LiteLLM_EndUserTable,
|
||||||
LiteLLM_JWTAuth,
|
LiteLLM_JWTAuth,
|
||||||
LiteLLM_OrganizationTable,
|
LiteLLM_OrganizationTable,
|
||||||
|
@ -52,6 +54,33 @@ db_cache_expiry = 5 # refresh every 5s
|
||||||
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
|
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
|
||||||
|
|
||||||
|
|
||||||
|
def _allowed_import_check() -> bool:
|
||||||
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
|
|
||||||
|
# Get the calling frame
|
||||||
|
caller_frame = inspect.stack()[2]
|
||||||
|
caller_function = caller_frame.function
|
||||||
|
caller_function_callable = caller_frame.frame.f_globals.get(caller_function)
|
||||||
|
|
||||||
|
allowed_function = "user_api_key_auth"
|
||||||
|
allowed_signature = inspect.signature(user_api_key_auth)
|
||||||
|
if caller_function_callable is None or not callable(caller_function_callable):
|
||||||
|
raise Exception(f"Caller function {caller_function} is not callable")
|
||||||
|
caller_signature = inspect.signature(caller_function_callable)
|
||||||
|
|
||||||
|
if caller_signature != allowed_signature:
|
||||||
|
raise TypeError(
|
||||||
|
f"The function '{caller_function}' does not match the required signature of 'user_api_key_auth'. {CommonProxyErrors.not_premium_user.value}"
|
||||||
|
)
|
||||||
|
# Check if the caller module is allowed
|
||||||
|
if caller_function != allowed_function:
|
||||||
|
raise ImportError(
|
||||||
|
f"This function can only be imported by '{allowed_function}'. {CommonProxyErrors.not_premium_user.value}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def common_checks( # noqa: PLR0915
|
def common_checks( # noqa: PLR0915
|
||||||
request_body: dict,
|
request_body: dict,
|
||||||
team_object: Optional[LiteLLM_TeamTable],
|
team_object: Optional[LiteLLM_TeamTable],
|
||||||
|
@ -76,6 +105,7 @@ def common_checks( # noqa: PLR0915
|
||||||
9. Check if request body is safe
|
9. Check if request body is safe
|
||||||
10. [OPTIONAL] Organization checks - is user_object.organization_id is set, run these checks
|
10. [OPTIONAL] Organization checks - is user_object.organization_id is set, run these checks
|
||||||
"""
|
"""
|
||||||
|
_allowed_import_check()
|
||||||
_model = request_body.get("model", None)
|
_model = request_body.get("model", None)
|
||||||
if team_object is not None and team_object.blocked is True:
|
if team_object is not None and team_object.blocked is True:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
|
|
@ -6234,7 +6234,7 @@ async def add_new_model(
|
||||||
model_params: Deployment,
|
model_params: Deployment,
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
):
|
):
|
||||||
global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config, prisma_client, master_key, store_model_in_db, proxy_logging_obj
|
global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config, prisma_client, master_key, store_model_in_db, proxy_logging_obj, premium_user
|
||||||
try:
|
try:
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
|
@ -6246,6 +6246,12 @@ async def add_new_model(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if model_params.model_info.team_id is not None and premium_user is not True:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail={"error": CommonProxyErrors.not_premium_user.value},
|
||||||
|
)
|
||||||
|
|
||||||
if not check_if_team_id_matches_key(
|
if not check_if_team_id_matches_key(
|
||||||
team_id=model_params.model_info.team_id, user_api_key_dict=user_api_key_dict
|
team_id=model_params.model_info.team_id, user_api_key_dict=user_api_key_dict
|
||||||
):
|
):
|
||||||
|
|
|
@ -199,3 +199,28 @@ async def test_can_team_call_model(model, expect_to_work):
|
||||||
assert model_in_access_group(**args)
|
assert model_in_access_group(**args)
|
||||||
else:
|
else:
|
||||||
assert not model_in_access_group(**args)
|
assert not model_in_access_group(**args)
|
||||||
|
|
||||||
|
|
||||||
|
def test_common_checks_import():
|
||||||
|
"""
|
||||||
|
Enforce that common_checks can only be imported by the 'user_api_key_auth()' function.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from litellm.proxy.auth.user_api_key_auth import common_checks
|
||||||
|
from litellm.proxy._types import CommonProxyErrors
|
||||||
|
|
||||||
|
common_checks(
|
||||||
|
request_body={},
|
||||||
|
team_object=None,
|
||||||
|
user_object=None,
|
||||||
|
end_user_object=None,
|
||||||
|
global_proxy_spend=None,
|
||||||
|
general_settings={},
|
||||||
|
route="",
|
||||||
|
llm_router=None,
|
||||||
|
)
|
||||||
|
pytest.fail(
|
||||||
|
"common_checks can only be imported by the 'user_api_key_auth()' function."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
assert CommonProxyErrors.not_premium_user.value in str(e)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue