diff --git a/tests/proxy_unit_tests/test_update_spend.py b/tests/proxy_unit_tests/test_update_spend.py index 36965cafa7..641768a7d2 100644 --- a/tests/proxy_unit_tests/test_update_spend.py +++ b/tests/proxy_unit_tests/test_update_spend.py @@ -20,7 +20,12 @@ from litellm.proxy.utils import update_spend, DB_CONNECTION_ERROR_TYPES class MockPrismaClient: def __init__(self): - self.db = MagicMock() + # Create AsyncMock for db operations + self.db = AsyncMock() + self.db.litellm_spendlogs = AsyncMock() + self.db.litellm_spendlogs.create_many = AsyncMock() + + # Initialize transaction lists self.spend_log_transactions = [] self.user_list_transactons = {} self.end_user_list_transactons = {} @@ -33,6 +38,20 @@ class MockPrismaClient: def jsonify_object(self, obj): return obj + def add_spend_log_transaction_to_daily_user_transaction(self, payload): + # Mock implementation + pass + + +def create_mock_proxy_logging(): + print("creating mock proxy logging") + proxy_logging_obj = MagicMock() + proxy_logging_obj.failure_handler = AsyncMock() + proxy_logging_obj.db_spend_update_writer = AsyncMock() + proxy_logging_obj.db_spend_update_writer.db_update_spend_transaction_handler = AsyncMock() + print("returning proxy logging obj") + return proxy_logging_obj + @pytest.mark.asyncio @pytest.mark.parametrize( @@ -47,8 +66,11 @@ async def test_update_spend_logs_connection_errors(error_type): """Test retry mechanism for different connection error types""" # Setup prisma_client = MockPrismaClient() - proxy_logging_obj = MagicMock() - proxy_logging_obj.failure_handler = AsyncMock() + proxy_logging_obj = create_mock_proxy_logging() + + # Create AsyncMock for db_spend_update_writer + proxy_logging_obj.db_spend_update_writer = AsyncMock() + proxy_logging_obj.db_spend_update_writer.db_update_spend_transaction_handler = AsyncMock() # Add test spend logs prisma_client.spend_log_transactions = [ @@ -90,8 +112,7 @@ async def test_update_spend_logs_max_retries_exceeded(error_type): """Test that each connection error type properly fails after max retries""" # Setup prisma_client = MockPrismaClient() - proxy_logging_obj = MagicMock() - proxy_logging_obj.failure_handler = AsyncMock() + proxy_logging_obj = create_mock_proxy_logging() # Add test spend logs prisma_client.spend_log_transactions = [ @@ -123,8 +144,7 @@ async def test_update_spend_logs_non_connection_error(): """Test handling of non-connection related errors""" # Setup prisma_client = MockPrismaClient() - proxy_logging_obj = MagicMock() - proxy_logging_obj.failure_handler = AsyncMock() + proxy_logging_obj = create_mock_proxy_logging() # Add test spend logs prisma_client.spend_log_transactions = [ @@ -155,8 +175,7 @@ async def test_update_spend_logs_exponential_backoff(): """Test that exponential backoff is working correctly""" # Setup prisma_client = MockPrismaClient() - proxy_logging_obj = MagicMock() - proxy_logging_obj.failure_handler = AsyncMock() + proxy_logging_obj = create_mock_proxy_logging() # Add test spend logs prisma_client.spend_log_transactions = [{"id": "1", "spend": 10}] @@ -198,8 +217,7 @@ async def test_update_spend_logs_multiple_batches_success(): """ # Setup prisma_client = MockPrismaClient() - proxy_logging_obj = MagicMock() - proxy_logging_obj.failure_handler = AsyncMock() + proxy_logging_obj = create_mock_proxy_logging() # Create 150 test spend logs (1.5x BATCH_SIZE) prisma_client.spend_log_transactions = [ @@ -245,8 +263,7 @@ async def test_update_spend_logs_multiple_batches_with_failure(): """ # Setup prisma_client = MockPrismaClient() - proxy_logging_obj = MagicMock() - proxy_logging_obj.failure_handler = AsyncMock() + proxy_logging_obj = create_mock_proxy_logging() # Create 400 test spend logs (4x BATCH_SIZE) prisma_client.spend_log_transactions = [