Merge pull request #3790 from BerriAI/litellm_set_team_member_budgets

[Feat] Set Budgets for Users within a Team
This commit is contained in:
Ishaan Jaff 2024-05-22 19:44:04 -07:00 committed by GitHub
commit a8b64a01dc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 270 additions and 9 deletions

View file

@ -797,12 +797,13 @@ async def user_api_key_auth(
# Run checks for
# 1. If token can call model
# 2. If user_id for this token is in budget
# 3. If 'user' passed to /chat/completions, /embeddings endpoint is in budget
# 4. If token is expired
# 5. If token spend is under Budget for the token
# 6. If token spend per model is under budget per model
# 7. If token spend is under team budget
# 8. If team spend is under team budget
# 3. If the user spend within their own team is within budget
# 4. If 'user' passed to /chat/completions, /embeddings endpoint is in budget
# 5. If token is expired
# 6. If token spend is under Budget for the token
# 7. If token spend per model is under budget per model
# 8. If token spend is under team budget
# 9. If team spend is under team budget
# Check 1. If token can call model
_model_alias_map = {}
@ -1000,6 +1001,43 @@ async def user_api_key_auth(
raise Exception(
f"ExceededBudget: User {valid_token.user_id} has exceeded their budget. Current spend: {user_current_spend}; Max Budget: {user_max_budget}"
)
# Check 3. Check if user is in their team budget
if valid_token.team_member_spend is not None:
if prisma_client is not None:
_cache_key = f"{valid_token.team_id}_{valid_token.user_id}"
team_member_info = await user_api_key_cache.async_get_cache(
key=_cache_key
)
if team_member_info is None:
team_member_info = (
await prisma_client.db.litellm_teammembership.find_first(
where={
"user_id": valid_token.user_id,
"team_id": valid_token.team_id,
}, # type: ignore
include={"litellm_budget_table": True},
)
)
await user_api_key_cache.async_set_cache(
key=_cache_key,
value=team_member_info,
ttl=UserAPIKeyCacheTTLEnum.user_information_cache.value,
)
if (
team_member_info is not None
and team_member_info.litellm_budget_table is not None
):
team_member_budget = (
team_member_info.litellm_budget_table.max_budget
)
if team_member_budget is not None and team_member_budget > 0:
if valid_token.team_member_spend > team_member_budget:
raise Exception(
f"ExceededBudget: Crossed spend within team. UserID: {valid_token.user_id}, in team {valid_token.team_id} has exceeded their budget. Current spend: {valid_token.team_member_spend}; Max Budget: {team_member_budget}"
)
# Check 3. If token is expired
if valid_token.expires is not None:
@ -1701,6 +1739,19 @@ async def update_database(
response_cost
+ prisma_client.team_list_transactons.get(team_id, 0)
)
try:
# Track spend of the team member within this team
# key is "team_id::<value>::user_id::<value>"
team_member_key = f"team_id::{team_id}::user_id::{user_id}"
prisma_client.team_member_list_transactons[team_member_key] = (
response_cost
+ prisma_client.team_member_list_transactons.get(
team_member_key, 0
)
)
except:
pass
except Exception as e:
verbose_proxy_logger.info(
f"Update Team DB failed to execute - {str(e)}\n{traceback.format_exc()}"
@ -1832,6 +1883,16 @@ async def update_cache(
# Calculate the new cost by adding the existing cost and response_cost
existing_spend_obj.team_spend = existing_team_spend + response_cost
if (
existing_spend_obj is not None
and getattr(existing_spend_obj, "team_member_spend", None) is not None
):
existing_team_member_spend = existing_spend_obj.team_member_spend or 0
# Calculate the new cost by adding the existing cost and response_cost
existing_spend_obj.team_member_spend = (
existing_team_member_spend + response_cost
)
# Update the cost column for the given token
existing_spend_obj.spend = new_spend
user_api_key_cache.set_cache(key=hashed_token, value=existing_spend_obj)