diff --git a/litellm/proxy/db/db_spend_update_writer.py b/litellm/proxy/db/db_spend_update_writer.py index f46b03b57a..6b5719acbb 100644 --- a/litellm/proxy/db/db_spend_update_writer.py +++ b/litellm/proxy/db/db_spend_update_writer.py @@ -25,6 +25,7 @@ from litellm.proxy._types import ( ) from litellm.proxy.db.pod_lock_manager import PodLockManager from litellm.proxy.db.redis_update_buffer import RedisUpdateBuffer +from litellm.proxy.db.spend_update_queue import SpendUpdateQueue if TYPE_CHECKING: from litellm.proxy.utils import PrismaClient, ProxyLogging @@ -48,10 +49,11 @@ class DBSpendUpdateWriter: self.redis_cache = redis_cache self.redis_update_buffer = RedisUpdateBuffer(redis_cache=self.redis_cache) self.pod_lock_manager = PodLockManager(cronjob_id=DB_SPEND_UPDATE_JOB_NAME) + self.spend_update_queue = SpendUpdateQueue() - @staticmethod async def update_database( # LiteLLM management object fields + self, token: Optional[str], user_id: Optional[str], end_user_id: Optional[str], @@ -84,7 +86,7 @@ class DBSpendUpdateWriter: hashed_token = token asyncio.create_task( - DBSpendUpdateWriter._update_user_db( + self._update_user_db( response_cost=response_cost, user_id=user_id, prisma_client=prisma_client, @@ -94,14 +96,14 @@ class DBSpendUpdateWriter: ) ) asyncio.create_task( - DBSpendUpdateWriter._update_key_db( + self._update_key_db( response_cost=response_cost, hashed_token=hashed_token, prisma_client=prisma_client, ) ) asyncio.create_task( - DBSpendUpdateWriter._update_team_db( + self._update_team_db( response_cost=response_cost, team_id=team_id, user_id=user_id, @@ -109,7 +111,7 @@ class DBSpendUpdateWriter: ) ) asyncio.create_task( - DBSpendUpdateWriter._update_org_db( + self._update_org_db( response_cost=response_cost, org_id=org_id, prisma_client=prisma_client, @@ -135,56 +137,8 @@ class DBSpendUpdateWriter: f"Error updating Prisma database: {traceback.format_exc()}" ) - @staticmethod - async def _update_transaction_list( - response_cost: Optional[float], - entity_id: Optional[str], - transaction_list: dict, - entity_type: Litellm_EntityType, - debug_msg: Optional[str] = None, - prisma_client: Optional[PrismaClient] = None, - ) -> bool: - """ - Common helper method to update a transaction list for an entity - - Args: - response_cost: The cost to add - entity_id: The ID of the entity to update - transaction_list: The transaction list dictionary to update - entity_type: The type of entity (from EntityType enum) - debug_msg: Optional custom debug message - - Returns: - bool: True if update happened, False otherwise - """ - try: - if debug_msg: - verbose_proxy_logger.debug(debug_msg) - else: - verbose_proxy_logger.debug( - f"adding spend to {entity_type.value} db. Response cost: {response_cost}. {entity_type.value}_id: {entity_id}." - ) - if prisma_client is None: - return False - - if entity_id is None: - verbose_proxy_logger.debug( - f"track_cost_callback: {entity_type.value}_id is None. Not tracking spend for {entity_type.value}" - ) - return False - transaction_list[entity_id] = response_cost + transaction_list.get( - entity_id, 0 - ) - return True - - except Exception as e: - verbose_proxy_logger.info( - f"Update {entity_type.value.capitalize()} DB failed to execute - {str(e)}\n{traceback.format_exc()}" - ) - raise e - - @staticmethod async def _update_key_db( + self, response_cost: Optional[float], hashed_token: Optional[str], prisma_client: Optional[PrismaClient], @@ -193,13 +147,12 @@ class DBSpendUpdateWriter: if hashed_token is None or prisma_client is None: return - await DBSpendUpdateWriter._update_transaction_list( - response_cost=response_cost, - entity_id=hashed_token, - transaction_list=prisma_client.key_list_transactions, - entity_type=Litellm_EntityType.KEY, - debug_msg=f"adding spend to key db. Response cost: {response_cost}. Token: {hashed_token}.", - prisma_client=prisma_client, + await self.spend_update_queue.add_update( + update={ + "entity_type": Litellm_EntityType.KEY.value, + "entity_id": hashed_token, + "amount": response_cost, + } ) except Exception as e: verbose_proxy_logger.exception( @@ -207,8 +160,8 @@ class DBSpendUpdateWriter: ) raise e - @staticmethod async def _update_user_db( + self, response_cost: Optional[float], user_id: Optional[str], prisma_client: Optional[PrismaClient], @@ -234,21 +187,21 @@ class DBSpendUpdateWriter: for _id in user_ids: if _id is not None: - await DBSpendUpdateWriter._update_transaction_list( - response_cost=response_cost, - entity_id=_id, - transaction_list=prisma_client.user_list_transactions, - entity_type=Litellm_EntityType.USER, - prisma_client=prisma_client, + await self.spend_update_queue.add_update( + update={ + "entity_type": Litellm_EntityType.USER.value, + "entity_id": _id, + "amount": response_cost, + } ) if end_user_id is not None: - await DBSpendUpdateWriter._update_transaction_list( - response_cost=response_cost, - entity_id=end_user_id, - transaction_list=prisma_client.end_user_list_transactions, - entity_type=Litellm_EntityType.END_USER, - prisma_client=prisma_client, + await self.spend_update_queue.add_update( + update={ + "entity_type": Litellm_EntityType.END_USER.value, + "entity_id": end_user_id, + "amount": response_cost, + } ) except Exception as e: verbose_proxy_logger.info( @@ -256,8 +209,8 @@ class DBSpendUpdateWriter: + f"Update User DB call failed to execute {str(e)}\n{traceback.format_exc()}" ) - @staticmethod async def _update_team_db( + self, response_cost: Optional[float], team_id: Optional[str], user_id: Optional[str], @@ -270,12 +223,12 @@ class DBSpendUpdateWriter: ) return - await DBSpendUpdateWriter._update_transaction_list( - response_cost=response_cost, - entity_id=team_id, - transaction_list=prisma_client.team_list_transactions, - entity_type=Litellm_EntityType.TEAM, - prisma_client=prisma_client, + await self.spend_update_queue.add_update( + update={ + "entity_type": Litellm_EntityType.TEAM.value, + "entity_id": team_id, + "amount": response_cost, + } ) try: @@ -283,12 +236,12 @@ class DBSpendUpdateWriter: if user_id is not None: # key is "team_id::::user_id::" team_member_key = f"team_id::{team_id}::user_id::{user_id}" - await DBSpendUpdateWriter._update_transaction_list( - response_cost=response_cost, - entity_id=team_member_key, - transaction_list=prisma_client.team_member_list_transactions, - entity_type=Litellm_EntityType.TEAM_MEMBER, - prisma_client=prisma_client, + await self.spend_update_queue.add_update( + update={ + "entity_type": Litellm_EntityType.TEAM_MEMBER.value, + "entity_id": team_member_key, + "amount": response_cost, + } ) except Exception: pass @@ -298,8 +251,8 @@ class DBSpendUpdateWriter: ) raise e - @staticmethod async def _update_org_db( + self, response_cost: Optional[float], org_id: Optional[str], prisma_client: Optional[PrismaClient], @@ -311,12 +264,12 @@ class DBSpendUpdateWriter: ) return - await DBSpendUpdateWriter._update_transaction_list( - response_cost=response_cost, - entity_id=org_id, - transaction_list=prisma_client.org_list_transactions, - entity_type=Litellm_EntityType.ORGANIZATION, - prisma_client=prisma_client, + await self.spend_update_queue.add_update( + update={ + "entity_type": Litellm_EntityType.ORGANIZATION.value, + "entity_id": org_id, + "amount": response_cost, + } ) except Exception as e: verbose_proxy_logger.info( @@ -435,7 +388,7 @@ class DBSpendUpdateWriter: - Only 1 pod will commit to db at a time (based on if it can acquire the lock over writing to DB) """ await self.redis_update_buffer.store_in_memory_spend_updates_in_redis( - prisma_client=prisma_client, + spend_update_queue=self.spend_update_queue, ) # Only commit from redis to db if this pod is the leader @@ -447,7 +400,7 @@ class DBSpendUpdateWriter: await self.redis_update_buffer.get_all_update_transactions_from_redis_buffer() ) if db_spend_update_transactions is not None: - await DBSpendUpdateWriter._commit_spend_updates_to_db( + await self._commit_spend_updates_to_db( prisma_client=prisma_client, n_retry_times=n_retry_times, proxy_logging_obj=proxy_logging_obj, @@ -471,23 +424,26 @@ class DBSpendUpdateWriter: Note: This flow causes Deadlocks in production (1K RPS+). Use self._commit_spend_updates_to_db_with_redis() instead if you expect 1K+ RPS. """ - db_spend_update_transactions = DBSpendUpdateTransactions( - user_list_transactions=prisma_client.user_list_transactions, - end_user_list_transactions=prisma_client.end_user_list_transactions, - key_list_transactions=prisma_client.key_list_transactions, - team_list_transactions=prisma_client.team_list_transactions, - team_member_list_transactions=prisma_client.team_member_list_transactions, - org_list_transactions=prisma_client.org_list_transactions, + aggregated_updates = ( + await self.spend_update_queue.flush_and_get_all_aggregated_updates_by_entity_type() ) - await DBSpendUpdateWriter._commit_spend_updates_to_db( + db_spend_update_transactions = DBSpendUpdateTransactions( + user_list_transactions=aggregated_updates.get("user", {}), + end_user_list_transactions=aggregated_updates.get("end_user", {}), + key_list_transactions=aggregated_updates.get("key", {}), + team_list_transactions=aggregated_updates.get("team", {}), + team_member_list_transactions=aggregated_updates.get("team_member", {}), + org_list_transactions=aggregated_updates.get("organization", {}), + ) + await self._commit_spend_updates_to_db( prisma_client=prisma_client, n_retry_times=n_retry_times, proxy_logging_obj=proxy_logging_obj, db_spend_update_transactions=db_spend_update_transactions, ) - @staticmethod async def _commit_spend_updates_to_db( # noqa: PLR0915 + self, prisma_client: PrismaClient, n_retry_times: int, proxy_logging_obj: ProxyLogging, @@ -526,9 +482,6 @@ class DBSpendUpdateWriter: where={"user_id": user_id}, data={"spend": {"increment": response_cost}}, ) - prisma_client.user_list_transactions = ( - {} - ) # Clear the remaining transactions after processing all batches in the loop. break except DB_CONNECTION_ERROR_TYPES as e: if ( @@ -583,9 +536,6 @@ class DBSpendUpdateWriter: where={"token": token}, data={"spend": {"increment": response_cost}}, ) - prisma_client.key_list_transactions = ( - {} - ) # Clear the remaining transactions after processing all batches in the loop. break except DB_CONNECTION_ERROR_TYPES as e: if ( @@ -632,9 +582,6 @@ class DBSpendUpdateWriter: where={"team_id": team_id}, data={"spend": {"increment": response_cost}}, ) - prisma_client.team_list_transactions = ( - {} - ) # Clear the remaining transactions after processing all batches in the loop. break except DB_CONNECTION_ERROR_TYPES as e: if ( @@ -684,9 +631,6 @@ class DBSpendUpdateWriter: where={"team_id": team_id, "user_id": user_id}, data={"spend": {"increment": response_cost}}, ) - prisma_client.team_member_list_transactions = ( - {} - ) # Clear the remaining transactions after processing all batches in the loop. break except DB_CONNECTION_ERROR_TYPES as e: if ( @@ -725,9 +669,6 @@ class DBSpendUpdateWriter: where={"organization_id": org_id}, data={"spend": {"increment": response_cost}}, ) - prisma_client.org_list_transactions = ( - {} - ) # Clear the remaining transactions after processing all batches in the loop. break except DB_CONNECTION_ERROR_TYPES as e: if ( diff --git a/litellm/proxy/db/redis_update_buffer.py b/litellm/proxy/db/redis_update_buffer.py index 7370b0ed08..0dfaa72a16 100644 --- a/litellm/proxy/db/redis_update_buffer.py +++ b/litellm/proxy/db/redis_update_buffer.py @@ -33,7 +33,6 @@ class RedisUpdateBuffer: redis_cache: Optional[RedisCache] = None, ): self.redis_cache = redis_cache - self.spend_update_queue = SpendUpdateQueue() @staticmethod def _should_commit_spend_updates_to_redis() -> bool: @@ -56,6 +55,7 @@ class RedisUpdateBuffer: async def store_in_memory_spend_updates_in_redis( self, + spend_update_queue: SpendUpdateQueue, ): """ Stores the in-memory spend updates to Redis @@ -81,9 +81,9 @@ class RedisUpdateBuffer: return aggregated_updates = ( - await self.spend_update_queue.flush_and_get_all_aggregated_updates_by_entity_type() + await spend_update_queue.flush_and_get_all_aggregated_updates_by_entity_type() ) - verbose_proxy_logger.debug("ALL AGGREGATED UPDATES: ", aggregated_updates) + verbose_proxy_logger.debug("ALL AGGREGATED UPDATES: %s", aggregated_updates) db_spend_update_transactions: DBSpendUpdateTransactions = ( DBSpendUpdateTransactions( @@ -92,7 +92,7 @@ class RedisUpdateBuffer: key_list_transactions=aggregated_updates.get("key", {}), team_list_transactions=aggregated_updates.get("team", {}), team_member_list_transactions=aggregated_updates.get("team_member", {}), - org_list_transactions=aggregated_updates.get("org", {}), + org_list_transactions=aggregated_updates.get("organization", {}), ) ) diff --git a/litellm/proxy/db/spend_update_queue.py b/litellm/proxy/db/spend_update_queue.py index 778361be84..2462e6a547 100644 --- a/litellm/proxy/db/spend_update_queue.py +++ b/litellm/proxy/db/spend_update_queue.py @@ -1,6 +1,8 @@ import asyncio from typing import TYPE_CHECKING, Any, Dict, List +from litellm._logging import verbose_proxy_logger + if TYPE_CHECKING: from litellm.proxy.utils import PrismaClient else: @@ -21,6 +23,7 @@ class SpendUpdateQueue: async def add_update(self, update: Dict[str, Any]) -> None: """Enqueue an update. Each update might be a dict like {'entity_type': 'user', 'entity_id': '123', 'amount': 1.2}.""" + verbose_proxy_logger.debug("Adding update to queue: %s", update) await self.update_queue.put(update) async def flush_all_updates_from_in_memory_queue(self) -> List[Dict[str, Any]]: @@ -35,6 +38,7 @@ class SpendUpdateQueue: ) -> Dict[str, Any]: """Flush all updates from the queue and return all updates aggregated by entity type.""" updates = await self.flush_all_updates_from_in_memory_queue() + verbose_proxy_logger.debug("Aggregating updates by entity type: %s", updates) return self.aggregate_updates_by_entity_type(updates) def aggregate_updates_by_entity_type( diff --git a/litellm/proxy/hooks/proxy_track_cost_callback.py b/litellm/proxy/hooks/proxy_track_cost_callback.py index 39c1eeace9..dc0c27eb3e 100644 --- a/litellm/proxy/hooks/proxy_track_cost_callback.py +++ b/litellm/proxy/hooks/proxy_track_cost_callback.py @@ -37,6 +37,8 @@ class _ProxyDBLogger(CustomLogger): if _ProxyDBLogger._should_track_errors_in_db() is False: return + from litellm.proxy.proxy_server import proxy_logging_obj + _metadata = dict( StandardLoggingUserAPIKeyMetadata( user_api_key_hash=user_api_key_dict.api_key, @@ -66,7 +68,7 @@ class _ProxyDBLogger(CustomLogger): request_data.get("proxy_server_request") or {} ) request_data["litellm_params"]["metadata"] = existing_metadata - await DBSpendUpdateWriter.update_database( + await proxy_logging_obj.db_spend_update_writer.update_database( token=user_api_key_dict.api_key, response_cost=0.0, user_id=user_api_key_dict.user_id, @@ -136,7 +138,7 @@ class _ProxyDBLogger(CustomLogger): end_user_id=end_user_id, ): ## UPDATE DATABASE - await DBSpendUpdateWriter.update_database( + await proxy_logging_obj.db_spend_update_writer.update_database( token=user_api_key, response_cost=response_cost, user_id=user_id,