fix(user_api_key_auth.py): respect team budgets over user budget, if key belongs to team

Closes https://github.com/BerriAI/litellm/issues/5097
This commit is contained in:
Krrish Dholakia 2024-08-07 14:32:27 -07:00
parent f579aef740
commit d832327ccf
3 changed files with 92 additions and 110 deletions

View file

@ -55,11 +55,11 @@ def common_checks(
1. If team is blocked 1. If team is blocked
2. If team can call model 2. If team can call model
3. If team is in budget 3. If team is in budget
5. If user passed in (JWT or key.user_id) - is in budget 4. If user passed in (JWT or key.user_id) - is in budget
4. If end_user (either via JWT or 'user' passed to /chat/completions, /embeddings endpoint) is in budget 5. If end_user (either via JWT or 'user' passed to /chat/completions, /embeddings endpoint) is in budget
5. [OPTIONAL] If 'enforce_end_user' enabled - did developer pass in 'user' param for openai endpoints 6. [OPTIONAL] If 'enforce_end_user' enabled - did developer pass in 'user' param for openai endpoints
6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget 7. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
7. [OPTIONAL] If guardrails modified - is request allowed to change this 8. [OPTIONAL] If guardrails modified - is request allowed to change this
""" """
_model = request_body.get("model", None) _model = request_body.get("model", None)
if team_object is not None and team_object.blocked is True: if team_object is not None and team_object.blocked is True:
@ -91,12 +91,19 @@ def common_checks(
raise Exception( raise Exception(
f"Team={team_object.team_id} over budget. Spend={team_object.spend}, Budget={team_object.max_budget}" f"Team={team_object.team_id} over budget. Spend={team_object.spend}, Budget={team_object.max_budget}"
) )
if user_object is not None and user_object.max_budget is not None: # 4. If user is in budget
## 4.1 check personal budget, if personal key
if (
(team_object is None or team_object.team_id is None)
and user_object is not None
and user_object.max_budget is not None
):
user_budget = user_object.max_budget user_budget = user_object.max_budget
if user_budget > user_object.spend: if user_budget < user_object.spend:
raise Exception( raise Exception(
f"ExceededBudget: User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_budget}" f"ExceededBudget: User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_budget}"
) )
## 4.2 check team member budget, if team key
# 5. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget # 5. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
if end_user_object is not None and end_user_object.litellm_budget_table is not None: if end_user_object is not None and end_user_object.litellm_budget_table is not None:
end_user_budget = end_user_object.litellm_budget_table.max_budget end_user_budget = end_user_object.litellm_budget_table.max_budget

View file

@ -552,6 +552,7 @@ async def user_api_key_auth(
key=api_key key=api_key
) )
if valid_token is None: if valid_token is None:
user_obj: Optional[LiteLLM_UserTable] = None
## check db ## check db
verbose_proxy_logger.debug("api key: %s", api_key) verbose_proxy_logger.debug("api key: %s", api_key)
if prisma_client is not None: if prisma_client is not None:
@ -650,114 +651,26 @@ async def user_api_key_auth(
valid_token=valid_token, valid_token=valid_token,
) )
# Check 2. If user_id for this token is in budget # Check 2. If user_id for this token is in budget - done in common_checks()
if valid_token.user_id is not None: if valid_token.user_id is not None:
user_id_list = [valid_token.user_id] user_obj = await get_user_object(
for id in user_id_list: user_id=valid_token.user_id,
value = user_api_key_cache.get_cache(key=id) prisma_client=prisma_client,
if value is not None: user_api_key_cache=user_api_key_cache,
user_id_upsert=False,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
# value = user_api_key_cache.get_cache(key=id)
if user_obj is not None:
if user_id_information is None: if user_id_information is None:
user_id_information = [] user_id_information = []
user_id_information.append(value) user_id_information.append(user_obj.model_dump())
if user_id_information is None or (
isinstance(user_id_information, list)
and len(user_id_information) < 1
):
if prisma_client is not None:
user_id_information = await prisma_client.get_data(
user_id_list=[
valid_token.user_id,
],
table_name="user",
query_type="find_all",
)
if user_id_information is not None:
for _id in user_id_information:
await user_api_key_cache.async_set_cache(
key=_id["user_id"],
value=_id,
)
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"user_id_information: {user_id_information}" f"user_id_information: {user_id_information}"
) )
if user_id_information is not None:
if isinstance(user_id_information, list):
## Check if user in budget
for _user in user_id_information:
if _user is None:
continue
assert isinstance(_user, dict)
# check if user is admin #
# Token exists, not expired now check if its in budget for the user
user_max_budget = _user.get("max_budget", None)
user_current_spend = _user.get("spend", None)
verbose_proxy_logger.debug(
f"user_id: {_user.get('user_id', None)}; user_max_budget: {user_max_budget}; user_current_spend: {user_current_spend}"
)
if (
user_max_budget is not None
and user_current_spend is not None
):
call_info = CallInfo(
token=valid_token.token,
spend=user_current_spend,
max_budget=user_max_budget,
user_id=_user.get("user_id", None),
user_email=_user.get("user_email", None),
key_alias=valid_token.key_alias,
)
asyncio.create_task(
proxy_logging_obj.budget_alerts(
type="user_budget",
user_info=call_info,
)
)
_user_id = _user.get("user_id", None)
if user_current_spend > user_max_budget:
raise litellm.BudgetExceededError(
current_cost=user_current_spend,
max_budget=user_max_budget,
)
else:
# Token exists, not expired now check if its in budget for the user
user_max_budget = getattr(
user_id_information, "max_budget", None
)
user_current_spend = getattr(user_id_information, "spend", None)
if (
user_max_budget is not None
and user_current_spend is not None
):
call_info = CallInfo(
token=valid_token.token,
spend=user_current_spend,
max_budget=user_max_budget,
user_id=getattr(user_id_information, "user_id", None),
user_email=getattr(
user_id_information, "user_email", None
),
key_alias=valid_token.key_alias,
)
asyncio.create_task(
proxy_logging_obj.budget_alerts(
type="user_budget",
user_info=call_info,
)
)
if user_current_spend > user_max_budget:
raise litellm.BudgetExceededError(
current_cost=user_current_spend,
max_budget=user_max_budget,
)
# Check 3. Check if user is in their team budget # Check 3. Check if user is in their team budget
if valid_token.team_member_spend is not None: if valid_token.team_member_spend is not None:
if prisma_client is not None: if prisma_client is not None:
@ -983,7 +896,7 @@ async def user_api_key_auth(
_ = common_checks( _ = common_checks(
request_body=request_data, request_body=request_data,
team_object=_team_obj, team_object=_team_obj,
user_object=None, user_object=user_obj,
end_user_object=_end_user_object, end_user_object=_end_user_object,
general_settings=general_settings, general_settings=general_settings,
global_proxy_spend=global_proxy_spend, global_proxy_spend=global_proxy_spend,

View file

@ -116,3 +116,65 @@ def test_returned_user_api_key_auth(user_role):
assert new_obj.user_role == user_role assert new_obj.user_role == user_role
else: else:
assert new_obj.user_role == "internal_user" assert new_obj.user_role == "internal_user"
@pytest.mark.parametrize("key_ownership", ["user_key", "team_key"])
@pytest.mark.asyncio
async def test_user_personal_budgets(key_ownership):
"""
Set a personal budget on a user
- have it only apply when key belongs to user -> raises BudgetExceededError
- if key belongs to team, have key respect team budget -> allows call to go through
"""
import asyncio
import time
from fastapi import Request
from starlette.datastructures import URL
from litellm.proxy._types import LiteLLM_UserTable, UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.proxy_server import hash_token, user_api_key_cache
_user_id = "1234"
user_key = "sk-12345678"
if key_ownership == "user_key":
valid_token = UserAPIKeyAuth(
token=hash_token(user_key),
last_refreshed_at=time.time(),
user_id=_user_id,
spend=20,
)
elif key_ownership == "team_key":
valid_token = UserAPIKeyAuth(
token=hash_token(user_key),
last_refreshed_at=time.time(),
user_id=_user_id,
team_id="my-special-team",
team_max_budget=100,
spend=20,
)
await asyncio.sleep(1)
user_obj = LiteLLM_UserTable(
user_id=_user_id, spend=11, max_budget=10, user_email=""
)
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
user_api_key_cache.set_cache(key="{}".format(_user_id), value=user_obj)
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm.proxy.proxy_server, "prisma_client", "hello-world")
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
try:
await user_api_key_auth(request=request, api_key="Bearer " + user_key)
if key_ownership == "user_key":
pytest.fail("Expected this call to fail. User is over limit.")
except Exception:
if key_ownership == "team_key":
pytest.fail("Expected this call to work. Key is below team budget.")