fix update_database helper on db_spend_update_writer

This commit is contained in:
Ishaan Jaff 2025-03-31 19:01:00 -07:00
parent bcd49204f6
commit 3e16a51ca6
4 changed files with 73 additions and 126 deletions

View file

@ -25,6 +25,7 @@ from litellm.proxy._types import (
) )
from litellm.proxy.db.pod_lock_manager import PodLockManager from litellm.proxy.db.pod_lock_manager import PodLockManager
from litellm.proxy.db.redis_update_buffer import RedisUpdateBuffer from litellm.proxy.db.redis_update_buffer import RedisUpdateBuffer
from litellm.proxy.db.spend_update_queue import SpendUpdateQueue
if TYPE_CHECKING: if TYPE_CHECKING:
from litellm.proxy.utils import PrismaClient, ProxyLogging from litellm.proxy.utils import PrismaClient, ProxyLogging
@ -48,10 +49,11 @@ class DBSpendUpdateWriter:
self.redis_cache = redis_cache self.redis_cache = redis_cache
self.redis_update_buffer = RedisUpdateBuffer(redis_cache=self.redis_cache) self.redis_update_buffer = RedisUpdateBuffer(redis_cache=self.redis_cache)
self.pod_lock_manager = PodLockManager(cronjob_id=DB_SPEND_UPDATE_JOB_NAME) self.pod_lock_manager = PodLockManager(cronjob_id=DB_SPEND_UPDATE_JOB_NAME)
self.spend_update_queue = SpendUpdateQueue()
@staticmethod
async def update_database( async def update_database(
# LiteLLM management object fields # LiteLLM management object fields
self,
token: Optional[str], token: Optional[str],
user_id: Optional[str], user_id: Optional[str],
end_user_id: Optional[str], end_user_id: Optional[str],
@ -84,7 +86,7 @@ class DBSpendUpdateWriter:
hashed_token = token hashed_token = token
asyncio.create_task( asyncio.create_task(
DBSpendUpdateWriter._update_user_db( self._update_user_db(
response_cost=response_cost, response_cost=response_cost,
user_id=user_id, user_id=user_id,
prisma_client=prisma_client, prisma_client=prisma_client,
@ -94,14 +96,14 @@ class DBSpendUpdateWriter:
) )
) )
asyncio.create_task( asyncio.create_task(
DBSpendUpdateWriter._update_key_db( self._update_key_db(
response_cost=response_cost, response_cost=response_cost,
hashed_token=hashed_token, hashed_token=hashed_token,
prisma_client=prisma_client, prisma_client=prisma_client,
) )
) )
asyncio.create_task( asyncio.create_task(
DBSpendUpdateWriter._update_team_db( self._update_team_db(
response_cost=response_cost, response_cost=response_cost,
team_id=team_id, team_id=team_id,
user_id=user_id, user_id=user_id,
@ -109,7 +111,7 @@ class DBSpendUpdateWriter:
) )
) )
asyncio.create_task( asyncio.create_task(
DBSpendUpdateWriter._update_org_db( self._update_org_db(
response_cost=response_cost, response_cost=response_cost,
org_id=org_id, org_id=org_id,
prisma_client=prisma_client, prisma_client=prisma_client,
@ -135,56 +137,8 @@ class DBSpendUpdateWriter:
f"Error updating Prisma database: {traceback.format_exc()}" f"Error updating Prisma database: {traceback.format_exc()}"
) )
@staticmethod
async def _update_transaction_list(
response_cost: Optional[float],
entity_id: Optional[str],
transaction_list: dict,
entity_type: Litellm_EntityType,
debug_msg: Optional[str] = None,
prisma_client: Optional[PrismaClient] = None,
) -> bool:
"""
Common helper method to update a transaction list for an entity
Args:
response_cost: The cost to add
entity_id: The ID of the entity to update
transaction_list: The transaction list dictionary to update
entity_type: The type of entity (from EntityType enum)
debug_msg: Optional custom debug message
Returns:
bool: True if update happened, False otherwise
"""
try:
if debug_msg:
verbose_proxy_logger.debug(debug_msg)
else:
verbose_proxy_logger.debug(
f"adding spend to {entity_type.value} db. Response cost: {response_cost}. {entity_type.value}_id: {entity_id}."
)
if prisma_client is None:
return False
if entity_id is None:
verbose_proxy_logger.debug(
f"track_cost_callback: {entity_type.value}_id is None. Not tracking spend for {entity_type.value}"
)
return False
transaction_list[entity_id] = response_cost + transaction_list.get(
entity_id, 0
)
return True
except Exception as e:
verbose_proxy_logger.info(
f"Update {entity_type.value.capitalize()} DB failed to execute - {str(e)}\n{traceback.format_exc()}"
)
raise e
@staticmethod
async def _update_key_db( async def _update_key_db(
self,
response_cost: Optional[float], response_cost: Optional[float],
hashed_token: Optional[str], hashed_token: Optional[str],
prisma_client: Optional[PrismaClient], prisma_client: Optional[PrismaClient],
@ -193,13 +147,12 @@ class DBSpendUpdateWriter:
if hashed_token is None or prisma_client is None: if hashed_token is None or prisma_client is None:
return return
await DBSpendUpdateWriter._update_transaction_list( await self.spend_update_queue.add_update(
response_cost=response_cost, update={
entity_id=hashed_token, "entity_type": Litellm_EntityType.KEY.value,
transaction_list=prisma_client.key_list_transactions, "entity_id": hashed_token,
entity_type=Litellm_EntityType.KEY, "amount": response_cost,
debug_msg=f"adding spend to key db. Response cost: {response_cost}. Token: {hashed_token}.", }
prisma_client=prisma_client,
) )
except Exception as e: except Exception as e:
verbose_proxy_logger.exception( verbose_proxy_logger.exception(
@ -207,8 +160,8 @@ class DBSpendUpdateWriter:
) )
raise e raise e
@staticmethod
async def _update_user_db( async def _update_user_db(
self,
response_cost: Optional[float], response_cost: Optional[float],
user_id: Optional[str], user_id: Optional[str],
prisma_client: Optional[PrismaClient], prisma_client: Optional[PrismaClient],
@ -234,21 +187,21 @@ class DBSpendUpdateWriter:
for _id in user_ids: for _id in user_ids:
if _id is not None: if _id is not None:
await DBSpendUpdateWriter._update_transaction_list( await self.spend_update_queue.add_update(
response_cost=response_cost, update={
entity_id=_id, "entity_type": Litellm_EntityType.USER.value,
transaction_list=prisma_client.user_list_transactions, "entity_id": _id,
entity_type=Litellm_EntityType.USER, "amount": response_cost,
prisma_client=prisma_client, }
) )
if end_user_id is not None: if end_user_id is not None:
await DBSpendUpdateWriter._update_transaction_list( await self.spend_update_queue.add_update(
response_cost=response_cost, update={
entity_id=end_user_id, "entity_type": Litellm_EntityType.END_USER.value,
transaction_list=prisma_client.end_user_list_transactions, "entity_id": end_user_id,
entity_type=Litellm_EntityType.END_USER, "amount": response_cost,
prisma_client=prisma_client, }
) )
except Exception as e: except Exception as e:
verbose_proxy_logger.info( verbose_proxy_logger.info(
@ -256,8 +209,8 @@ class DBSpendUpdateWriter:
+ f"Update User DB call failed to execute {str(e)}\n{traceback.format_exc()}" + f"Update User DB call failed to execute {str(e)}\n{traceback.format_exc()}"
) )
@staticmethod
async def _update_team_db( async def _update_team_db(
self,
response_cost: Optional[float], response_cost: Optional[float],
team_id: Optional[str], team_id: Optional[str],
user_id: Optional[str], user_id: Optional[str],
@ -270,12 +223,12 @@ class DBSpendUpdateWriter:
) )
return return
await DBSpendUpdateWriter._update_transaction_list( await self.spend_update_queue.add_update(
response_cost=response_cost, update={
entity_id=team_id, "entity_type": Litellm_EntityType.TEAM.value,
transaction_list=prisma_client.team_list_transactions, "entity_id": team_id,
entity_type=Litellm_EntityType.TEAM, "amount": response_cost,
prisma_client=prisma_client, }
) )
try: try:
@ -283,12 +236,12 @@ class DBSpendUpdateWriter:
if user_id is not None: if user_id is not None:
# key is "team_id::<value>::user_id::<value>" # key is "team_id::<value>::user_id::<value>"
team_member_key = f"team_id::{team_id}::user_id::{user_id}" team_member_key = f"team_id::{team_id}::user_id::{user_id}"
await DBSpendUpdateWriter._update_transaction_list( await self.spend_update_queue.add_update(
response_cost=response_cost, update={
entity_id=team_member_key, "entity_type": Litellm_EntityType.TEAM_MEMBER.value,
transaction_list=prisma_client.team_member_list_transactions, "entity_id": team_member_key,
entity_type=Litellm_EntityType.TEAM_MEMBER, "amount": response_cost,
prisma_client=prisma_client, }
) )
except Exception: except Exception:
pass pass
@ -298,8 +251,8 @@ class DBSpendUpdateWriter:
) )
raise e raise e
@staticmethod
async def _update_org_db( async def _update_org_db(
self,
response_cost: Optional[float], response_cost: Optional[float],
org_id: Optional[str], org_id: Optional[str],
prisma_client: Optional[PrismaClient], prisma_client: Optional[PrismaClient],
@ -311,12 +264,12 @@ class DBSpendUpdateWriter:
) )
return return
await DBSpendUpdateWriter._update_transaction_list( await self.spend_update_queue.add_update(
response_cost=response_cost, update={
entity_id=org_id, "entity_type": Litellm_EntityType.ORGANIZATION.value,
transaction_list=prisma_client.org_list_transactions, "entity_id": org_id,
entity_type=Litellm_EntityType.ORGANIZATION, "amount": response_cost,
prisma_client=prisma_client, }
) )
except Exception as e: except Exception as e:
verbose_proxy_logger.info( verbose_proxy_logger.info(
@ -435,7 +388,7 @@ class DBSpendUpdateWriter:
- Only 1 pod will commit to db at a time (based on if it can acquire the lock over writing to DB) - Only 1 pod will commit to db at a time (based on if it can acquire the lock over writing to DB)
""" """
await self.redis_update_buffer.store_in_memory_spend_updates_in_redis( await self.redis_update_buffer.store_in_memory_spend_updates_in_redis(
prisma_client=prisma_client, spend_update_queue=self.spend_update_queue,
) )
# Only commit from redis to db if this pod is the leader # Only commit from redis to db if this pod is the leader
@ -447,7 +400,7 @@ class DBSpendUpdateWriter:
await self.redis_update_buffer.get_all_update_transactions_from_redis_buffer() await self.redis_update_buffer.get_all_update_transactions_from_redis_buffer()
) )
if db_spend_update_transactions is not None: if db_spend_update_transactions is not None:
await DBSpendUpdateWriter._commit_spend_updates_to_db( await self._commit_spend_updates_to_db(
prisma_client=prisma_client, prisma_client=prisma_client,
n_retry_times=n_retry_times, n_retry_times=n_retry_times,
proxy_logging_obj=proxy_logging_obj, proxy_logging_obj=proxy_logging_obj,
@ -471,23 +424,26 @@ class DBSpendUpdateWriter:
Note: This flow causes Deadlocks in production (1K RPS+). Use self._commit_spend_updates_to_db_with_redis() instead if you expect 1K+ RPS. Note: This flow causes Deadlocks in production (1K RPS+). Use self._commit_spend_updates_to_db_with_redis() instead if you expect 1K+ RPS.
""" """
db_spend_update_transactions = DBSpendUpdateTransactions( aggregated_updates = (
user_list_transactions=prisma_client.user_list_transactions, await self.spend_update_queue.flush_and_get_all_aggregated_updates_by_entity_type()
end_user_list_transactions=prisma_client.end_user_list_transactions,
key_list_transactions=prisma_client.key_list_transactions,
team_list_transactions=prisma_client.team_list_transactions,
team_member_list_transactions=prisma_client.team_member_list_transactions,
org_list_transactions=prisma_client.org_list_transactions,
) )
await DBSpendUpdateWriter._commit_spend_updates_to_db( db_spend_update_transactions = DBSpendUpdateTransactions(
user_list_transactions=aggregated_updates.get("user", {}),
end_user_list_transactions=aggregated_updates.get("end_user", {}),
key_list_transactions=aggregated_updates.get("key", {}),
team_list_transactions=aggregated_updates.get("team", {}),
team_member_list_transactions=aggregated_updates.get("team_member", {}),
org_list_transactions=aggregated_updates.get("organization", {}),
)
await self._commit_spend_updates_to_db(
prisma_client=prisma_client, prisma_client=prisma_client,
n_retry_times=n_retry_times, n_retry_times=n_retry_times,
proxy_logging_obj=proxy_logging_obj, proxy_logging_obj=proxy_logging_obj,
db_spend_update_transactions=db_spend_update_transactions, db_spend_update_transactions=db_spend_update_transactions,
) )
@staticmethod
async def _commit_spend_updates_to_db( # noqa: PLR0915 async def _commit_spend_updates_to_db( # noqa: PLR0915
self,
prisma_client: PrismaClient, prisma_client: PrismaClient,
n_retry_times: int, n_retry_times: int,
proxy_logging_obj: ProxyLogging, proxy_logging_obj: ProxyLogging,
@ -526,9 +482,6 @@ class DBSpendUpdateWriter:
where={"user_id": user_id}, where={"user_id": user_id},
data={"spend": {"increment": response_cost}}, data={"spend": {"increment": response_cost}},
) )
prisma_client.user_list_transactions = (
{}
) # Clear the remaining transactions after processing all batches in the loop.
break break
except DB_CONNECTION_ERROR_TYPES as e: except DB_CONNECTION_ERROR_TYPES as e:
if ( if (
@ -583,9 +536,6 @@ class DBSpendUpdateWriter:
where={"token": token}, where={"token": token},
data={"spend": {"increment": response_cost}}, data={"spend": {"increment": response_cost}},
) )
prisma_client.key_list_transactions = (
{}
) # Clear the remaining transactions after processing all batches in the loop.
break break
except DB_CONNECTION_ERROR_TYPES as e: except DB_CONNECTION_ERROR_TYPES as e:
if ( if (
@ -632,9 +582,6 @@ class DBSpendUpdateWriter:
where={"team_id": team_id}, where={"team_id": team_id},
data={"spend": {"increment": response_cost}}, data={"spend": {"increment": response_cost}},
) )
prisma_client.team_list_transactions = (
{}
) # Clear the remaining transactions after processing all batches in the loop.
break break
except DB_CONNECTION_ERROR_TYPES as e: except DB_CONNECTION_ERROR_TYPES as e:
if ( if (
@ -684,9 +631,6 @@ class DBSpendUpdateWriter:
where={"team_id": team_id, "user_id": user_id}, where={"team_id": team_id, "user_id": user_id},
data={"spend": {"increment": response_cost}}, data={"spend": {"increment": response_cost}},
) )
prisma_client.team_member_list_transactions = (
{}
) # Clear the remaining transactions after processing all batches in the loop.
break break
except DB_CONNECTION_ERROR_TYPES as e: except DB_CONNECTION_ERROR_TYPES as e:
if ( if (
@ -725,9 +669,6 @@ class DBSpendUpdateWriter:
where={"organization_id": org_id}, where={"organization_id": org_id},
data={"spend": {"increment": response_cost}}, data={"spend": {"increment": response_cost}},
) )
prisma_client.org_list_transactions = (
{}
) # Clear the remaining transactions after processing all batches in the loop.
break break
except DB_CONNECTION_ERROR_TYPES as e: except DB_CONNECTION_ERROR_TYPES as e:
if ( if (

View file

@ -33,7 +33,6 @@ class RedisUpdateBuffer:
redis_cache: Optional[RedisCache] = None, redis_cache: Optional[RedisCache] = None,
): ):
self.redis_cache = redis_cache self.redis_cache = redis_cache
self.spend_update_queue = SpendUpdateQueue()
@staticmethod @staticmethod
def _should_commit_spend_updates_to_redis() -> bool: def _should_commit_spend_updates_to_redis() -> bool:
@ -56,6 +55,7 @@ class RedisUpdateBuffer:
async def store_in_memory_spend_updates_in_redis( async def store_in_memory_spend_updates_in_redis(
self, self,
spend_update_queue: SpendUpdateQueue,
): ):
""" """
Stores the in-memory spend updates to Redis Stores the in-memory spend updates to Redis
@ -81,9 +81,9 @@ class RedisUpdateBuffer:
return return
aggregated_updates = ( aggregated_updates = (
await self.spend_update_queue.flush_and_get_all_aggregated_updates_by_entity_type() await spend_update_queue.flush_and_get_all_aggregated_updates_by_entity_type()
) )
verbose_proxy_logger.debug("ALL AGGREGATED UPDATES: ", aggregated_updates) verbose_proxy_logger.debug("ALL AGGREGATED UPDATES: %s", aggregated_updates)
db_spend_update_transactions: DBSpendUpdateTransactions = ( db_spend_update_transactions: DBSpendUpdateTransactions = (
DBSpendUpdateTransactions( DBSpendUpdateTransactions(
@ -92,7 +92,7 @@ class RedisUpdateBuffer:
key_list_transactions=aggregated_updates.get("key", {}), key_list_transactions=aggregated_updates.get("key", {}),
team_list_transactions=aggregated_updates.get("team", {}), team_list_transactions=aggregated_updates.get("team", {}),
team_member_list_transactions=aggregated_updates.get("team_member", {}), team_member_list_transactions=aggregated_updates.get("team_member", {}),
org_list_transactions=aggregated_updates.get("org", {}), org_list_transactions=aggregated_updates.get("organization", {}),
) )
) )

View file

@ -1,6 +1,8 @@
import asyncio import asyncio
from typing import TYPE_CHECKING, Any, Dict, List from typing import TYPE_CHECKING, Any, Dict, List
from litellm._logging import verbose_proxy_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from litellm.proxy.utils import PrismaClient from litellm.proxy.utils import PrismaClient
else: else:
@ -21,6 +23,7 @@ class SpendUpdateQueue:
async def add_update(self, update: Dict[str, Any]) -> None: async def add_update(self, update: Dict[str, Any]) -> None:
"""Enqueue an update. Each update might be a dict like {'entity_type': 'user', 'entity_id': '123', 'amount': 1.2}.""" """Enqueue an update. Each update might be a dict like {'entity_type': 'user', 'entity_id': '123', 'amount': 1.2}."""
verbose_proxy_logger.debug("Adding update to queue: %s", update)
await self.update_queue.put(update) await self.update_queue.put(update)
async def flush_all_updates_from_in_memory_queue(self) -> List[Dict[str, Any]]: async def flush_all_updates_from_in_memory_queue(self) -> List[Dict[str, Any]]:
@ -35,6 +38,7 @@ class SpendUpdateQueue:
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Flush all updates from the queue and return all updates aggregated by entity type.""" """Flush all updates from the queue and return all updates aggregated by entity type."""
updates = await self.flush_all_updates_from_in_memory_queue() updates = await self.flush_all_updates_from_in_memory_queue()
verbose_proxy_logger.debug("Aggregating updates by entity type: %s", updates)
return self.aggregate_updates_by_entity_type(updates) return self.aggregate_updates_by_entity_type(updates)
def aggregate_updates_by_entity_type( def aggregate_updates_by_entity_type(

View file

@ -37,6 +37,8 @@ class _ProxyDBLogger(CustomLogger):
if _ProxyDBLogger._should_track_errors_in_db() is False: if _ProxyDBLogger._should_track_errors_in_db() is False:
return return
from litellm.proxy.proxy_server import proxy_logging_obj
_metadata = dict( _metadata = dict(
StandardLoggingUserAPIKeyMetadata( StandardLoggingUserAPIKeyMetadata(
user_api_key_hash=user_api_key_dict.api_key, user_api_key_hash=user_api_key_dict.api_key,
@ -66,7 +68,7 @@ class _ProxyDBLogger(CustomLogger):
request_data.get("proxy_server_request") or {} request_data.get("proxy_server_request") or {}
) )
request_data["litellm_params"]["metadata"] = existing_metadata request_data["litellm_params"]["metadata"] = existing_metadata
await DBSpendUpdateWriter.update_database( await proxy_logging_obj.db_spend_update_writer.update_database(
token=user_api_key_dict.api_key, token=user_api_key_dict.api_key,
response_cost=0.0, response_cost=0.0,
user_id=user_api_key_dict.user_id, user_id=user_api_key_dict.user_id,
@ -136,7 +138,7 @@ class _ProxyDBLogger(CustomLogger):
end_user_id=end_user_id, end_user_id=end_user_id,
): ):
## UPDATE DATABASE ## UPDATE DATABASE
await DBSpendUpdateWriter.update_database( await proxy_logging_obj.db_spend_update_writer.update_database(
token=user_api_key, token=user_api_key,
response_cost=response_cost, response_cost=response_cost,
user_id=user_id, user_id=user_id,