mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
320 lines
11 KiB
Python
320 lines
11 KiB
Python
import json
|
|
import os
|
|
import sys
|
|
from datetime import datetime, timedelta, timezone
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
|
|
sys.path.insert(
|
|
0, os.path.abspath("../../..")
|
|
) # Adds the parent directory to the system path
|
|
|
|
from litellm.constants import DEFAULT_CRON_JOB_LOCK_TTL_SECONDS
|
|
from litellm.proxy.db.db_transaction_queue.pod_lock_manager import PodLockManager
|
|
|
|
|
|
# Mock Prisma client class
|
|
class MockPrismaClient:
|
|
def __init__(self):
|
|
self.db = MagicMock()
|
|
self.db.litellm_cronjob = AsyncMock()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_prisma(monkeypatch):
|
|
mock_client = MockPrismaClient()
|
|
|
|
# Mock the prisma_client import in proxy_server
|
|
def mock_get_prisma():
|
|
return mock_client
|
|
|
|
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_client)
|
|
return mock_client
|
|
|
|
|
|
@pytest.fixture
|
|
def pod_lock_manager():
|
|
return PodLockManager(cronjob_id="test_job")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_acquire_lock_success(pod_lock_manager, mock_prisma):
|
|
"""
|
|
Test that the lock is acquired successfully when no existing lock exists
|
|
"""
|
|
# Mock find_unique to return None (no existing lock)
|
|
mock_prisma.db.litellm_cronjob.find_unique.return_value = None
|
|
|
|
# Mock successful creation of new lock
|
|
mock_response = AsyncMock()
|
|
mock_response.status = "ACTIVE"
|
|
mock_response.pod_id = pod_lock_manager.pod_id
|
|
mock_prisma.db.litellm_cronjob.create.return_value = mock_response
|
|
|
|
result = await pod_lock_manager.acquire_lock()
|
|
assert result == True
|
|
|
|
# Verify find_unique was called
|
|
mock_prisma.db.litellm_cronjob.find_unique.assert_called_once()
|
|
# Verify create was called with correct parameters
|
|
mock_prisma.db.litellm_cronjob.create.assert_called_once()
|
|
call_args = mock_prisma.db.litellm_cronjob.create.call_args[1]
|
|
assert call_args["data"]["cronjob_id"] == "test_job"
|
|
assert call_args["data"]["pod_id"] == pod_lock_manager.pod_id
|
|
assert call_args["data"]["status"] == "ACTIVE"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_acquire_lock_existing_active(pod_lock_manager, mock_prisma):
|
|
"""
|
|
Test that the lock is not acquired if there's an active lock by different pod
|
|
"""
|
|
# Mock existing active lock
|
|
mock_existing = AsyncMock()
|
|
mock_existing.status = "ACTIVE"
|
|
mock_existing.pod_id = "different_pod_id"
|
|
mock_existing.ttl = datetime.now(timezone.utc) + timedelta(seconds=30) # Future TTL
|
|
mock_prisma.db.litellm_cronjob.find_unique.return_value = mock_existing
|
|
|
|
result = await pod_lock_manager.acquire_lock()
|
|
assert result == False
|
|
|
|
# Verify find_unique was called but update/create were not
|
|
mock_prisma.db.litellm_cronjob.find_unique.assert_called_once()
|
|
mock_prisma.db.litellm_cronjob.update.assert_not_called()
|
|
mock_prisma.db.litellm_cronjob.create.assert_not_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_acquire_lock_expired(pod_lock_manager, mock_prisma):
|
|
"""
|
|
Test that the lock can be acquired if existing lock is expired
|
|
"""
|
|
# Mock existing expired lock
|
|
mock_existing = AsyncMock()
|
|
mock_existing.status = "ACTIVE"
|
|
mock_existing.pod_id = "different_pod_id"
|
|
mock_existing.ttl = datetime.now(timezone.utc) - timedelta(seconds=30) # Past TTL
|
|
mock_prisma.db.litellm_cronjob.find_unique.return_value = mock_existing
|
|
|
|
# Mock successful update
|
|
mock_updated = AsyncMock()
|
|
mock_updated.pod_id = pod_lock_manager.pod_id
|
|
mock_prisma.db.litellm_cronjob.update.return_value = mock_updated
|
|
|
|
result = await pod_lock_manager.acquire_lock()
|
|
assert result == True
|
|
|
|
# Verify both find_unique and update were called
|
|
mock_prisma.db.litellm_cronjob.find_unique.assert_called_once()
|
|
mock_prisma.db.litellm_cronjob.update.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_renew_lock(pod_lock_manager, mock_prisma):
|
|
"""
|
|
Test that the renew lock calls the DB update method with the correct parameters
|
|
"""
|
|
mock_prisma.db.litellm_cronjob.update.return_value = AsyncMock()
|
|
|
|
await pod_lock_manager.renew_lock()
|
|
|
|
# Verify update was called with correct parameters
|
|
mock_prisma.db.litellm_cronjob.update.assert_called_once()
|
|
call_args = mock_prisma.db.litellm_cronjob.update.call_args[1]
|
|
assert call_args["where"]["cronjob_id"] == "test_job"
|
|
assert call_args["where"]["pod_id"] == pod_lock_manager.pod_id
|
|
assert "ttl" in call_args["data"]
|
|
assert "last_updated" in call_args["data"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_release_lock(pod_lock_manager, mock_prisma):
|
|
"""
|
|
Test that the release lock calls the DB update method with the correct parameters
|
|
|
|
specifically, the status should be set to INACTIVE
|
|
"""
|
|
mock_prisma.db.litellm_cronjob.update.return_value = AsyncMock()
|
|
|
|
await pod_lock_manager.release_lock()
|
|
|
|
# Verify update was called with correct parameters
|
|
mock_prisma.db.litellm_cronjob.update.assert_called_once()
|
|
call_args = mock_prisma.db.litellm_cronjob.update.call_args[1]
|
|
assert call_args["where"]["cronjob_id"] == "test_job"
|
|
assert call_args["where"]["pod_id"] == pod_lock_manager.pod_id
|
|
assert call_args["data"]["status"] == "INACTIVE"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_prisma_client_none(pod_lock_manager, monkeypatch):
|
|
# Mock prisma_client as None
|
|
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", None)
|
|
|
|
# Test all methods with None client
|
|
assert await pod_lock_manager.acquire_lock() == False
|
|
assert await pod_lock_manager.renew_lock() == False
|
|
assert await pod_lock_manager.release_lock() == False
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_database_error_handling(pod_lock_manager, mock_prisma):
|
|
# Mock database errors
|
|
mock_prisma.db.litellm_cronjob.upsert.side_effect = Exception("Database error")
|
|
mock_prisma.db.litellm_cronjob.update.side_effect = Exception("Database error")
|
|
|
|
# Test error handling in all methods
|
|
assert await pod_lock_manager.acquire_lock() == False
|
|
await pod_lock_manager.renew_lock() # Should not raise exception
|
|
await pod_lock_manager.release_lock() # Should not raise exception
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_acquire_lock_inactive_status(pod_lock_manager, mock_prisma):
|
|
"""
|
|
Test that the lock can be acquired if existing lock is INACTIVE
|
|
"""
|
|
# Mock existing inactive lock
|
|
mock_existing = AsyncMock()
|
|
mock_existing.status = "INACTIVE"
|
|
mock_existing.pod_id = "different_pod_id"
|
|
mock_existing.ttl = datetime.now(timezone.utc) + timedelta(seconds=30)
|
|
mock_prisma.db.litellm_cronjob.find_unique.return_value = mock_existing
|
|
|
|
# Mock successful update
|
|
mock_updated = AsyncMock()
|
|
mock_updated.pod_id = pod_lock_manager.pod_id
|
|
mock_prisma.db.litellm_cronjob.update.return_value = mock_updated
|
|
|
|
result = await pod_lock_manager.acquire_lock()
|
|
assert result == True
|
|
|
|
mock_prisma.db.litellm_cronjob.update.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_acquire_lock_same_pod(pod_lock_manager, mock_prisma):
|
|
"""
|
|
Test that the lock returns True if the same pod already holds the lock
|
|
"""
|
|
# Mock existing active lock held by same pod
|
|
mock_existing = AsyncMock()
|
|
mock_existing.status = "ACTIVE"
|
|
mock_existing.pod_id = pod_lock_manager.pod_id
|
|
mock_existing.ttl = datetime.now(timezone.utc) + timedelta(seconds=30)
|
|
mock_prisma.db.litellm_cronjob.find_unique.return_value = mock_existing
|
|
|
|
result = await pod_lock_manager.acquire_lock()
|
|
assert result == True
|
|
|
|
# Verify no update was needed
|
|
mock_prisma.db.litellm_cronjob.update.assert_not_called()
|
|
mock_prisma.db.litellm_cronjob.create.assert_not_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_acquire_lock_race_condition(pod_lock_manager, mock_prisma):
|
|
"""
|
|
Test handling of potential race conditions during lock acquisition
|
|
"""
|
|
# First find_unique returns None
|
|
mock_prisma.db.litellm_cronjob.find_unique.return_value = None
|
|
|
|
# But create raises unique constraint violation
|
|
mock_prisma.db.litellm_cronjob.create.side_effect = Exception(
|
|
"Unique constraint violation"
|
|
)
|
|
|
|
result = await pod_lock_manager.acquire_lock()
|
|
assert result == False
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ttl_calculation(pod_lock_manager, mock_prisma):
|
|
"""
|
|
Test that TTL is calculated correctly when acquiring lock
|
|
"""
|
|
mock_prisma.db.litellm_cronjob.find_unique.return_value = None
|
|
mock_prisma.db.litellm_cronjob.create.return_value = AsyncMock()
|
|
|
|
await pod_lock_manager.acquire_lock()
|
|
|
|
call_args = mock_prisma.db.litellm_cronjob.create.call_args[1]
|
|
ttl = call_args["data"]["ttl"]
|
|
|
|
# Verify TTL is in the future by DEFAULT_CRON_JOB_LOCK_TTL_SECONDS
|
|
expected_ttl = datetime.now(timezone.utc) + timedelta(
|
|
seconds=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS
|
|
)
|
|
assert abs((ttl - expected_ttl).total_seconds()) < 1 # Allow 1 second difference
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_concurrent_lock_acquisition_simulation(mock_prisma):
|
|
"""
|
|
Simulate multiple pods trying to acquire the lock simultaneously
|
|
"""
|
|
pod1 = PodLockManager(cronjob_id="test_job")
|
|
pod2 = PodLockManager(cronjob_id="test_job")
|
|
pod3 = PodLockManager(cronjob_id="test_job")
|
|
|
|
# Simulate first pod getting the lock
|
|
mock_prisma.db.litellm_cronjob.find_unique.return_value = None
|
|
mock_response = AsyncMock()
|
|
mock_response.pod_id = pod1.pod_id
|
|
mock_response.status = "ACTIVE"
|
|
mock_prisma.db.litellm_cronjob.create.return_value = mock_response
|
|
|
|
# First pod should get the lock
|
|
result1 = await pod1.acquire_lock()
|
|
assert result1 == True
|
|
|
|
# Simulate other pods trying to acquire same lock immediately after
|
|
mock_existing = AsyncMock()
|
|
mock_existing.status = "ACTIVE"
|
|
mock_existing.pod_id = pod1.pod_id
|
|
mock_existing.ttl = datetime.now(timezone.utc) + timedelta(seconds=30)
|
|
mock_prisma.db.litellm_cronjob.find_unique.return_value = mock_existing
|
|
|
|
# Other pods should fail to acquire
|
|
result2 = await pod2.acquire_lock()
|
|
result3 = await pod3.acquire_lock()
|
|
assert result2 == False
|
|
assert result3 == False
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_lock_takeover_race_condition(mock_prisma):
|
|
"""
|
|
Test scenario where multiple pods try to take over an expired lock
|
|
"""
|
|
pod1 = PodLockManager(cronjob_id="test_job")
|
|
pod2 = PodLockManager(cronjob_id="test_job")
|
|
|
|
# Simulate expired lock
|
|
mock_existing = AsyncMock()
|
|
mock_existing.status = "ACTIVE"
|
|
mock_existing.pod_id = "old_pod"
|
|
mock_existing.ttl = datetime.now(timezone.utc) - timedelta(seconds=30)
|
|
mock_prisma.db.litellm_cronjob.find_unique.return_value = mock_existing
|
|
|
|
# Simulate pod1's update succeeding
|
|
mock_update1 = AsyncMock()
|
|
mock_update1.pod_id = pod1.pod_id
|
|
mock_prisma.db.litellm_cronjob.update.return_value = mock_update1
|
|
|
|
# First pod should successfully take over
|
|
result1 = await pod1.acquire_lock()
|
|
assert result1 == True
|
|
|
|
# Simulate pod2's update failing due to race condition
|
|
mock_prisma.db.litellm_cronjob.update.side_effect = Exception(
|
|
"Row was updated by another transaction"
|
|
)
|
|
|
|
# Second pod should fail to take over
|
|
result2 = await pod2.acquire_lock()
|
|
assert result2 == False
|