forked from phoenix/litellm-mirror
Merge pull request #2726 from BerriAI/litellm_enforce_user_param
feat(auth_checks.py): enable admin to enforce 'user' param for all openai endpoints
This commit is contained in:
commit
ef51544741
5 changed files with 42 additions and 1 deletions
|
@ -49,6 +49,7 @@ jobs:
|
||||||
pip install argon2-cffi
|
pip install argon2-cffi
|
||||||
pip install "pytest-mock==3.12.0"
|
pip install "pytest-mock==3.12.0"
|
||||||
pip install python-multipart
|
pip install python-multipart
|
||||||
|
pip install google-cloud-aiplatform
|
||||||
- save_cache:
|
- save_cache:
|
||||||
paths:
|
paths:
|
||||||
- ./venv
|
- ./venv
|
||||||
|
|
|
@ -698,6 +698,8 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
|
||||||
team_tpm_limit: Optional[int] = None
|
team_tpm_limit: Optional[int] = None
|
||||||
team_rpm_limit: Optional[int] = None
|
team_rpm_limit: Optional[int] = None
|
||||||
team_max_budget: Optional[float] = None
|
team_max_budget: Optional[float] = None
|
||||||
|
team_models: List = []
|
||||||
|
team_blocked: bool = False
|
||||||
soft_budget: Optional[float] = None
|
soft_budget: Optional[float] = None
|
||||||
team_model_aliases: Optional[Dict] = None
|
team_model_aliases: Optional[Dict] = None
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@ from litellm.proxy._types import (
|
||||||
LiteLLM_TeamTable,
|
LiteLLM_TeamTable,
|
||||||
LiteLLMRoutes,
|
LiteLLMRoutes,
|
||||||
)
|
)
|
||||||
from typing import Optional, Literal
|
from typing import Optional, Literal, Union
|
||||||
from litellm.proxy.utils import PrismaClient
|
from litellm.proxy.utils import PrismaClient
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
|
|
||||||
|
@ -26,6 +26,8 @@ def common_checks(
|
||||||
request_body: dict,
|
request_body: dict,
|
||||||
team_object: LiteLLM_TeamTable,
|
team_object: LiteLLM_TeamTable,
|
||||||
end_user_object: Optional[LiteLLM_EndUserTable],
|
end_user_object: Optional[LiteLLM_EndUserTable],
|
||||||
|
general_settings: dict,
|
||||||
|
route: str,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Common checks across jwt + key-based auth.
|
Common checks across jwt + key-based auth.
|
||||||
|
@ -34,6 +36,7 @@ def common_checks(
|
||||||
2. If team can call model
|
2. If team can call model
|
||||||
3. If team is in budget
|
3. If team is in budget
|
||||||
4. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
4. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
||||||
|
5. [OPTIONAL] If 'enforce_end_user' enabled - did developer pass in 'user' param for openai endpoints
|
||||||
"""
|
"""
|
||||||
_model = request_body.get("model", None)
|
_model = request_body.get("model", None)
|
||||||
if team_object.blocked == True:
|
if team_object.blocked == True:
|
||||||
|
@ -65,6 +68,16 @@ def common_checks(
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}"
|
f"End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}"
|
||||||
)
|
)
|
||||||
|
# 5. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints
|
||||||
|
if (
|
||||||
|
general_settings.get("enforce_user_param", None) is not None
|
||||||
|
and general_settings["enforce_user_param"] == True
|
||||||
|
):
|
||||||
|
if route in LiteLLMRoutes.openai_routes.value and "user" not in request_body:
|
||||||
|
raise Exception(
|
||||||
|
f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}"
|
||||||
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -439,6 +439,8 @@ async def user_api_key_auth(
|
||||||
request_body=request_data,
|
request_body=request_data,
|
||||||
team_object=team_object,
|
team_object=team_object,
|
||||||
end_user_object=end_user_object,
|
end_user_object=end_user_object,
|
||||||
|
general_settings=general_settings,
|
||||||
|
route=route,
|
||||||
)
|
)
|
||||||
# save user object in cache
|
# save user object in cache
|
||||||
await user_api_key_cache.async_set_cache(
|
await user_api_key_cache.async_set_cache(
|
||||||
|
@ -866,6 +868,23 @@ async def user_api_key_auth(
|
||||||
f"ExceededTokenBudget: Current Team Spend: {valid_token.team_spend}; Max Budget for Team: {valid_token.team_max_budget}"
|
f"ExceededTokenBudget: Current Team Spend: {valid_token.team_spend}; Max Budget for Team: {valid_token.team_max_budget}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check 8: Additional Common Checks across jwt + key auth
|
||||||
|
_team_obj = LiteLLM_TeamTable(
|
||||||
|
team_id=valid_token.team_id,
|
||||||
|
max_budget=valid_token.team_max_budget,
|
||||||
|
spend=valid_token.team_spend,
|
||||||
|
tpm_limit=valid_token.team_tpm_limit,
|
||||||
|
rpm_limit=valid_token.team_rpm_limit,
|
||||||
|
blocked=valid_token.team_blocked,
|
||||||
|
models=valid_token.team_models,
|
||||||
|
)
|
||||||
|
_ = common_checks(
|
||||||
|
request_body=request_data,
|
||||||
|
team_object=_team_obj,
|
||||||
|
end_user_object=None,
|
||||||
|
general_settings=general_settings,
|
||||||
|
route=route,
|
||||||
|
)
|
||||||
# Token passed all checks
|
# Token passed all checks
|
||||||
api_key = valid_token.token
|
api_key = valid_token.token
|
||||||
|
|
||||||
|
|
|
@ -1013,6 +1013,8 @@ class PrismaClient:
|
||||||
t.max_budget AS team_max_budget,
|
t.max_budget AS team_max_budget,
|
||||||
t.tpm_limit AS team_tpm_limit,
|
t.tpm_limit AS team_tpm_limit,
|
||||||
t.rpm_limit AS team_rpm_limit,
|
t.rpm_limit AS team_rpm_limit,
|
||||||
|
t.models AS team_models,
|
||||||
|
t.blocked AS team_blocked,
|
||||||
m.aliases as team_model_aliases
|
m.aliases as team_model_aliases
|
||||||
FROM "LiteLLM_VerificationToken" AS v
|
FROM "LiteLLM_VerificationToken" AS v
|
||||||
LEFT JOIN "LiteLLM_TeamTable" AS t ON v.team_id = t.team_id
|
LEFT JOIN "LiteLLM_TeamTable" AS t ON v.team_id = t.team_id
|
||||||
|
@ -1023,6 +1025,10 @@ class PrismaClient:
|
||||||
response = await self.db.query_first(query=sql_query)
|
response = await self.db.query_first(query=sql_query)
|
||||||
|
|
||||||
if response is not None:
|
if response is not None:
|
||||||
|
if response["team_models"] is None:
|
||||||
|
response["team_models"] = []
|
||||||
|
if response["team_blocked"] is None:
|
||||||
|
response["team_blocked"] = False
|
||||||
response = LiteLLM_VerificationTokenView(**response)
|
response = LiteLLM_VerificationTokenView(**response)
|
||||||
# for prisma we need to cast the expires time to str
|
# for prisma we need to cast the expires time to str
|
||||||
if response.expires is not None and isinstance(
|
if response.expires is not None and isinstance(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue