mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
refactor, use commit_update_transactions_to_db
This commit is contained in:
parent
7f24714576
commit
33f877e8b0
2 changed files with 252 additions and 198 deletions
|
@ -7,16 +7,28 @@ Module responsible for
|
|||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy._types import Litellm_EntityType, LiteLLM_UserTable, SpendLogsPayload
|
||||
from litellm.proxy._types import (
|
||||
DB_CONNECTION_ERROR_TYPES,
|
||||
Litellm_EntityType,
|
||||
LiteLLM_UserTable,
|
||||
SpendLogsPayload,
|
||||
)
|
||||
from litellm.proxy.spend_tracking.spend_tracking_utils import get_logging_payload
|
||||
from litellm.proxy.utils import PrismaClient, ProxyUpdateSpend, hash_token
|
||||
from litellm.proxy.utils import (
|
||||
PrismaClient,
|
||||
ProxyLogging,
|
||||
ProxyUpdateSpend,
|
||||
_raise_failed_update_spend_exception,
|
||||
hash_token,
|
||||
)
|
||||
|
||||
|
||||
class DBSpendUpdateWriter:
|
||||
|
@ -346,3 +358,235 @@ class DBSpendUpdateWriter:
|
|||
payload.copy()
|
||||
)
|
||||
return prisma_client
|
||||
|
||||
@staticmethod
|
||||
async def commit_update_transactions_to_db( # noqa: PLR0915
|
||||
prisma_client: PrismaClient,
|
||||
n_retry_times: int,
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
):
|
||||
"""
|
||||
Handles commiting update spend transactions to db
|
||||
|
||||
UPDATES can lead to deadlocks, hence we handle them separately
|
||||
|
||||
Args:
|
||||
prisma_client: PrismaClient object
|
||||
n_retry_times: int, number of retry times
|
||||
proxy_logging_obj: ProxyLogging object
|
||||
"""
|
||||
### UPDATE USER TABLE ###
|
||||
if len(prisma_client.user_list_transactons.keys()) > 0:
|
||||
for i in range(n_retry_times + 1):
|
||||
start_time = time.time()
|
||||
try:
|
||||
async with prisma_client.db.tx(
|
||||
timeout=timedelta(seconds=60)
|
||||
) as transaction:
|
||||
async with transaction.batch_() as batcher:
|
||||
for (
|
||||
user_id,
|
||||
response_cost,
|
||||
) in prisma_client.user_list_transactons.items():
|
||||
batcher.litellm_usertable.update_many(
|
||||
where={"user_id": user_id},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
)
|
||||
prisma_client.user_list_transactons = (
|
||||
{}
|
||||
) # Clear the remaining transactions after processing all batches in the loop.
|
||||
break
|
||||
except DB_CONNECTION_ERROR_TYPES as e:
|
||||
if (
|
||||
i >= n_retry_times
|
||||
): # If we've reached the maximum number of retries
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e,
|
||||
start_time=start_time,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
# Optionally, sleep for a bit before retrying
|
||||
await asyncio.sleep(2**i) # Exponential backoff
|
||||
except Exception as e:
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
|
||||
### UPDATE END-USER TABLE ###
|
||||
verbose_proxy_logger.debug(
|
||||
"End-User Spend transactions: {}".format(
|
||||
len(prisma_client.end_user_list_transactons.keys())
|
||||
)
|
||||
)
|
||||
if len(prisma_client.end_user_list_transactons.keys()) > 0:
|
||||
await ProxyUpdateSpend.update_end_user_spend(
|
||||
n_retry_times=n_retry_times,
|
||||
prisma_client=prisma_client,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
### UPDATE KEY TABLE ###
|
||||
verbose_proxy_logger.debug(
|
||||
"KEY Spend transactions: {}".format(
|
||||
len(prisma_client.key_list_transactons.keys())
|
||||
)
|
||||
)
|
||||
if len(prisma_client.key_list_transactons.keys()) > 0:
|
||||
for i in range(n_retry_times + 1):
|
||||
start_time = time.time()
|
||||
try:
|
||||
async with prisma_client.db.tx(
|
||||
timeout=timedelta(seconds=60)
|
||||
) as transaction:
|
||||
async with transaction.batch_() as batcher:
|
||||
for (
|
||||
token,
|
||||
response_cost,
|
||||
) in prisma_client.key_list_transactons.items():
|
||||
batcher.litellm_verificationtoken.update_many( # 'update_many' prevents error from being raised if no row exists
|
||||
where={"token": token},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
)
|
||||
prisma_client.key_list_transactons = (
|
||||
{}
|
||||
) # Clear the remaining transactions after processing all batches in the loop.
|
||||
break
|
||||
except DB_CONNECTION_ERROR_TYPES as e:
|
||||
if (
|
||||
i >= n_retry_times
|
||||
): # If we've reached the maximum number of retries
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e,
|
||||
start_time=start_time,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
# Optionally, sleep for a bit before retrying
|
||||
await asyncio.sleep(2**i) # Exponential backoff
|
||||
except Exception as e:
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
|
||||
### UPDATE TEAM TABLE ###
|
||||
verbose_proxy_logger.debug(
|
||||
"Team Spend transactions: {}".format(
|
||||
len(prisma_client.team_list_transactons.keys())
|
||||
)
|
||||
)
|
||||
if len(prisma_client.team_list_transactons.keys()) > 0:
|
||||
for i in range(n_retry_times + 1):
|
||||
start_time = time.time()
|
||||
try:
|
||||
async with prisma_client.db.tx(
|
||||
timeout=timedelta(seconds=60)
|
||||
) as transaction:
|
||||
async with transaction.batch_() as batcher:
|
||||
for (
|
||||
team_id,
|
||||
response_cost,
|
||||
) in prisma_client.team_list_transactons.items():
|
||||
verbose_proxy_logger.debug(
|
||||
"Updating spend for team id={} by {}".format(
|
||||
team_id, response_cost
|
||||
)
|
||||
)
|
||||
batcher.litellm_teamtable.update_many( # 'update_many' prevents error from being raised if no row exists
|
||||
where={"team_id": team_id},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
)
|
||||
prisma_client.team_list_transactons = (
|
||||
{}
|
||||
) # Clear the remaining transactions after processing all batches in the loop.
|
||||
break
|
||||
except DB_CONNECTION_ERROR_TYPES as e:
|
||||
if (
|
||||
i >= n_retry_times
|
||||
): # If we've reached the maximum number of retries
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e,
|
||||
start_time=start_time,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
# Optionally, sleep for a bit before retrying
|
||||
await asyncio.sleep(2**i) # Exponential backoff
|
||||
except Exception as e:
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
|
||||
### UPDATE TEAM Membership TABLE with spend ###
|
||||
if len(prisma_client.team_member_list_transactons.keys()) > 0:
|
||||
for i in range(n_retry_times + 1):
|
||||
start_time = time.time()
|
||||
try:
|
||||
async with prisma_client.db.tx(
|
||||
timeout=timedelta(seconds=60)
|
||||
) as transaction:
|
||||
async with transaction.batch_() as batcher:
|
||||
for (
|
||||
key,
|
||||
response_cost,
|
||||
) in prisma_client.team_member_list_transactons.items():
|
||||
# key is "team_id::<value>::user_id::<value>"
|
||||
team_id = key.split("::")[1]
|
||||
user_id = key.split("::")[3]
|
||||
|
||||
batcher.litellm_teammembership.update_many( # 'update_many' prevents error from being raised if no row exists
|
||||
where={"team_id": team_id, "user_id": user_id},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
)
|
||||
prisma_client.team_member_list_transactons = (
|
||||
{}
|
||||
) # Clear the remaining transactions after processing all batches in the loop.
|
||||
break
|
||||
except DB_CONNECTION_ERROR_TYPES as e:
|
||||
if (
|
||||
i >= n_retry_times
|
||||
): # If we've reached the maximum number of retries
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e,
|
||||
start_time=start_time,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
# Optionally, sleep for a bit before retrying
|
||||
await asyncio.sleep(2**i) # Exponential backoff
|
||||
except Exception as e:
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
|
||||
### UPDATE ORG TABLE ###
|
||||
if len(prisma_client.org_list_transactons.keys()) > 0:
|
||||
for i in range(n_retry_times + 1):
|
||||
start_time = time.time()
|
||||
try:
|
||||
async with prisma_client.db.tx(
|
||||
timeout=timedelta(seconds=60)
|
||||
) as transaction:
|
||||
async with transaction.batch_() as batcher:
|
||||
for (
|
||||
org_id,
|
||||
response_cost,
|
||||
) in prisma_client.org_list_transactons.items():
|
||||
batcher.litellm_organizationtable.update_many( # 'update_many' prevents error from being raised if no row exists
|
||||
where={"organization_id": org_id},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
)
|
||||
prisma_client.org_list_transactons = (
|
||||
{}
|
||||
) # Clear the remaining transactions after processing all batches in the loop.
|
||||
break
|
||||
except DB_CONNECTION_ERROR_TYPES as e:
|
||||
if (
|
||||
i >= n_retry_times
|
||||
): # If we've reached the maximum number of retries
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e,
|
||||
start_time=start_time,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
# Optionally, sleep for a bit before retrying
|
||||
await asyncio.sleep(2**i) # Exponential backoff
|
||||
except Exception as e:
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
|
|
|
@ -62,6 +62,7 @@ from litellm.proxy.db.create_views import (
|
|||
create_missing_views,
|
||||
should_create_missing_views,
|
||||
)
|
||||
from litellm.proxy.db.db_spend_update_writer import DBSpendUpdateWriter
|
||||
from litellm.proxy.db.log_db_metrics import log_db_metrics
|
||||
from litellm.proxy.db.prisma_client import PrismaWrapper
|
||||
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
|
||||
|
@ -2674,202 +2675,11 @@ async def update_spend( # noqa: PLR0915
|
|||
spend_logs: list,
|
||||
"""
|
||||
n_retry_times = 3
|
||||
i = None
|
||||
### UPDATE USER TABLE ###
|
||||
if len(prisma_client.user_list_transactons.keys()) > 0:
|
||||
for i in range(n_retry_times + 1):
|
||||
start_time = time.time()
|
||||
try:
|
||||
async with prisma_client.db.tx(
|
||||
timeout=timedelta(seconds=60)
|
||||
) as transaction:
|
||||
async with transaction.batch_() as batcher:
|
||||
for (
|
||||
user_id,
|
||||
response_cost,
|
||||
) in prisma_client.user_list_transactons.items():
|
||||
batcher.litellm_usertable.update_many(
|
||||
where={"user_id": user_id},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
)
|
||||
prisma_client.user_list_transactons = (
|
||||
{}
|
||||
) # Clear the remaining transactions after processing all batches in the loop.
|
||||
break
|
||||
except DB_CONNECTION_ERROR_TYPES as e:
|
||||
if i >= n_retry_times: # If we've reached the maximum number of retries
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
# Optionally, sleep for a bit before retrying
|
||||
await asyncio.sleep(2**i) # Exponential backoff
|
||||
except Exception as e:
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
|
||||
### UPDATE END-USER TABLE ###
|
||||
verbose_proxy_logger.debug(
|
||||
"End-User Spend transactions: {}".format(
|
||||
len(prisma_client.end_user_list_transactons.keys())
|
||||
)
|
||||
await DBSpendUpdateWriter.commit_update_transactions_to_db(
|
||||
prisma_client=prisma_client,
|
||||
n_retry_times=n_retry_times,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
if len(prisma_client.end_user_list_transactons.keys()) > 0:
|
||||
await ProxyUpdateSpend.update_end_user_spend(
|
||||
n_retry_times=n_retry_times,
|
||||
prisma_client=prisma_client,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
### UPDATE KEY TABLE ###
|
||||
verbose_proxy_logger.debug(
|
||||
"KEY Spend transactions: {}".format(
|
||||
len(prisma_client.key_list_transactons.keys())
|
||||
)
|
||||
)
|
||||
if len(prisma_client.key_list_transactons.keys()) > 0:
|
||||
for i in range(n_retry_times + 1):
|
||||
start_time = time.time()
|
||||
try:
|
||||
async with prisma_client.db.tx(
|
||||
timeout=timedelta(seconds=60)
|
||||
) as transaction:
|
||||
async with transaction.batch_() as batcher:
|
||||
for (
|
||||
token,
|
||||
response_cost,
|
||||
) in prisma_client.key_list_transactons.items():
|
||||
batcher.litellm_verificationtoken.update_many( # 'update_many' prevents error from being raised if no row exists
|
||||
where={"token": token},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
)
|
||||
prisma_client.key_list_transactons = (
|
||||
{}
|
||||
) # Clear the remaining transactions after processing all batches in the loop.
|
||||
break
|
||||
except DB_CONNECTION_ERROR_TYPES as e:
|
||||
if i >= n_retry_times: # If we've reached the maximum number of retries
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
# Optionally, sleep for a bit before retrying
|
||||
await asyncio.sleep(2**i) # Exponential backoff
|
||||
except Exception as e:
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
|
||||
### UPDATE TEAM TABLE ###
|
||||
verbose_proxy_logger.debug(
|
||||
"Team Spend transactions: {}".format(
|
||||
len(prisma_client.team_list_transactons.keys())
|
||||
)
|
||||
)
|
||||
if len(prisma_client.team_list_transactons.keys()) > 0:
|
||||
for i in range(n_retry_times + 1):
|
||||
start_time = time.time()
|
||||
try:
|
||||
async with prisma_client.db.tx(
|
||||
timeout=timedelta(seconds=60)
|
||||
) as transaction:
|
||||
async with transaction.batch_() as batcher:
|
||||
for (
|
||||
team_id,
|
||||
response_cost,
|
||||
) in prisma_client.team_list_transactons.items():
|
||||
verbose_proxy_logger.debug(
|
||||
"Updating spend for team id={} by {}".format(
|
||||
team_id, response_cost
|
||||
)
|
||||
)
|
||||
batcher.litellm_teamtable.update_many( # 'update_many' prevents error from being raised if no row exists
|
||||
where={"team_id": team_id},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
)
|
||||
prisma_client.team_list_transactons = (
|
||||
{}
|
||||
) # Clear the remaining transactions after processing all batches in the loop.
|
||||
break
|
||||
except DB_CONNECTION_ERROR_TYPES as e:
|
||||
if i >= n_retry_times: # If we've reached the maximum number of retries
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
# Optionally, sleep for a bit before retrying
|
||||
await asyncio.sleep(2**i) # Exponential backoff
|
||||
except Exception as e:
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
|
||||
### UPDATE TEAM Membership TABLE with spend ###
|
||||
if len(prisma_client.team_member_list_transactons.keys()) > 0:
|
||||
for i in range(n_retry_times + 1):
|
||||
start_time = time.time()
|
||||
try:
|
||||
async with prisma_client.db.tx(
|
||||
timeout=timedelta(seconds=60)
|
||||
) as transaction:
|
||||
async with transaction.batch_() as batcher:
|
||||
for (
|
||||
key,
|
||||
response_cost,
|
||||
) in prisma_client.team_member_list_transactons.items():
|
||||
# key is "team_id::<value>::user_id::<value>"
|
||||
team_id = key.split("::")[1]
|
||||
user_id = key.split("::")[3]
|
||||
|
||||
batcher.litellm_teammembership.update_many( # 'update_many' prevents error from being raised if no row exists
|
||||
where={"team_id": team_id, "user_id": user_id},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
)
|
||||
prisma_client.team_member_list_transactons = (
|
||||
{}
|
||||
) # Clear the remaining transactions after processing all batches in the loop.
|
||||
break
|
||||
except DB_CONNECTION_ERROR_TYPES as e:
|
||||
if i >= n_retry_times: # If we've reached the maximum number of retries
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
# Optionally, sleep for a bit before retrying
|
||||
await asyncio.sleep(2**i) # Exponential backoff
|
||||
except Exception as e:
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
|
||||
### UPDATE ORG TABLE ###
|
||||
if len(prisma_client.org_list_transactons.keys()) > 0:
|
||||
for i in range(n_retry_times + 1):
|
||||
start_time = time.time()
|
||||
try:
|
||||
async with prisma_client.db.tx(
|
||||
timeout=timedelta(seconds=60)
|
||||
) as transaction:
|
||||
async with transaction.batch_() as batcher:
|
||||
for (
|
||||
org_id,
|
||||
response_cost,
|
||||
) in prisma_client.org_list_transactons.items():
|
||||
batcher.litellm_organizationtable.update_many( # 'update_many' prevents error from being raised if no row exists
|
||||
where={"organization_id": org_id},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
)
|
||||
prisma_client.org_list_transactons = (
|
||||
{}
|
||||
) # Clear the remaining transactions after processing all batches in the loop.
|
||||
break
|
||||
except DB_CONNECTION_ERROR_TYPES as e:
|
||||
if i >= n_retry_times: # If we've reached the maximum number of retries
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
# Optionally, sleep for a bit before retrying
|
||||
await asyncio.sleep(2**i) # Exponential backoff
|
||||
except Exception as e:
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
|
||||
### UPDATE SPEND LOGS ###
|
||||
verbose_proxy_logger.debug(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue