fix(proxy_server.py): support tracking org spend

currently works when org set for jwt auth
This commit is contained in:
Krrish Dholakia 2024-04-11 23:01:21 -07:00
parent de9258e700
commit aa5da4346a
6 changed files with 136 additions and 1 deletions

View file

@ -567,6 +567,7 @@ class PrismaClient:
end_user_list_transactons: dict = {}
key_list_transactons: dict = {}
team_list_transactons: dict = {}
org_list_transactons: dict = {}
spend_log_transactions: List = []
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
@ -2150,6 +2151,46 @@ async def update_spend(
)
raise e
### UPDATE ORG TABLE ###
if len(prisma_client.org_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
try:
async with prisma_client.db.tx(
timeout=timedelta(seconds=60)
) as transaction:
async with transaction.batch_() as batcher:
for (
org_id,
response_cost,
) in prisma_client.org_list_transactons.items():
batcher.litellm_organizationtable.update_many( # 'update_many' prevents error from being raised if no row exists
where={"organization_id": org_id},
data={"spend": {"increment": response_cost}},
)
prisma_client.org_list_transactons = (
{}
) # Clear the remaining transactions after processing all batches in the loop.
break
except httpx.ReadTimeout:
if i >= n_retry_times: # If we've reached the maximum number of retries
raise # Re-raise the last exception
# Optionally, sleep for a bit before retrying
await asyncio.sleep(2**i) # Exponential backoff
except Exception as e:
import traceback
error_msg = (
f"LiteLLM Prisma Client Exception - update org spend: {str(e)}"
)
print_verbose(error_msg)
error_traceback = error_msg + "\n" + traceback.format_exc()
asyncio.create_task(
proxy_logging_obj.failure_handler(
original_exception=e, traceback_str=error_traceback
)
)
raise e
### UPDATE SPEND LOGS ###
verbose_proxy_logger.debug(
"Spend Logs transactions: {}".format(len(prisma_client.spend_log_transactions))