diff --git a/litellm/proxy/db/db_spend_update_writer.py b/litellm/proxy/db/db_spend_update_writer.py index eee598f79a..9047bd4c38 100644 --- a/litellm/proxy/db/db_spend_update_writer.py +++ b/litellm/proxy/db/db_spend_update_writer.py @@ -22,7 +22,7 @@ from litellm.proxy._types import ( LiteLLM_UserTable, SpendLogsPayload, ) -from litellm.proxy.db.pod_leader_manager import PodLockManager +from litellm.proxy.db.pod_lock_manager import PodLockManager from litellm.proxy.db.redis_update_buffer import ( DBSpendUpdateTransactions, RedisUpdateBuffer, @@ -49,7 +49,7 @@ class DBSpendUpdateWriter: ): self.redis_cache = redis_cache self.redis_update_buffer = RedisUpdateBuffer(redis_cache=self.redis_cache) - self.pod_leader_manager = PodLockManager(cronjob_id=DB_SPEND_UPDATE_JOB_NAME) + self.pod_lock_manager = PodLockManager(cronjob_id=DB_SPEND_UPDATE_JOB_NAME) @staticmethod async def update_database( @@ -405,7 +405,7 @@ class DBSpendUpdateWriter: ) # Only commit from redis to db if this pod is the leader - if await self.pod_leader_manager.acquire_lock(): + if await self.pod_lock_manager.acquire_lock(): verbose_proxy_logger.debug("acquired lock for spend updates") try: @@ -422,7 +422,7 @@ class DBSpendUpdateWriter: except Exception as e: verbose_proxy_logger.error(f"Error committing spend updates: {e}") finally: - await self.pod_leader_manager.release_lock() + await self.pod_lock_manager.release_lock() else: db_spend_update_transactions = DBSpendUpdateTransactions( user_list_transactions=prisma_client.user_list_transactions, diff --git a/litellm/proxy/db/pod_leader_manager.py b/litellm/proxy/db/pod_lock_manager.py similarity index 100% rename from litellm/proxy/db/pod_leader_manager.py rename to litellm/proxy/db/pod_lock_manager.py diff --git a/tests/litellm/proxy/db/test_pod_lock_manager.py b/tests/litellm/proxy/db/test_pod_lock_manager.py new file mode 100644 index 0000000000..8894ea6716 --- /dev/null +++ b/tests/litellm/proxy/db/test_pod_lock_manager.py @@ -0,0 +1,127 @@ +import json +import os +import sys +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, MagicMock + +import pytest +from fastapi.testclient import TestClient + +sys.path.insert( + 0, os.path.abspath("../../..") +) # Adds the parent directory to the system path + +from litellm.constants import DEFAULT_CRON_JOB_LOCK_TTL_SECONDS +from litellm.proxy.db.pod_lock_manager import PodLockManager + + +# Mock Prisma client class +class MockPrismaClient: + def __init__(self): + self.db = MagicMock() + self.db.litellm_cronjob = AsyncMock() + + +@pytest.fixture +def mock_prisma(monkeypatch): + mock_client = MockPrismaClient() + + # Mock the prisma_client import in proxy_server + def mock_get_prisma(): + return mock_client + + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_client) + return mock_client + + +@pytest.fixture +def pod_lock_manager(): + return PodLockManager(cronjob_id="test_job") + + +@pytest.mark.asyncio +async def test_acquire_lock_success(pod_lock_manager, mock_prisma): + # Mock successful lock acquisition + mock_response = AsyncMock() + mock_response.status = "ACTIVE" + mock_response.pod_id = pod_lock_manager.pod_id + mock_prisma.db.litellm_cronjob.upsert.return_value = mock_response + + result = await pod_lock_manager.acquire_lock() + assert result == True + + # Verify upsert was called with correct parameters + mock_prisma.db.litellm_cronjob.upsert.assert_called_once() + call_args = mock_prisma.db.litellm_cronjob.upsert.call_args[1] + assert call_args["where"]["cronjob_id"] == "test_job" + assert "create" in call_args["data"] + assert "update" in call_args["data"] + + +@pytest.mark.asyncio +async def test_acquire_lock_failure(pod_lock_manager, mock_prisma): + """ + Test that the lock is not acquired if the lock is held by a different pod + """ + # Mock failed lock acquisition (different pod holds the lock) + mock_response = AsyncMock() + mock_response.status = "ACTIVE" + mock_response.pod_id = "different_pod_id" + mock_prisma.db.litellm_cronjob.upsert.return_value = mock_response + + result = await pod_lock_manager.acquire_lock() + assert result == False + + +@pytest.mark.asyncio +async def test_renew_lock(pod_lock_manager, mock_prisma): + # Mock successful lock renewal + mock_prisma.db.litellm_cronjob.update.return_value = AsyncMock() + + await pod_lock_manager.renew_lock() + + # Verify update was called with correct parameters + mock_prisma.db.litellm_cronjob.update.assert_called_once() + call_args = mock_prisma.db.litellm_cronjob.update.call_args[1] + assert call_args["where"]["cronjob_id"] == "test_job" + assert call_args["where"]["pod_id"] == pod_lock_manager.pod_id + assert "ttl" in call_args["data"] + assert "last_updated" in call_args["data"] + + +@pytest.mark.asyncio +async def test_release_lock(pod_lock_manager, mock_prisma): + # Mock successful lock release + mock_prisma.db.litellm_cronjob.update.return_value = AsyncMock() + + await pod_lock_manager.release_lock() + + # Verify update was called with correct parameters + mock_prisma.db.litellm_cronjob.update.assert_called_once() + call_args = mock_prisma.db.litellm_cronjob.update.call_args[1] + assert call_args["where"]["cronjob_id"] == "test_job" + assert call_args["where"]["pod_id"] == pod_lock_manager.pod_id + assert call_args["data"]["status"] == "INACTIVE" + + +@pytest.mark.asyncio +async def test_prisma_client_none(pod_lock_manager, monkeypatch): + # Mock prisma_client as None + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", None) + + # Test all methods with None client + assert await pod_lock_manager.acquire_lock() == False + assert await pod_lock_manager.renew_lock() == False + assert await pod_lock_manager.release_lock() == False + + +@pytest.mark.asyncio +async def test_database_error_handling(pod_lock_manager, mock_prisma): + # Mock database errors + mock_prisma.db.litellm_cronjob.upsert.side_effect = Exception("Database error") + mock_prisma.db.litellm_cronjob.update.side_effect = Exception("Database error") + + # Test error handling in all methods + assert await pod_lock_manager.acquire_lock() == False + await pod_lock_manager.renew_lock() # Should not raise exception + await pod_lock_manager.release_lock() # Should not raise exception