test fix update spend

This commit is contained in:
Ishaan Jaff 2025-03-31 14:20:47 -07:00
parent 17ad8a0417
commit ce5f55d04e

View file

@ -20,7 +20,12 @@ from litellm.proxy.utils import update_spend, DB_CONNECTION_ERROR_TYPES
class MockPrismaClient:
def __init__(self):
self.db = MagicMock()
# Create AsyncMock for db operations
self.db = AsyncMock()
self.db.litellm_spendlogs = AsyncMock()
self.db.litellm_spendlogs.create_many = AsyncMock()
# Initialize transaction lists
self.spend_log_transactions = []
self.user_list_transactons = {}
self.end_user_list_transactons = {}
@ -33,6 +38,20 @@ class MockPrismaClient:
def jsonify_object(self, obj):
return obj
def add_spend_log_transaction_to_daily_user_transaction(self, payload):
# Mock implementation
pass
def create_mock_proxy_logging():
print("creating mock proxy logging")
proxy_logging_obj = MagicMock()
proxy_logging_obj.failure_handler = AsyncMock()
proxy_logging_obj.db_spend_update_writer = AsyncMock()
proxy_logging_obj.db_spend_update_writer.db_update_spend_transaction_handler = AsyncMock()
print("returning proxy logging obj")
return proxy_logging_obj
@pytest.mark.asyncio
@pytest.mark.parametrize(
@ -47,8 +66,11 @@ async def test_update_spend_logs_connection_errors(error_type):
"""Test retry mechanism for different connection error types"""
# Setup
prisma_client = MockPrismaClient()
proxy_logging_obj = MagicMock()
proxy_logging_obj.failure_handler = AsyncMock()
proxy_logging_obj = create_mock_proxy_logging()
# Create AsyncMock for db_spend_update_writer
proxy_logging_obj.db_spend_update_writer = AsyncMock()
proxy_logging_obj.db_spend_update_writer.db_update_spend_transaction_handler = AsyncMock()
# Add test spend logs
prisma_client.spend_log_transactions = [
@ -90,8 +112,7 @@ async def test_update_spend_logs_max_retries_exceeded(error_type):
"""Test that each connection error type properly fails after max retries"""
# Setup
prisma_client = MockPrismaClient()
proxy_logging_obj = MagicMock()
proxy_logging_obj.failure_handler = AsyncMock()
proxy_logging_obj = create_mock_proxy_logging()
# Add test spend logs
prisma_client.spend_log_transactions = [
@ -123,8 +144,7 @@ async def test_update_spend_logs_non_connection_error():
"""Test handling of non-connection related errors"""
# Setup
prisma_client = MockPrismaClient()
proxy_logging_obj = MagicMock()
proxy_logging_obj.failure_handler = AsyncMock()
proxy_logging_obj = create_mock_proxy_logging()
# Add test spend logs
prisma_client.spend_log_transactions = [
@ -155,8 +175,7 @@ async def test_update_spend_logs_exponential_backoff():
"""Test that exponential backoff is working correctly"""
# Setup
prisma_client = MockPrismaClient()
proxy_logging_obj = MagicMock()
proxy_logging_obj.failure_handler = AsyncMock()
proxy_logging_obj = create_mock_proxy_logging()
# Add test spend logs
prisma_client.spend_log_transactions = [{"id": "1", "spend": 10}]
@ -198,8 +217,7 @@ async def test_update_spend_logs_multiple_batches_success():
"""
# Setup
prisma_client = MockPrismaClient()
proxy_logging_obj = MagicMock()
proxy_logging_obj.failure_handler = AsyncMock()
proxy_logging_obj = create_mock_proxy_logging()
# Create 150 test spend logs (1.5x BATCH_SIZE)
prisma_client.spend_log_transactions = [
@ -245,8 +263,7 @@ async def test_update_spend_logs_multiple_batches_with_failure():
"""
# Setup
prisma_client = MockPrismaClient()
proxy_logging_obj = MagicMock()
proxy_logging_obj.failure_handler = AsyncMock()
proxy_logging_obj = create_mock_proxy_logging()
# Create 400 test spend logs (4x BATCH_SIZE)
prisma_client.spend_log_transactions = [