fix(proxy_server.py): enforce team based spend limits

This commit is contained in:
Krrish Dholakia 2024-02-26 15:45:25 -08:00
parent f7af18d72a
commit 71d4b7aaf4
3 changed files with 146 additions and 4 deletions

View file

@ -350,13 +350,14 @@ async def user_api_key_auth(
original_api_key = api_key # (Patch: For DynamoDB Backwards Compatibility)
if api_key.startswith("sk-"):
api_key = hash_token(token=api_key)
valid_token = user_api_key_cache.get_cache(key=api_key)
# valid_token = user_api_key_cache.get_cache(key=api_key)
valid_token = None
if valid_token is None:
## check db
verbose_proxy_logger.debug(f"api key: {api_key}")
if prisma_client is not None:
valid_token = await prisma_client.get_data(
token=api_key,
token=api_key, table_name="combined_view"
)
elif custom_db_client is not None:
@ -381,6 +382,8 @@ async def user_api_key_auth(
# 4. If token is expired
# 5. If token spend is under Budget for the token
# 6. If token spend per model is under budget per model
# 7. If token spend is under team budget
# 8. If team spend is under team budget
request_data = await _read_request_body(
request=request
@ -610,6 +613,44 @@ async def user_api_key_auth(
f"ExceededModelBudget: Current spend for model: {current_model_spend}; Max Budget for Model: {current_model_budget}"
)
# Check 6. Token spend is under Team budget
if (
valid_token.spend is not None
and valid_token.team_max_budget is not None
):
asyncio.create_task(
proxy_logging_obj.budget_alerts(
user_max_budget=valid_token.team_max_budget,
user_current_spend=valid_token.spend,
type="token_budget",
user_info=valid_token,
)
)
if valid_token.spend > valid_token.team_max_budget:
raise Exception(
f"ExceededTokenBudget: Current spend for token: {valid_token.spend}; Max Budget for Team: {valid_token.team_max_budget}"
)
# Check 7. Team spend is under Team budget
if (
valid_token.team_spend is not None
and valid_token.team_max_budget is not None
):
asyncio.create_task(
proxy_logging_obj.budget_alerts(
user_max_budget=valid_token.team_max_budget,
user_current_spend=valid_token.team_spend,
type="token_budget",
user_info=valid_token,
)
)
if valid_token.team_spend > valid_token.team_max_budget:
raise Exception(
f"ExceededTokenBudget: Current Team Spend: {valid_token.team_spend}; Max Budget for Team: {valid_token.team_max_budget}"
)
# Token passed all checks
api_key = valid_token.token
@ -2256,6 +2297,10 @@ async def startup_event():
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
)
### CHECK IF VIEW EXISTS ###
create_view_response = await prisma_client.check_view_exists()
print(f"create_view_response: {create_view_response}") # noqa
### START BUDGET SCHEDULER ###
if prisma_client is not None:
scheduler = AsyncIOScheduler()