diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 2d306ceb31..950e457f71 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -55,11 +55,11 @@ def common_checks( 1. If team is blocked 2. If team can call model 3. If team is in budget - 5. If user passed in (JWT or key.user_id) - is in budget - 4. If end_user (either via JWT or 'user' passed to /chat/completions, /embeddings endpoint) is in budget - 5. [OPTIONAL] If 'enforce_end_user' enabled - did developer pass in 'user' param for openai endpoints - 6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget - 7. [OPTIONAL] If guardrails modified - is request allowed to change this + 4. If user passed in (JWT or key.user_id) - is in budget + 5. If end_user (either via JWT or 'user' passed to /chat/completions, /embeddings endpoint) is in budget + 6. [OPTIONAL] If 'enforce_end_user' enabled - did developer pass in 'user' param for openai endpoints + 7. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget + 8. [OPTIONAL] If guardrails modified - is request allowed to change this """ _model = request_body.get("model", None) if team_object is not None and team_object.blocked is True: @@ -91,12 +91,19 @@ def common_checks( raise Exception( f"Team={team_object.team_id} over budget. Spend={team_object.spend}, Budget={team_object.max_budget}" ) - if user_object is not None and user_object.max_budget is not None: + # 4. If user is in budget + ## 4.1 check personal budget, if personal key + if ( + (team_object is None or team_object.team_id is None) + and user_object is not None + and user_object.max_budget is not None + ): user_budget = user_object.max_budget - if user_budget > user_object.spend: + if user_budget < user_object.spend: raise Exception( f"ExceededBudget: User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_budget}" ) + ## 4.2 check team member budget, if team key # 5. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget if end_user_object is not None and end_user_object.litellm_budget_table is not None: end_user_budget = end_user_object.litellm_budget_table.max_budget diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 3ba5ea9fda..3e03888526 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -552,6 +552,7 @@ async def user_api_key_auth( key=api_key ) if valid_token is None: + user_obj: Optional[LiteLLM_UserTable] = None ## check db verbose_proxy_logger.debug("api key: %s", api_key) if prisma_client is not None: @@ -650,114 +651,26 @@ async def user_api_key_auth( valid_token=valid_token, ) - # Check 2. If user_id for this token is in budget + # Check 2. If user_id for this token is in budget - done in common_checks() if valid_token.user_id is not None: - user_id_list = [valid_token.user_id] - for id in user_id_list: - value = user_api_key_cache.get_cache(key=id) - if value is not None: - if user_id_information is None: - user_id_information = [] - user_id_information.append(value) - if user_id_information is None or ( - isinstance(user_id_information, list) - and len(user_id_information) < 1 - ): - if prisma_client is not None: - user_id_information = await prisma_client.get_data( - user_id_list=[ - valid_token.user_id, - ], - table_name="user", - query_type="find_all", - ) - if user_id_information is not None: - for _id in user_id_information: - await user_api_key_cache.async_set_cache( - key=_id["user_id"], - value=_id, - ) + user_obj = await get_user_object( + user_id=valid_token.user_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + user_id_upsert=False, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + ) + # value = user_api_key_cache.get_cache(key=id) + if user_obj is not None: + if user_id_information is None: + user_id_information = [] + user_id_information.append(user_obj.model_dump()) verbose_proxy_logger.debug( f"user_id_information: {user_id_information}" ) - if user_id_information is not None: - if isinstance(user_id_information, list): - ## Check if user in budget - for _user in user_id_information: - if _user is None: - continue - assert isinstance(_user, dict) - # check if user is admin # - - # Token exists, not expired now check if its in budget for the user - user_max_budget = _user.get("max_budget", None) - user_current_spend = _user.get("spend", None) - - verbose_proxy_logger.debug( - f"user_id: {_user.get('user_id', None)}; user_max_budget: {user_max_budget}; user_current_spend: {user_current_spend}" - ) - - if ( - user_max_budget is not None - and user_current_spend is not None - ): - call_info = CallInfo( - token=valid_token.token, - spend=user_current_spend, - max_budget=user_max_budget, - user_id=_user.get("user_id", None), - user_email=_user.get("user_email", None), - key_alias=valid_token.key_alias, - ) - asyncio.create_task( - proxy_logging_obj.budget_alerts( - type="user_budget", - user_info=call_info, - ) - ) - - _user_id = _user.get("user_id", None) - if user_current_spend > user_max_budget: - raise litellm.BudgetExceededError( - current_cost=user_current_spend, - max_budget=user_max_budget, - ) - else: - # Token exists, not expired now check if its in budget for the user - user_max_budget = getattr( - user_id_information, "max_budget", None - ) - user_current_spend = getattr(user_id_information, "spend", None) - - if ( - user_max_budget is not None - and user_current_spend is not None - ): - call_info = CallInfo( - token=valid_token.token, - spend=user_current_spend, - max_budget=user_max_budget, - user_id=getattr(user_id_information, "user_id", None), - user_email=getattr( - user_id_information, "user_email", None - ), - key_alias=valid_token.key_alias, - ) - asyncio.create_task( - proxy_logging_obj.budget_alerts( - type="user_budget", - user_info=call_info, - ) - ) - - if user_current_spend > user_max_budget: - raise litellm.BudgetExceededError( - current_cost=user_current_spend, - max_budget=user_max_budget, - ) - # Check 3. Check if user is in their team budget if valid_token.team_member_spend is not None: if prisma_client is not None: @@ -983,7 +896,7 @@ async def user_api_key_auth( _ = common_checks( request_body=request_data, team_object=_team_obj, - user_object=None, + user_object=user_obj, end_user_object=_end_user_object, general_settings=general_settings, global_proxy_spend=global_proxy_spend, diff --git a/litellm/tests/test_user_api_key_auth.py b/litellm/tests/test_user_api_key_auth.py index 1ba81d4fa4..e8f5a8e08a 100644 --- a/litellm/tests/test_user_api_key_auth.py +++ b/litellm/tests/test_user_api_key_auth.py @@ -116,3 +116,65 @@ def test_returned_user_api_key_auth(user_role): assert new_obj.user_role == user_role else: assert new_obj.user_role == "internal_user" + + +@pytest.mark.parametrize("key_ownership", ["user_key", "team_key"]) +@pytest.mark.asyncio +async def test_user_personal_budgets(key_ownership): + """ + Set a personal budget on a user + + - have it only apply when key belongs to user -> raises BudgetExceededError + - if key belongs to team, have key respect team budget -> allows call to go through + """ + import asyncio + import time + + from fastapi import Request + from starlette.datastructures import URL + + from litellm.proxy._types import LiteLLM_UserTable, UserAPIKeyAuth + from litellm.proxy.auth.user_api_key_auth import user_api_key_auth + from litellm.proxy.proxy_server import hash_token, user_api_key_cache + + _user_id = "1234" + user_key = "sk-12345678" + + if key_ownership == "user_key": + valid_token = UserAPIKeyAuth( + token=hash_token(user_key), + last_refreshed_at=time.time(), + user_id=_user_id, + spend=20, + ) + elif key_ownership == "team_key": + valid_token = UserAPIKeyAuth( + token=hash_token(user_key), + last_refreshed_at=time.time(), + user_id=_user_id, + team_id="my-special-team", + team_max_budget=100, + spend=20, + ) + await asyncio.sleep(1) + user_obj = LiteLLM_UserTable( + user_id=_user_id, spend=11, max_budget=10, user_email="" + ) + user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token) + user_api_key_cache.set_cache(key="{}".format(_user_id), value=user_obj) + + setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + setattr(litellm.proxy.proxy_server, "prisma_client", "hello-world") + + request = Request(scope={"type": "http"}) + request._url = URL(url="/chat/completions") + + try: + await user_api_key_auth(request=request, api_key="Bearer " + user_key) + + if key_ownership == "user_key": + pytest.fail("Expected this call to fail. User is over limit.") + except Exception: + if key_ownership == "team_key": + pytest.fail("Expected this call to work. Key is below team budget.")