diff --git a/docs/my-website/docs/proxy/users.md b/docs/my-website/docs/proxy/users.md index 866c0c4bb..afa888ba7 100644 --- a/docs/my-website/docs/proxy/users.md +++ b/docs/my-website/docs/proxy/users.md @@ -473,4 +473,75 @@ curl --location 'http://0.0.0.0:4000/key/generate' \ --header 'Authorization: Bearer ' \ --header 'Content-Type: application/json' \ --data '{"models": ["azure-models"], "user_id": "krrish@berri.ai"}' -``` \ No newline at end of file +``` + + +## Advanced +### Set Budgets for Members within a Team + +Use this when you want to budget a users spend within a Team + + +#### Step 1. Create User + +Create a user with `user_id=ishaan` + +```shell +curl --location 'http://0.0.0.0:4000/user/new' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "user_id": "ishaan" +}' +``` + +#### Step 2. Add User to an existing Team - set `max_budget_in_team` + +Set `max_budget_in_team` when adding a User to a team. We use the same `user_id` we set in Step 1 + +```shell +curl -X POST 'http://0.0.0.0:4000/team/member_add' \ +-H 'Authorization: Bearer sk-1234' \ +-H 'Content-Type: application/json' \ +-d '{"team_id": "e8d1460f-846c-45d7-9b43-55f3cc52ac32", "max_budget_in_team": 0.000000000001, "member": {"role": "user", "user_id": "ishaan"}}' +``` + +#### Step 3. Create a Key for user from Step 1 + +Set `user_id=ishaan` from step 1 + +```shell +curl --location 'http://0.0.0.0:4000/key/generate' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "user_id": "ishaan" +}' +``` +Response from `/key/generate` + +We use the `key` from this response in Step 4 +```shell +{"models":[],"spend":0.0,"max_budget":null,"user_id":"ishaan","team_id":null,"max_parallel_requests":null,"metadata":{},"tpm_limit":null,"rpm_limit":null,"budget_duration":null,"allowed_cache_controls":[],"soft_budget":null,"key_alias":null,"duration":null,"aliases":{},"config":{},"permissions":{},"model_max_budget":{},"key":"sk-RV-l2BJEZ_LYNChSx2EueQ","key_name":null,"expires":null,"token_id":null}% +``` + +#### Step 4. Make /chat/completions requests for user + +Use the key from step 3 for this request. After 2-3 requests expect to see The following error `ExceededBudget: Crossed spend within team` + + +```shell +curl --location 'http://localhost:4000/chat/completions' \ + --header 'Authorization: Bearer sk-RV-l2BJEZ_LYNChSx2EueQ' \ + --header 'Content-Type: application/json' \ + --data '{ + "model": "llama3", + "messages": [ + { + "role": "user", + "content": "tes4" + } + ] +}' +``` + diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 34a40b0d8..148f2d11c 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -648,6 +648,20 @@ class LiteLLM_BudgetTable(LiteLLMBase): protected_namespaces = () +class LiteLLM_TeamMemberTable(LiteLLM_BudgetTable): + """ + Used to track spend of a user_id within a team_id + """ + + spend: Optional[float] = None + user_id: Optional[str] = None + team_id: Optional[str] = None + budget_id: Optional[str] = None + + class Config: + protected_namespaces = () + + class NewOrganizationRequest(LiteLLM_BudgetTable): organization_id: Optional[str] = None organization_alias: str @@ -942,6 +956,7 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken): team_blocked: bool = False soft_budget: Optional[float] = None team_model_aliases: Optional[Dict] = None + team_member_spend: Optional[float] = None # End User Params end_user_id: Optional[str] = None diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 7af5b4c94..25d6e79f8 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -797,12 +797,13 @@ async def user_api_key_auth( # Run checks for # 1. If token can call model # 2. If user_id for this token is in budget - # 3. If 'user' passed to /chat/completions, /embeddings endpoint is in budget - # 4. If token is expired - # 5. If token spend is under Budget for the token - # 6. If token spend per model is under budget per model - # 7. If token spend is under team budget - # 8. If team spend is under team budget + # 3. If the user spend within their own team is within budget + # 4. If 'user' passed to /chat/completions, /embeddings endpoint is in budget + # 5. If token is expired + # 6. If token spend is under Budget for the token + # 7. If token spend per model is under budget per model + # 8. If token spend is under team budget + # 9. If team spend is under team budget # Check 1. If token can call model _model_alias_map = {} @@ -1000,6 +1001,43 @@ async def user_api_key_auth( raise Exception( f"ExceededBudget: User {valid_token.user_id} has exceeded their budget. Current spend: {user_current_spend}; Max Budget: {user_max_budget}" ) + # Check 3. Check if user is in their team budget + if valid_token.team_member_spend is not None: + if prisma_client is not None: + + _cache_key = f"{valid_token.team_id}_{valid_token.user_id}" + + team_member_info = await user_api_key_cache.async_get_cache( + key=_cache_key + ) + if team_member_info is None: + team_member_info = ( + await prisma_client.db.litellm_teammembership.find_first( + where={ + "user_id": valid_token.user_id, + "team_id": valid_token.team_id, + }, # type: ignore + include={"litellm_budget_table": True}, + ) + ) + await user_api_key_cache.async_set_cache( + key=_cache_key, + value=team_member_info, + ttl=UserAPIKeyCacheTTLEnum.user_information_cache.value, + ) + + if ( + team_member_info is not None + and team_member_info.litellm_budget_table is not None + ): + team_member_budget = ( + team_member_info.litellm_budget_table.max_budget + ) + if team_member_budget is not None and team_member_budget > 0: + if valid_token.team_member_spend > team_member_budget: + raise Exception( + f"ExceededBudget: Crossed spend within team. UserID: {valid_token.user_id}, in team {valid_token.team_id} has exceeded their budget. Current spend: {valid_token.team_member_spend}; Max Budget: {team_member_budget}" + ) # Check 3. If token is expired if valid_token.expires is not None: @@ -1701,6 +1739,19 @@ async def update_database( response_cost + prisma_client.team_list_transactons.get(team_id, 0) ) + + try: + # Track spend of the team member within this team + # key is "team_id::::user_id::" + team_member_key = f"team_id::{team_id}::user_id::{user_id}" + prisma_client.team_member_list_transactons[team_member_key] = ( + response_cost + + prisma_client.team_member_list_transactons.get( + team_member_key, 0 + ) + ) + except: + pass except Exception as e: verbose_proxy_logger.info( f"Update Team DB failed to execute - {str(e)}\n{traceback.format_exc()}" @@ -1832,6 +1883,16 @@ async def update_cache( # Calculate the new cost by adding the existing cost and response_cost existing_spend_obj.team_spend = existing_team_spend + response_cost + if ( + existing_spend_obj is not None + and getattr(existing_spend_obj, "team_member_spend", None) is not None + ): + existing_team_member_spend = existing_spend_obj.team_member_spend or 0 + # Calculate the new cost by adding the existing cost and response_cost + existing_spend_obj.team_member_spend = ( + existing_team_member_spend + response_cost + ) + # Update the cost column for the given token existing_spend_obj.spend = new_spend user_api_key_cache.set_cache(key=hashed_token, value=existing_spend_obj) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index fdea28c2d..8522b3259 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -551,6 +551,7 @@ class PrismaClient: end_user_list_transactons: dict = {} key_list_transactons: dict = {} team_list_transactons: dict = {} + team_member_list_transactons: dict = {} # key is ["team_id" + "user_id"] org_list_transactons: dict = {} spend_log_transactions: List = [] @@ -1096,9 +1097,11 @@ class PrismaClient: t.models AS team_models, t.blocked AS team_blocked, t.team_alias AS team_alias, + tm.spend AS team_member_spend, m.aliases as team_model_aliases FROM "LiteLLM_VerificationToken" AS v LEFT JOIN "LiteLLM_TeamTable" AS t ON v.team_id = t.team_id + LEFT JOIN "LiteLLM_TeamMembership" AS tm ON v.team_id = tm.team_id AND tm.user_id = v.user_id LEFT JOIN "LiteLLM_ModelTable" m ON t.model_id = m.id WHERE v.token = '{token}' """ @@ -2262,6 +2265,56 @@ async def update_spend( ) raise e + ### UPDATE TEAM Membership TABLE with spend ### + if len(prisma_client.team_member_list_transactons.keys()) > 0: + for i in range(n_retry_times + 1): + start_time = time.time() + try: + async with prisma_client.db.tx( + timeout=timedelta(seconds=60) + ) as transaction: + async with transaction.batch_() as batcher: + for ( + key, + response_cost, + ) in prisma_client.team_member_list_transactons.items(): + # key is "team_id::::user_id::" + team_id = key.split("::")[1] + user_id = key.split("::")[3] + + batcher.litellm_teammembership.update_many( # 'update_many' prevents error from being raised if no row exists + where={"team_id": team_id, "user_id": user_id}, + data={"spend": {"increment": response_cost}}, + ) + prisma_client.team_member_list_transactons = ( + {} + ) # Clear the remaining transactions after processing all batches in the loop. + break + except httpx.ReadTimeout: + if i >= n_retry_times: # If we've reached the maximum number of retries + raise # Re-raise the last exception + # Optionally, sleep for a bit before retrying + await asyncio.sleep(2**i) # Exponential backoff + except Exception as e: + import traceback + + error_msg = ( + f"LiteLLM Prisma Client Exception - update team spend: {str(e)}" + ) + print_verbose(error_msg) + error_traceback = error_msg + "\n" + traceback.format_exc() + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + proxy_logging_obj.failure_handler( + original_exception=e, + duration=_duration, + call_type="update_spend", + traceback_str=error_traceback, + ) + ) + raise e + ### UPDATE ORG TABLE ### if len(prisma_client.org_list_transactons.keys()) > 0: for i in range(n_retry_times + 1): diff --git a/tests/test_team.py b/tests/test_team.py index 2cc384a74..17491b396 100644 --- a/tests/test_team.py +++ b/tests/test_team.py @@ -8,7 +8,13 @@ from openai import AsyncOpenAI async def new_user( - session, i, user_id=None, budget=None, budget_duration=None, models=["azure-models"] + session, + i, + user_id=None, + budget=None, + budget_duration=None, + models=["azure-models"], + team_id=None, ): url = "http://0.0.0.0:4000/user/new" headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} @@ -23,6 +29,9 @@ async def new_user( if user_id is not None: data["user_id"] = user_id + if team_id is not None: + data["team_id"] = team_id + async with session.post(url, headers=headers, json=data) as response: status = response.status response_text = await response.text() @@ -37,7 +46,9 @@ async def new_user( return await response.json() -async def add_member(session, i, team_id, user_id=None, user_email=None): +async def add_member( + session, i, team_id, user_id=None, user_email=None, max_budget=None +): url = "http://0.0.0.0:4000/team/member_add" headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} data = {"team_id": team_id, "member": {"role": "user"}} @@ -46,6 +57,9 @@ async def add_member(session, i, team_id, user_id=None, user_email=None): elif user_id is not None: data["member"]["user_id"] = user_id + if max_budget is not None: + data["max_budget_in_team"] = max_budget + async with session.post(url, headers=headers, json=data) as response: status = response.status response_text = await response.text() @@ -475,3 +489,50 @@ async def test_team_alias(): key = key_gen["key"] ## Test key response = await chat_completion(session=session, key=key, model="cheap-model") + + +@pytest.mark.asyncio +async def test_users_in_team_budget(): + """ + - Create Team + - Create User + - Add User to team with budget = 0.0000001 + - Make Call 1 -> pass + - Make Call 2 -> fail + """ + get_user = f"krrish_{time.time()}@berri.ai" + async with aiohttp.ClientSession() as session: + team = await new_team(session, 0, user_id=get_user) + print("New team=", team) + key_gen = await new_user( + session, + 0, + user_id=get_user, + budget=10, + budget_duration="5s", + team_id=team["team_id"], + models=["fake-openai-endpoint"], + ) + key = key_gen["key"] + + # Add user to team + await add_member( + session, 0, team_id=team["team_id"], user_id=get_user, max_budget=0.0000001 + ) + + # Call 1 + result = await chat_completion(session, key, model="fake-openai-endpoint") + print("Call 1 passed", result) + + await asyncio.sleep(2) + + # Call 2 + try: + await chat_completion(session, key, model="fake-openai-endpoint") + pytest.fail( + "Call 2 should have failed. The user crossed their budget within their team" + ) + except Exception as e: + print("got exception, this is expected") + print(e) + assert "Crossed spend within team" in str(e)