litellm-mirror/tests/litellm/proxy/db/test_pod_lock_manager.py
2025-04-01 18:30:48 -07:00

320 lines
11 KiB
Python

import json
import os
import sys
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock
import pytest
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
from litellm.constants import DEFAULT_CRON_JOB_LOCK_TTL_SECONDS
from litellm.proxy.db.db_transaction_queue.pod_lock_manager import PodLockManager
# Mock Prisma client class
class MockPrismaClient:
def __init__(self):
self.db = MagicMock()
self.db.litellm_cronjob = AsyncMock()
@pytest.fixture
def mock_prisma(monkeypatch):
mock_client = MockPrismaClient()
# Mock the prisma_client import in proxy_server
def mock_get_prisma():
return mock_client
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_client)
return mock_client
@pytest.fixture
def pod_lock_manager():
return PodLockManager(cronjob_id="test_job")
@pytest.mark.asyncio
async def test_acquire_lock_success(pod_lock_manager, mock_prisma):
"""
Test that the lock is acquired successfully when no existing lock exists
"""
# Mock find_unique to return None (no existing lock)
mock_prisma.db.litellm_cronjob.find_unique.return_value = None
# Mock successful creation of new lock
mock_response = AsyncMock()
mock_response.status = "ACTIVE"
mock_response.pod_id = pod_lock_manager.pod_id
mock_prisma.db.litellm_cronjob.create.return_value = mock_response
result = await pod_lock_manager.acquire_lock()
assert result == True
# Verify find_unique was called
mock_prisma.db.litellm_cronjob.find_unique.assert_called_once()
# Verify create was called with correct parameters
mock_prisma.db.litellm_cronjob.create.assert_called_once()
call_args = mock_prisma.db.litellm_cronjob.create.call_args[1]
assert call_args["data"]["cronjob_id"] == "test_job"
assert call_args["data"]["pod_id"] == pod_lock_manager.pod_id
assert call_args["data"]["status"] == "ACTIVE"
@pytest.mark.asyncio
async def test_acquire_lock_existing_active(pod_lock_manager, mock_prisma):
"""
Test that the lock is not acquired if there's an active lock by different pod
"""
# Mock existing active lock
mock_existing = AsyncMock()
mock_existing.status = "ACTIVE"
mock_existing.pod_id = "different_pod_id"
mock_existing.ttl = datetime.now(timezone.utc) + timedelta(seconds=30) # Future TTL
mock_prisma.db.litellm_cronjob.find_unique.return_value = mock_existing
result = await pod_lock_manager.acquire_lock()
assert result == False
# Verify find_unique was called but update/create were not
mock_prisma.db.litellm_cronjob.find_unique.assert_called_once()
mock_prisma.db.litellm_cronjob.update.assert_not_called()
mock_prisma.db.litellm_cronjob.create.assert_not_called()
@pytest.mark.asyncio
async def test_acquire_lock_expired(pod_lock_manager, mock_prisma):
"""
Test that the lock can be acquired if existing lock is expired
"""
# Mock existing expired lock
mock_existing = AsyncMock()
mock_existing.status = "ACTIVE"
mock_existing.pod_id = "different_pod_id"
mock_existing.ttl = datetime.now(timezone.utc) - timedelta(seconds=30) # Past TTL
mock_prisma.db.litellm_cronjob.find_unique.return_value = mock_existing
# Mock successful update
mock_updated = AsyncMock()
mock_updated.pod_id = pod_lock_manager.pod_id
mock_prisma.db.litellm_cronjob.update.return_value = mock_updated
result = await pod_lock_manager.acquire_lock()
assert result == True
# Verify both find_unique and update were called
mock_prisma.db.litellm_cronjob.find_unique.assert_called_once()
mock_prisma.db.litellm_cronjob.update.assert_called_once()
@pytest.mark.asyncio
async def test_renew_lock(pod_lock_manager, mock_prisma):
"""
Test that the renew lock calls the DB update method with the correct parameters
"""
mock_prisma.db.litellm_cronjob.update.return_value = AsyncMock()
await pod_lock_manager.renew_lock()
# Verify update was called with correct parameters
mock_prisma.db.litellm_cronjob.update.assert_called_once()
call_args = mock_prisma.db.litellm_cronjob.update.call_args[1]
assert call_args["where"]["cronjob_id"] == "test_job"
assert call_args["where"]["pod_id"] == pod_lock_manager.pod_id
assert "ttl" in call_args["data"]
assert "last_updated" in call_args["data"]
@pytest.mark.asyncio
async def test_release_lock(pod_lock_manager, mock_prisma):
"""
Test that the release lock calls the DB update method with the correct parameters
specifically, the status should be set to INACTIVE
"""
mock_prisma.db.litellm_cronjob.update.return_value = AsyncMock()
await pod_lock_manager.release_lock()
# Verify update was called with correct parameters
mock_prisma.db.litellm_cronjob.update.assert_called_once()
call_args = mock_prisma.db.litellm_cronjob.update.call_args[1]
assert call_args["where"]["cronjob_id"] == "test_job"
assert call_args["where"]["pod_id"] == pod_lock_manager.pod_id
assert call_args["data"]["status"] == "INACTIVE"
@pytest.mark.asyncio
async def test_prisma_client_none(pod_lock_manager, monkeypatch):
# Mock prisma_client as None
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", None)
# Test all methods with None client
assert await pod_lock_manager.acquire_lock() == False
assert await pod_lock_manager.renew_lock() == False
assert await pod_lock_manager.release_lock() == False
@pytest.mark.asyncio
async def test_database_error_handling(pod_lock_manager, mock_prisma):
# Mock database errors
mock_prisma.db.litellm_cronjob.upsert.side_effect = Exception("Database error")
mock_prisma.db.litellm_cronjob.update.side_effect = Exception("Database error")
# Test error handling in all methods
assert await pod_lock_manager.acquire_lock() == False
await pod_lock_manager.renew_lock() # Should not raise exception
await pod_lock_manager.release_lock() # Should not raise exception
@pytest.mark.asyncio
async def test_acquire_lock_inactive_status(pod_lock_manager, mock_prisma):
"""
Test that the lock can be acquired if existing lock is INACTIVE
"""
# Mock existing inactive lock
mock_existing = AsyncMock()
mock_existing.status = "INACTIVE"
mock_existing.pod_id = "different_pod_id"
mock_existing.ttl = datetime.now(timezone.utc) + timedelta(seconds=30)
mock_prisma.db.litellm_cronjob.find_unique.return_value = mock_existing
# Mock successful update
mock_updated = AsyncMock()
mock_updated.pod_id = pod_lock_manager.pod_id
mock_prisma.db.litellm_cronjob.update.return_value = mock_updated
result = await pod_lock_manager.acquire_lock()
assert result == True
mock_prisma.db.litellm_cronjob.update.assert_called_once()
@pytest.mark.asyncio
async def test_acquire_lock_same_pod(pod_lock_manager, mock_prisma):
"""
Test that the lock returns True if the same pod already holds the lock
"""
# Mock existing active lock held by same pod
mock_existing = AsyncMock()
mock_existing.status = "ACTIVE"
mock_existing.pod_id = pod_lock_manager.pod_id
mock_existing.ttl = datetime.now(timezone.utc) + timedelta(seconds=30)
mock_prisma.db.litellm_cronjob.find_unique.return_value = mock_existing
result = await pod_lock_manager.acquire_lock()
assert result == True
# Verify no update was needed
mock_prisma.db.litellm_cronjob.update.assert_not_called()
mock_prisma.db.litellm_cronjob.create.assert_not_called()
@pytest.mark.asyncio
async def test_acquire_lock_race_condition(pod_lock_manager, mock_prisma):
"""
Test handling of potential race conditions during lock acquisition
"""
# First find_unique returns None
mock_prisma.db.litellm_cronjob.find_unique.return_value = None
# But create raises unique constraint violation
mock_prisma.db.litellm_cronjob.create.side_effect = Exception(
"Unique constraint violation"
)
result = await pod_lock_manager.acquire_lock()
assert result == False
@pytest.mark.asyncio
async def test_ttl_calculation(pod_lock_manager, mock_prisma):
"""
Test that TTL is calculated correctly when acquiring lock
"""
mock_prisma.db.litellm_cronjob.find_unique.return_value = None
mock_prisma.db.litellm_cronjob.create.return_value = AsyncMock()
await pod_lock_manager.acquire_lock()
call_args = mock_prisma.db.litellm_cronjob.create.call_args[1]
ttl = call_args["data"]["ttl"]
# Verify TTL is in the future by DEFAULT_CRON_JOB_LOCK_TTL_SECONDS
expected_ttl = datetime.now(timezone.utc) + timedelta(
seconds=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS
)
assert abs((ttl - expected_ttl).total_seconds()) < 1 # Allow 1 second difference
@pytest.mark.asyncio
async def test_concurrent_lock_acquisition_simulation(mock_prisma):
"""
Simulate multiple pods trying to acquire the lock simultaneously
"""
pod1 = PodLockManager(cronjob_id="test_job")
pod2 = PodLockManager(cronjob_id="test_job")
pod3 = PodLockManager(cronjob_id="test_job")
# Simulate first pod getting the lock
mock_prisma.db.litellm_cronjob.find_unique.return_value = None
mock_response = AsyncMock()
mock_response.pod_id = pod1.pod_id
mock_response.status = "ACTIVE"
mock_prisma.db.litellm_cronjob.create.return_value = mock_response
# First pod should get the lock
result1 = await pod1.acquire_lock()
assert result1 == True
# Simulate other pods trying to acquire same lock immediately after
mock_existing = AsyncMock()
mock_existing.status = "ACTIVE"
mock_existing.pod_id = pod1.pod_id
mock_existing.ttl = datetime.now(timezone.utc) + timedelta(seconds=30)
mock_prisma.db.litellm_cronjob.find_unique.return_value = mock_existing
# Other pods should fail to acquire
result2 = await pod2.acquire_lock()
result3 = await pod3.acquire_lock()
assert result2 == False
assert result3 == False
@pytest.mark.asyncio
async def test_lock_takeover_race_condition(mock_prisma):
"""
Test scenario where multiple pods try to take over an expired lock
"""
pod1 = PodLockManager(cronjob_id="test_job")
pod2 = PodLockManager(cronjob_id="test_job")
# Simulate expired lock
mock_existing = AsyncMock()
mock_existing.status = "ACTIVE"
mock_existing.pod_id = "old_pod"
mock_existing.ttl = datetime.now(timezone.utc) - timedelta(seconds=30)
mock_prisma.db.litellm_cronjob.find_unique.return_value = mock_existing
# Simulate pod1's update succeeding
mock_update1 = AsyncMock()
mock_update1.pod_id = pod1.pod_id
mock_prisma.db.litellm_cronjob.update.return_value = mock_update1
# First pod should successfully take over
result1 = await pod1.acquire_lock()
assert result1 == True
# Simulate pod2's update failing due to race condition
mock_prisma.db.litellm_cronjob.update.side_effect = Exception(
"Row was updated by another transaction"
)
# Second pod should fail to take over
result2 = await pod2.acquire_lock()
assert result2 == False