diff --git a/litellm/proxy/db/db_spend_update_writer.py b/litellm/proxy/db/db_spend_update_writer.py index 00f60a76f2..21717b8cbd 100644 --- a/litellm/proxy/db/db_spend_update_writer.py +++ b/litellm/proxy/db/db_spend_update_writer.py @@ -15,12 +15,14 @@ from typing import TYPE_CHECKING, Any, Optional, Union import litellm from litellm._logging import verbose_proxy_logger from litellm.caching import DualCache, RedisCache, RedisClusterCache +from litellm.constants import DB_SPEND_UPDATE_JOB_NAME from litellm.proxy._types import ( DB_CONNECTION_ERROR_TYPES, Litellm_EntityType, LiteLLM_UserTable, SpendLogsPayload, ) +from litellm.proxy.db.pod_leader_manager import PodLockManager from litellm.proxy.db.redis_update_buffer import RedisUpdateBuffer if TYPE_CHECKING: @@ -41,8 +43,14 @@ class DBSpendUpdateWriter: def __init__( self, redis_cache: Optional[Union[RedisCache, RedisClusterCache]] = None ): + from litellm.proxy.proxy_server import prisma_client + self.redis_cache = redis_cache self.redis_update_buffer = RedisUpdateBuffer(redis_cache=redis_cache) + self.pod_leader_manager = PodLockManager( + cronjob_id=DB_SPEND_UPDATE_JOB_NAME, + prisma_client=prisma_client, + ) @staticmethod async def update_database( @@ -397,33 +405,20 @@ class DBSpendUpdateWriter: prisma_client=prisma_client, ) - if DBSpendUpdateWriter._should_commit_spend_updates_to_db(): + if await self.pod_leader_manager.acquire_lock(): + await DBSpendUpdateWriter._commit_spend_updates_to_db( + prisma_client=prisma_client, + n_retry_times=n_retry_times, + proxy_logging_obj=proxy_logging_obj, + ) + await self.pod_leader_manager.release_lock() + else: await DBSpendUpdateWriter._commit_spend_updates_to_db( prisma_client=prisma_client, n_retry_times=n_retry_times, proxy_logging_obj=proxy_logging_obj, ) - pass - - @staticmethod - async def _commit_spend_updates_to_redis( - prisma_client: PrismaClient, - ): - """ - Commits all the spend updates to Redis for each entity type - - once committed, the transactions are cleared from the in-memory variables - """ - pass - - @staticmethod - def _should_commit_spend_updates_to_db() -> bool: - """ - Checks if the Pod should commit spend updates to the Database - """ - return False - @staticmethod async def _commit_spend_updates_to_db( # noqa: PLR0915 prisma_client: PrismaClient, diff --git a/litellm/proxy/db/pod_leader_manager.py b/litellm/proxy/db/pod_leader_manager.py new file mode 100644 index 0000000000..73ecd68b1b --- /dev/null +++ b/litellm/proxy/db/pod_leader_manager.py @@ -0,0 +1,111 @@ +import uuid +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING, Any, Optional + +from litellm._logging import verbose_proxy_logger +from litellm.constants import DEFAULT_CRON_JOB_LOCK_TTL_SECONDS + +if TYPE_CHECKING: + from litellm.proxy.utils import PrismaClient, ProxyLogging +else: + PrismaClient = Any + ProxyLogging = Any + + +class PodLockManager: + """ + Manager for acquiring and releasing locks for cron jobs. + + Ensures that only one pod can run a cron job at a time. + """ + + def __init__(self, prisma_client: Optional[PrismaClient], cronjob_id: str): + self.pod_id = str(uuid.uuid4()) + self.prisma = prisma_client + self.cronjob_id = cronjob_id + + async def acquire_lock(self) -> bool: + """ + Attempt to acquire the lock for a specific cron job. + """ + if not self.prisma: + return False + try: + current_time = datetime.now(timezone.utc) + # Lease expiry time + ttl_expiry = current_time + timedelta( + seconds=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS + ) + + # Attempt to acquire the lock by upserting the record in the `cronjob_locks` table + cronjob_lock = await self.prisma.db.cronJob.upsert( + where={"cronjob_id": self.cronjob_id}, + create={ + "cronjob_id": self.cronjob_id, + "pod_id": self.pod_id, + "status": "ACTIVE", + "last_updated": current_time, + "ttl": ttl_expiry, + }, + update={ + "status": "ACTIVE", + "last_updated": current_time, + "ttl": ttl_expiry, + }, + ) + + if cronjob_lock.status == "ACTIVE" and cronjob_lock.pod_id == self.pod_id: + verbose_proxy_logger.debug( + f"Pod {self.pod_id} has acquired the lock for {self.cronjob_id}." + ) + return True # Lock successfully acquired + return False + except Exception as e: + verbose_proxy_logger.error( + f"Error acquiring the lock for {self.cronjob_id}: {e}" + ) + return False + + async def renew_lock(self): + """ + Renew the lock (update the TTL) for the pod holding the lock. + """ + if not self.prisma: + return False + try: + current_time = datetime.now(timezone.utc) + # Extend the TTL for another DEFAULT_CRON_JOB_LOCK_TTL_SECONDS + ttl_expiry = current_time + timedelta( + seconds=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS + ) + + await self.prisma.db.cronJob.update( + where={"cronjob_id": self.cronjob_id, "pod_id": self.pod_id}, + data={"ttl": ttl_expiry, "last_updated": current_time}, + ) + verbose_proxy_logger.info( + f"Renewed the lock for Pod {self.pod_id} for {self.cronjob_id}" + ) + except Exception as e: + verbose_proxy_logger.error( + f"Error renewing the lock for {self.cronjob_id}: {e}" + ) + + async def release_lock(self): + """ + Release the lock and mark the pod as inactive. + """ + if not self.prisma: + return False + try: + await self.prisma.db.cronJob.update( + where={"cronjob_id": self.cronjob_id, "pod_id": self.pod_id}, + data={"status": "INACTIVE"}, + ) + verbose_proxy_logger.info( + f"Pod {self.pod_id} has released the lock for {self.cronjob_id}." + ) + except Exception as e: + verbose_proxy_logger.error( + f"Error releasing the lock for {self.cronjob_id}: {e}" + )