Merge pull request #2726 from BerriAI/litellm_enforce_user_param

feat(auth_checks.py): enable admin to enforce 'user' param for all openai endpoints
This commit is contained in:
Krish Dholakia 2024-03-27 19:38:52 -07:00 committed by GitHub
commit ef51544741
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 42 additions and 1 deletions

View file

@ -49,6 +49,7 @@ jobs:
pip install argon2-cffi pip install argon2-cffi
pip install "pytest-mock==3.12.0" pip install "pytest-mock==3.12.0"
pip install python-multipart pip install python-multipart
pip install google-cloud-aiplatform
- save_cache: - save_cache:
paths: paths:
- ./venv - ./venv

View file

@ -698,6 +698,8 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
team_tpm_limit: Optional[int] = None team_tpm_limit: Optional[int] = None
team_rpm_limit: Optional[int] = None team_rpm_limit: Optional[int] = None
team_max_budget: Optional[float] = None team_max_budget: Optional[float] = None
team_models: List = []
team_blocked: bool = False
soft_budget: Optional[float] = None soft_budget: Optional[float] = None
team_model_aliases: Optional[Dict] = None team_model_aliases: Optional[Dict] = None

View file

@ -15,7 +15,7 @@ from litellm.proxy._types import (
LiteLLM_TeamTable, LiteLLM_TeamTable,
LiteLLMRoutes, LiteLLMRoutes,
) )
from typing import Optional, Literal from typing import Optional, Literal, Union
from litellm.proxy.utils import PrismaClient from litellm.proxy.utils import PrismaClient
from litellm.caching import DualCache from litellm.caching import DualCache
@ -26,6 +26,8 @@ def common_checks(
request_body: dict, request_body: dict,
team_object: LiteLLM_TeamTable, team_object: LiteLLM_TeamTable,
end_user_object: Optional[LiteLLM_EndUserTable], end_user_object: Optional[LiteLLM_EndUserTable],
general_settings: dict,
route: str,
) -> bool: ) -> bool:
""" """
Common checks across jwt + key-based auth. Common checks across jwt + key-based auth.
@ -34,6 +36,7 @@ def common_checks(
2. If team can call model 2. If team can call model
3. If team is in budget 3. If team is in budget
4. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget 4. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
5. [OPTIONAL] If 'enforce_end_user' enabled - did developer pass in 'user' param for openai endpoints
""" """
_model = request_body.get("model", None) _model = request_body.get("model", None)
if team_object.blocked == True: if team_object.blocked == True:
@ -65,6 +68,16 @@ def common_checks(
raise Exception( raise Exception(
f"End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}" f"End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}"
) )
# 5. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints
if (
general_settings.get("enforce_user_param", None) is not None
and general_settings["enforce_user_param"] == True
):
if route in LiteLLMRoutes.openai_routes.value and "user" not in request_body:
raise Exception(
f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}"
)
return True return True

View file

@ -439,6 +439,8 @@ async def user_api_key_auth(
request_body=request_data, request_body=request_data,
team_object=team_object, team_object=team_object,
end_user_object=end_user_object, end_user_object=end_user_object,
general_settings=general_settings,
route=route,
) )
# save user object in cache # save user object in cache
await user_api_key_cache.async_set_cache( await user_api_key_cache.async_set_cache(
@ -866,6 +868,23 @@ async def user_api_key_auth(
f"ExceededTokenBudget: Current Team Spend: {valid_token.team_spend}; Max Budget for Team: {valid_token.team_max_budget}" f"ExceededTokenBudget: Current Team Spend: {valid_token.team_spend}; Max Budget for Team: {valid_token.team_max_budget}"
) )
# Check 8: Additional Common Checks across jwt + key auth
_team_obj = LiteLLM_TeamTable(
team_id=valid_token.team_id,
max_budget=valid_token.team_max_budget,
spend=valid_token.team_spend,
tpm_limit=valid_token.team_tpm_limit,
rpm_limit=valid_token.team_rpm_limit,
blocked=valid_token.team_blocked,
models=valid_token.team_models,
)
_ = common_checks(
request_body=request_data,
team_object=_team_obj,
end_user_object=None,
general_settings=general_settings,
route=route,
)
# Token passed all checks # Token passed all checks
api_key = valid_token.token api_key = valid_token.token

View file

@ -1013,6 +1013,8 @@ class PrismaClient:
t.max_budget AS team_max_budget, t.max_budget AS team_max_budget,
t.tpm_limit AS team_tpm_limit, t.tpm_limit AS team_tpm_limit,
t.rpm_limit AS team_rpm_limit, t.rpm_limit AS team_rpm_limit,
t.models AS team_models,
t.blocked AS team_blocked,
m.aliases as team_model_aliases m.aliases as team_model_aliases
FROM "LiteLLM_VerificationToken" AS v FROM "LiteLLM_VerificationToken" AS v
LEFT JOIN "LiteLLM_TeamTable" AS t ON v.team_id = t.team_id LEFT JOIN "LiteLLM_TeamTable" AS t ON v.team_id = t.team_id
@ -1023,6 +1025,10 @@ class PrismaClient:
response = await self.db.query_first(query=sql_query) response = await self.db.query_first(query=sql_query)
if response is not None: if response is not None:
if response["team_models"] is None:
response["team_models"] = []
if response["team_blocked"] is None:
response["team_blocked"] = False
response = LiteLLM_VerificationTokenView(**response) response = LiteLLM_VerificationTokenView(**response)
# for prisma we need to cast the expires time to str # for prisma we need to cast the expires time to str
if response.expires is not None and isinstance( if response.expires is not None and isinstance(