forked from phoenix/litellm-mirror
Merge pull request #2726 from BerriAI/litellm_enforce_user_param
feat(auth_checks.py): enable admin to enforce 'user' param for all openai endpoints
This commit is contained in:
commit
ef51544741
5 changed files with 42 additions and 1 deletions
|
@ -49,6 +49,7 @@ jobs:
|
|||
pip install argon2-cffi
|
||||
pip install "pytest-mock==3.12.0"
|
||||
pip install python-multipart
|
||||
pip install google-cloud-aiplatform
|
||||
- save_cache:
|
||||
paths:
|
||||
- ./venv
|
||||
|
|
|
@ -698,6 +698,8 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
|
|||
team_tpm_limit: Optional[int] = None
|
||||
team_rpm_limit: Optional[int] = None
|
||||
team_max_budget: Optional[float] = None
|
||||
team_models: List = []
|
||||
team_blocked: bool = False
|
||||
soft_budget: Optional[float] = None
|
||||
team_model_aliases: Optional[Dict] = None
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ from litellm.proxy._types import (
|
|||
LiteLLM_TeamTable,
|
||||
LiteLLMRoutes,
|
||||
)
|
||||
from typing import Optional, Literal
|
||||
from typing import Optional, Literal, Union
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.caching import DualCache
|
||||
|
||||
|
@ -26,6 +26,8 @@ def common_checks(
|
|||
request_body: dict,
|
||||
team_object: LiteLLM_TeamTable,
|
||||
end_user_object: Optional[LiteLLM_EndUserTable],
|
||||
general_settings: dict,
|
||||
route: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Common checks across jwt + key-based auth.
|
||||
|
@ -34,6 +36,7 @@ def common_checks(
|
|||
2. If team can call model
|
||||
3. If team is in budget
|
||||
4. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
||||
5. [OPTIONAL] If 'enforce_end_user' enabled - did developer pass in 'user' param for openai endpoints
|
||||
"""
|
||||
_model = request_body.get("model", None)
|
||||
if team_object.blocked == True:
|
||||
|
@ -65,6 +68,16 @@ def common_checks(
|
|||
raise Exception(
|
||||
f"End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}"
|
||||
)
|
||||
# 5. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints
|
||||
if (
|
||||
general_settings.get("enforce_user_param", None) is not None
|
||||
and general_settings["enforce_user_param"] == True
|
||||
):
|
||||
if route in LiteLLMRoutes.openai_routes.value and "user" not in request_body:
|
||||
raise Exception(
|
||||
f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
|
|
|
@ -439,6 +439,8 @@ async def user_api_key_auth(
|
|||
request_body=request_data,
|
||||
team_object=team_object,
|
||||
end_user_object=end_user_object,
|
||||
general_settings=general_settings,
|
||||
route=route,
|
||||
)
|
||||
# save user object in cache
|
||||
await user_api_key_cache.async_set_cache(
|
||||
|
@ -866,6 +868,23 @@ async def user_api_key_auth(
|
|||
f"ExceededTokenBudget: Current Team Spend: {valid_token.team_spend}; Max Budget for Team: {valid_token.team_max_budget}"
|
||||
)
|
||||
|
||||
# Check 8: Additional Common Checks across jwt + key auth
|
||||
_team_obj = LiteLLM_TeamTable(
|
||||
team_id=valid_token.team_id,
|
||||
max_budget=valid_token.team_max_budget,
|
||||
spend=valid_token.team_spend,
|
||||
tpm_limit=valid_token.team_tpm_limit,
|
||||
rpm_limit=valid_token.team_rpm_limit,
|
||||
blocked=valid_token.team_blocked,
|
||||
models=valid_token.team_models,
|
||||
)
|
||||
_ = common_checks(
|
||||
request_body=request_data,
|
||||
team_object=_team_obj,
|
||||
end_user_object=None,
|
||||
general_settings=general_settings,
|
||||
route=route,
|
||||
)
|
||||
# Token passed all checks
|
||||
api_key = valid_token.token
|
||||
|
||||
|
|
|
@ -1013,6 +1013,8 @@ class PrismaClient:
|
|||
t.max_budget AS team_max_budget,
|
||||
t.tpm_limit AS team_tpm_limit,
|
||||
t.rpm_limit AS team_rpm_limit,
|
||||
t.models AS team_models,
|
||||
t.blocked AS team_blocked,
|
||||
m.aliases as team_model_aliases
|
||||
FROM "LiteLLM_VerificationToken" AS v
|
||||
LEFT JOIN "LiteLLM_TeamTable" AS t ON v.team_id = t.team_id
|
||||
|
@ -1023,6 +1025,10 @@ class PrismaClient:
|
|||
response = await self.db.query_first(query=sql_query)
|
||||
|
||||
if response is not None:
|
||||
if response["team_models"] is None:
|
||||
response["team_models"] = []
|
||||
if response["team_blocked"] is None:
|
||||
response["team_blocked"] = False
|
||||
response = LiteLLM_VerificationTokenView(**response)
|
||||
# for prisma we need to cast the expires time to str
|
||||
if response.expires is not None and isinstance(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue