Merge pull request #9715 from BerriAI/litellm_refactor_pod_lock_manager

[Reliability Fix] - Use Redis for PodLock Manager instead of PG (ensures no deadlocks occur)
This commit is contained in:
Ishaan Jaff 2025-04-02 21:15:02 -07:00 committed by GitHub
commit 5a722ef18f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 323 additions and 422 deletions

View file

@ -304,12 +304,18 @@ class RedisCache(BaseCache):
key = self.check_and_fix_namespace(key=key) key = self.check_and_fix_namespace(key=key)
ttl = self.get_ttl(**kwargs) ttl = self.get_ttl(**kwargs)
nx = kwargs.get("nx", False)
print_verbose(f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}") print_verbose(f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}")
try: try:
if not hasattr(_redis_client, "set"): if not hasattr(_redis_client, "set"):
raise Exception("Redis client cannot set cache. Attribute not found.") raise Exception("Redis client cannot set cache. Attribute not found.")
await _redis_client.set(name=key, value=json.dumps(value), ex=ttl) result = await _redis_client.set(
name=key,
value=json.dumps(value),
nx=nx,
ex=ttl,
)
print_verbose( print_verbose(
f"Successfully Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}" f"Successfully Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
) )
@ -326,6 +332,7 @@ class RedisCache(BaseCache):
event_metadata={"key": key}, event_metadata={"key": key},
) )
) )
return result
except Exception as e: except Exception as e:
end_time = time.time() end_time = time.time()
_duration = end_time - start_time _duration = end_time - start_time
@ -931,7 +938,7 @@ class RedisCache(BaseCache):
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `delete` # typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `delete`
_redis_client: Any = self.init_async_client() _redis_client: Any = self.init_async_client()
# keys is str # keys is str
await _redis_client.delete(key) return await _redis_client.delete(key)
def delete_cache(self, key): def delete_cache(self, key):
self.redis_client.delete(key) self.redis_client.delete(key)

View file

@ -1,137 +1,129 @@
import uuid import uuid
from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.caching.redis_cache import RedisCache
from litellm.constants import DEFAULT_CRON_JOB_LOCK_TTL_SECONDS from litellm.constants import DEFAULT_CRON_JOB_LOCK_TTL_SECONDS
if TYPE_CHECKING: if TYPE_CHECKING:
from litellm.proxy.utils import PrismaClient, ProxyLogging ProxyLogging = Any
else: else:
PrismaClient = Any
ProxyLogging = Any ProxyLogging = Any
class PodLockManager: class PodLockManager:
""" """
Manager for acquiring and releasing locks for cron jobs. Manager for acquiring and releasing locks for cron jobs using Redis.
Ensures that only one pod can run a cron job at a time. Ensures that only one pod can run a cron job at a time.
""" """
def __init__(self, cronjob_id: str): def __init__(self, cronjob_id: str, redis_cache: Optional[RedisCache] = None):
self.pod_id = str(uuid.uuid4()) self.pod_id = str(uuid.uuid4())
self.cronjob_id = cronjob_id self.cronjob_id = cronjob_id
self.redis_cache = redis_cache
# Define a unique key for this cronjob lock in Redis.
self.lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
async def acquire_lock(self) -> bool: @staticmethod
def get_redis_lock_key(cronjob_id: str) -> str:
return f"cronjob_lock:{cronjob_id}"
async def acquire_lock(self) -> Optional[bool]:
""" """
Attempt to acquire the lock for a specific cron job using database locking. Attempt to acquire the lock for a specific cron job using Redis.
Uses the SET command with NX and EX options to ensure atomicity.
""" """
from litellm.proxy.proxy_server import prisma_client if self.redis_cache is None:
verbose_proxy_logger.debug("redis_cache is None, skipping acquire_lock")
verbose_proxy_logger.debug( return None
"Pod %s acquiring lock for cronjob_id=%s", self.pod_id, self.cronjob_id
)
if not prisma_client:
verbose_proxy_logger.debug("prisma is None, returning False")
return False
try:
current_time = datetime.now(timezone.utc)
ttl_expiry = current_time + timedelta(
seconds=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS
)
# Use Prisma's findUnique with FOR UPDATE lock to prevent race conditions
lock_record = await prisma_client.db.litellm_cronjob.find_unique(
where={"cronjob_id": self.cronjob_id},
)
if lock_record:
# If record exists, only update if it's inactive or expired
if lock_record.status == "ACTIVE" and lock_record.ttl > current_time:
return lock_record.pod_id == self.pod_id
# Update existing record
updated_lock = await prisma_client.db.litellm_cronjob.update(
where={"cronjob_id": self.cronjob_id},
data={
"pod_id": self.pod_id,
"status": "ACTIVE",
"last_updated": current_time,
"ttl": ttl_expiry,
},
)
else:
# Create new record if none exists
updated_lock = await prisma_client.db.litellm_cronjob.create(
data={
"cronjob_id": self.cronjob_id,
"pod_id": self.pod_id,
"status": "ACTIVE",
"last_updated": current_time,
"ttl": ttl_expiry,
}
)
return updated_lock.pod_id == self.pod_id
except Exception as e:
verbose_proxy_logger.error(
f"Error acquiring the lock for {self.cronjob_id}: {e}"
)
return False
async def renew_lock(self):
"""
Renew the lock (update the TTL) for the pod holding the lock.
"""
from litellm.proxy.proxy_server import prisma_client
if not prisma_client:
return False
try: try:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"renewing lock for cronjob_id=%s", self.cronjob_id "Pod %s attempting to acquire Redis lock for cronjob_id=%s",
self.pod_id,
self.cronjob_id,
) )
current_time = datetime.now(timezone.utc) # Try to set the lock key with the pod_id as its value, only if it doesn't exist (NX)
# Extend the TTL for another DEFAULT_CRON_JOB_LOCK_TTL_SECONDS # and with an expiration (EX) to avoid deadlocks.
ttl_expiry = current_time + timedelta( acquired = await self.redis_cache.async_set_cache(
seconds=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS self.lock_key,
) self.pod_id,
nx=True,
await prisma_client.db.litellm_cronjob.update( ttl=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS,
where={"cronjob_id": self.cronjob_id, "pod_id": self.pod_id},
data={"ttl": ttl_expiry, "last_updated": current_time},
)
verbose_proxy_logger.info(
f"Renewed the lock for Pod {self.pod_id} for {self.cronjob_id}"
) )
if acquired:
verbose_proxy_logger.info(
"Pod %s successfully acquired Redis lock for cronjob_id=%s",
self.pod_id,
self.cronjob_id,
)
return True
else:
# Check if the current pod already holds the lock
current_value = await self.redis_cache.async_get_cache(self.lock_key)
if current_value is not None:
if isinstance(current_value, bytes):
current_value = current_value.decode("utf-8")
if current_value == self.pod_id:
verbose_proxy_logger.info(
"Pod %s already holds the Redis lock for cronjob_id=%s",
self.pod_id,
self.cronjob_id,
)
return True
return False
except Exception as e: except Exception as e:
verbose_proxy_logger.error( verbose_proxy_logger.error(
f"Error renewing the lock for {self.cronjob_id}: {e}" f"Error acquiring Redis lock for {self.cronjob_id}: {e}"
) )
return False
async def release_lock(self): async def release_lock(self):
""" """
Release the lock and mark the pod as inactive. Release the lock if the current pod holds it.
Uses get and delete commands to ensure that only the owner can release the lock.
""" """
from litellm.proxy.proxy_server import prisma_client if self.redis_cache is None:
verbose_proxy_logger.debug("redis_cache is None, skipping release_lock")
if not prisma_client: return
return False
try: try:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"Pod %s releasing lock for cronjob_id=%s", self.pod_id, self.cronjob_id "Pod %s attempting to release Redis lock for cronjob_id=%s",
) self.pod_id,
await prisma_client.db.litellm_cronjob.update( self.cronjob_id,
where={"cronjob_id": self.cronjob_id, "pod_id": self.pod_id},
data={"status": "INACTIVE"},
)
verbose_proxy_logger.info(
f"Pod {self.pod_id} has released the lock for {self.cronjob_id}."
) )
current_value = await self.redis_cache.async_get_cache(self.lock_key)
if current_value is not None:
if isinstance(current_value, bytes):
current_value = current_value.decode("utf-8")
if current_value == self.pod_id:
result = await self.redis_cache.async_delete_cache(self.lock_key)
if result == 1:
verbose_proxy_logger.info(
"Pod %s successfully released Redis lock for cronjob_id=%s",
self.pod_id,
self.cronjob_id,
)
else:
verbose_proxy_logger.debug(
"Pod %s failed to release Redis lock for cronjob_id=%s",
self.pod_id,
self.cronjob_id,
)
else:
verbose_proxy_logger.debug(
"Pod %s cannot release Redis lock for cronjob_id=%s because it is held by pod %s",
self.pod_id,
self.cronjob_id,
current_value,
)
else:
verbose_proxy_logger.debug(
"Pod %s attempted to release Redis lock for cronjob_id=%s, but no lock was found",
self.pod_id,
self.cronjob_id,
)
except Exception as e: except Exception as e:
verbose_proxy_logger.error( verbose_proxy_logger.error(
f"Error releasing the lock for {self.cronjob_id}: {e}" f"Error releasing Redis lock for {self.cronjob_id}: {e}"
) )

View file

@ -349,6 +349,7 @@ class ProxyLogging:
if redis_cache is not None: if redis_cache is not None:
self.internal_usage_cache.dual_cache.redis_cache = redis_cache self.internal_usage_cache.dual_cache.redis_cache = redis_cache
self.db_spend_update_writer.redis_update_buffer.redis_cache = redis_cache self.db_spend_update_writer.redis_update_buffer.redis_cache = redis_cache
self.db_spend_update_writer.pod_lock_manager.redis_cache = redis_cache
def _init_litellm_callbacks(self, llm_router: Optional[Router] = None): def _init_litellm_callbacks(self, llm_router: Optional[Router] = None):
litellm.logging_callback_manager.add_litellm_callback(self.max_parallel_request_limiter) # type: ignore litellm.logging_callback_manager.add_litellm_callback(self.max_parallel_request_limiter) # type: ignore

View file

@ -2,7 +2,7 @@ import json
import os import os
import sys import sys
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
@ -15,306 +15,251 @@ from litellm.constants import DEFAULT_CRON_JOB_LOCK_TTL_SECONDS
from litellm.proxy.db.db_transaction_queue.pod_lock_manager import PodLockManager from litellm.proxy.db.db_transaction_queue.pod_lock_manager import PodLockManager
# Mock Prisma client class class MockRedisCache:
class MockPrismaClient:
def __init__(self): def __init__(self):
self.db = MagicMock() self.async_set_cache = AsyncMock()
self.db.litellm_cronjob = AsyncMock() self.async_get_cache = AsyncMock()
self.async_delete_cache = AsyncMock()
@pytest.fixture @pytest.fixture
def mock_prisma(monkeypatch): def mock_redis():
mock_client = MockPrismaClient() return MockRedisCache()
# Mock the prisma_client import in proxy_server
def mock_get_prisma():
return mock_client
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_client)
return mock_client
@pytest.fixture @pytest.fixture
def pod_lock_manager(): def pod_lock_manager(mock_redis):
return PodLockManager(cronjob_id="test_job") return PodLockManager(cronjob_id="test_job", redis_cache=mock_redis)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_acquire_lock_success(pod_lock_manager, mock_prisma): async def test_acquire_lock_success(pod_lock_manager, mock_redis):
""" """
Test that the lock is acquired successfully when no existing lock exists Test that the lock is acquired successfully when no existing lock exists
""" """
# Mock find_unique to return None (no existing lock) # Mock successful acquisition (SET NX returns True)
mock_prisma.db.litellm_cronjob.find_unique.return_value = None mock_redis.async_set_cache.return_value = True
# Mock successful creation of new lock
mock_response = AsyncMock()
mock_response.status = "ACTIVE"
mock_response.pod_id = pod_lock_manager.pod_id
mock_prisma.db.litellm_cronjob.create.return_value = mock_response
result = await pod_lock_manager.acquire_lock() result = await pod_lock_manager.acquire_lock()
assert result == True assert result == True
# Verify find_unique was called # Verify set_cache was called with correct parameters
mock_prisma.db.litellm_cronjob.find_unique.assert_called_once() mock_redis.async_set_cache.assert_called_once_with(
# Verify create was called with correct parameters pod_lock_manager.lock_key,
mock_prisma.db.litellm_cronjob.create.assert_called_once() pod_lock_manager.pod_id,
call_args = mock_prisma.db.litellm_cronjob.create.call_args[1] nx=True,
assert call_args["data"]["cronjob_id"] == "test_job" ttl=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS,
assert call_args["data"]["pod_id"] == pod_lock_manager.pod_id )
assert call_args["data"]["status"] == "ACTIVE"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_acquire_lock_existing_active(pod_lock_manager, mock_prisma): async def test_acquire_lock_existing_active(pod_lock_manager, mock_redis):
""" """
Test that the lock is not acquired if there's an active lock by different pod Test that the lock is not acquired if there's an active lock by different pod
""" """
# Mock existing active lock # Mock failed acquisition (SET NX returns False)
mock_existing = AsyncMock() mock_redis.async_set_cache.return_value = False
mock_existing.status = "ACTIVE" # Mock get_cache to return a different pod's ID
mock_existing.pod_id = "different_pod_id" mock_redis.async_get_cache.return_value = "different_pod_id"
mock_existing.ttl = datetime.now(timezone.utc) + timedelta(seconds=30) # Future TTL
mock_prisma.db.litellm_cronjob.find_unique.return_value = mock_existing
result = await pod_lock_manager.acquire_lock() result = await pod_lock_manager.acquire_lock()
assert result == False assert result == False
# Verify find_unique was called but update/create were not # Verify set_cache was called
mock_prisma.db.litellm_cronjob.find_unique.assert_called_once() mock_redis.async_set_cache.assert_called_once()
mock_prisma.db.litellm_cronjob.update.assert_not_called() # Verify get_cache was called to check existing lock
mock_prisma.db.litellm_cronjob.create.assert_not_called() mock_redis.async_get_cache.assert_called_once_with(pod_lock_manager.lock_key)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_acquire_lock_expired(pod_lock_manager, mock_prisma): async def test_acquire_lock_expired(pod_lock_manager, mock_redis):
""" """
Test that the lock can be acquired if existing lock is expired Test that the lock can be acquired if existing lock is expired
""" """
# Mock existing expired lock # Mock failed acquisition first (SET NX returns False)
mock_existing = AsyncMock() mock_redis.async_set_cache.return_value = False
mock_existing.status = "ACTIVE"
mock_existing.pod_id = "different_pod_id"
mock_existing.ttl = datetime.now(timezone.utc) - timedelta(seconds=30) # Past TTL
mock_prisma.db.litellm_cronjob.find_unique.return_value = mock_existing
# Mock successful update # Simulate an expired lock by having the TTL return a value
mock_updated = AsyncMock() # Since Redis auto-expires keys, an expired lock would be absent
mock_updated.pod_id = pod_lock_manager.pod_id # So we'll simulate a retry after the first check fails
mock_prisma.db.litellm_cronjob.update.return_value = mock_updated
# First check returns a value (lock exists)
mock_redis.async_get_cache.return_value = "different_pod_id"
# Then set succeeds on retry (simulating key expiring between checks)
mock_redis.async_set_cache.side_effect = [False, True]
result = await pod_lock_manager.acquire_lock()
assert result == False # First attempt fails
# Reset mock for a second attempt
mock_redis.async_set_cache.reset_mock()
mock_redis.async_set_cache.return_value = True
# Try again (simulating the lock expired)
result = await pod_lock_manager.acquire_lock() result = await pod_lock_manager.acquire_lock()
assert result == True assert result == True
# Verify both find_unique and update were called # Verify set_cache was called again
mock_prisma.db.litellm_cronjob.find_unique.assert_called_once() mock_redis.async_set_cache.assert_called_once()
mock_prisma.db.litellm_cronjob.update.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_renew_lock(pod_lock_manager, mock_prisma): async def test_release_lock_success(pod_lock_manager, mock_redis):
""" """
Test that the renew lock calls the DB update method with the correct parameters Test that the release lock works when the current pod holds the lock
""" """
mock_prisma.db.litellm_cronjob.update.return_value = AsyncMock() # Mock get_cache to return this pod's ID
mock_redis.async_get_cache.return_value = pod_lock_manager.pod_id
await pod_lock_manager.renew_lock() # Mock successful deletion
mock_redis.async_delete_cache.return_value = 1
# Verify update was called with correct parameters
mock_prisma.db.litellm_cronjob.update.assert_called_once()
call_args = mock_prisma.db.litellm_cronjob.update.call_args[1]
assert call_args["where"]["cronjob_id"] == "test_job"
assert call_args["where"]["pod_id"] == pod_lock_manager.pod_id
assert "ttl" in call_args["data"]
assert "last_updated" in call_args["data"]
@pytest.mark.asyncio
async def test_release_lock(pod_lock_manager, mock_prisma):
"""
Test that the release lock calls the DB update method with the correct parameters
specifically, the status should be set to INACTIVE
"""
mock_prisma.db.litellm_cronjob.update.return_value = AsyncMock()
await pod_lock_manager.release_lock() await pod_lock_manager.release_lock()
# Verify update was called with correct parameters # Verify get_cache was called
mock_prisma.db.litellm_cronjob.update.assert_called_once() mock_redis.async_get_cache.assert_called_once_with(pod_lock_manager.lock_key)
call_args = mock_prisma.db.litellm_cronjob.update.call_args[1] # Verify delete_cache was called
assert call_args["where"]["cronjob_id"] == "test_job" mock_redis.async_delete_cache.assert_called_once_with(pod_lock_manager.lock_key)
assert call_args["where"]["pod_id"] == pod_lock_manager.pod_id
assert call_args["data"]["status"] == "INACTIVE"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_prisma_client_none(pod_lock_manager, monkeypatch): async def test_release_lock_different_pod(pod_lock_manager, mock_redis):
# Mock prisma_client as None """
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", None) Test that the release lock doesn't delete when a different pod holds the lock
"""
# Mock get_cache to return a different pod's ID
mock_redis.async_get_cache.return_value = "different_pod_id"
# Test all methods with None client await pod_lock_manager.release_lock()
assert await pod_lock_manager.acquire_lock() == False
assert await pod_lock_manager.renew_lock() == False # Verify get_cache was called
assert await pod_lock_manager.release_lock() == False mock_redis.async_get_cache.assert_called_once_with(pod_lock_manager.lock_key)
# Verify delete_cache was NOT called
mock_redis.async_delete_cache.assert_not_called()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_database_error_handling(pod_lock_manager, mock_prisma): async def test_release_lock_no_lock(pod_lock_manager, mock_redis):
# Mock database errors """
mock_prisma.db.litellm_cronjob.upsert.side_effect = Exception("Database error") Test release lock behavior when no lock exists
mock_prisma.db.litellm_cronjob.update.side_effect = Exception("Database error") """
# Mock get_cache to return None (no lock)
mock_redis.async_get_cache.return_value = None
# Test error handling in all methods await pod_lock_manager.release_lock()
assert await pod_lock_manager.acquire_lock() == False
await pod_lock_manager.renew_lock() # Should not raise exception # Verify get_cache was called
await pod_lock_manager.release_lock() # Should not raise exception mock_redis.async_get_cache.assert_called_once_with(pod_lock_manager.lock_key)
# Verify delete_cache was NOT called
mock_redis.async_delete_cache.assert_not_called()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_acquire_lock_inactive_status(pod_lock_manager, mock_prisma): async def test_redis_none(monkeypatch):
""" """
Test that the lock can be acquired if existing lock is INACTIVE Test behavior when redis_cache is None
""" """
# Mock existing inactive lock pod_lock_manager = PodLockManager(cronjob_id="test_job", redis_cache=None)
mock_existing = AsyncMock()
mock_existing.status = "INACTIVE"
mock_existing.pod_id = "different_pod_id"
mock_existing.ttl = datetime.now(timezone.utc) + timedelta(seconds=30)
mock_prisma.db.litellm_cronjob.find_unique.return_value = mock_existing
# Mock successful update # Test acquire_lock with None redis_cache
mock_updated = AsyncMock() assert await pod_lock_manager.acquire_lock() is None
mock_updated.pod_id = pod_lock_manager.pod_id
mock_prisma.db.litellm_cronjob.update.return_value = mock_updated
result = await pod_lock_manager.acquire_lock() # Test release_lock with None redis_cache (should not raise exception)
assert result == True await pod_lock_manager.release_lock()
mock_prisma.db.litellm_cronjob.update.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_acquire_lock_same_pod(pod_lock_manager, mock_prisma): async def test_redis_error_handling(pod_lock_manager, mock_redis):
""" """
Test that the lock returns True if the same pod already holds the lock Test error handling in Redis operations
""" """
# Mock existing active lock held by same pod # Mock exceptions for Redis operations
mock_existing = AsyncMock() mock_redis.async_set_cache.side_effect = Exception("Redis error")
mock_existing.status = "ACTIVE" mock_redis.async_get_cache.side_effect = Exception("Redis error")
mock_existing.pod_id = pod_lock_manager.pod_id mock_redis.async_delete_cache.side_effect = Exception("Redis error")
mock_existing.ttl = datetime.now(timezone.utc) + timedelta(seconds=30)
mock_prisma.db.litellm_cronjob.find_unique.return_value = mock_existing
result = await pod_lock_manager.acquire_lock()
assert result == True
# Verify no update was needed
mock_prisma.db.litellm_cronjob.update.assert_not_called()
mock_prisma.db.litellm_cronjob.create.assert_not_called()
@pytest.mark.asyncio
async def test_acquire_lock_race_condition(pod_lock_manager, mock_prisma):
"""
Test handling of potential race conditions during lock acquisition
"""
# First find_unique returns None
mock_prisma.db.litellm_cronjob.find_unique.return_value = None
# But create raises unique constraint violation
mock_prisma.db.litellm_cronjob.create.side_effect = Exception(
"Unique constraint violation"
)
# Test acquire_lock error handling
result = await pod_lock_manager.acquire_lock() result = await pod_lock_manager.acquire_lock()
assert result == False assert result == False
# Reset side effect for get_cache for the release test
mock_redis.async_get_cache.side_effect = None
mock_redis.async_get_cache.return_value = pod_lock_manager.pod_id
@pytest.mark.asyncio # Test release_lock error handling (should not raise exception)
async def test_ttl_calculation(pod_lock_manager, mock_prisma): await pod_lock_manager.release_lock()
"""
Test that TTL is calculated correctly when acquiring lock
"""
mock_prisma.db.litellm_cronjob.find_unique.return_value = None
mock_prisma.db.litellm_cronjob.create.return_value = AsyncMock()
await pod_lock_manager.acquire_lock()
call_args = mock_prisma.db.litellm_cronjob.create.call_args[1]
ttl = call_args["data"]["ttl"]
# Verify TTL is in the future by DEFAULT_CRON_JOB_LOCK_TTL_SECONDS
expected_ttl = datetime.now(timezone.utc) + timedelta(
seconds=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS
)
assert abs((ttl - expected_ttl).total_seconds()) < 1 # Allow 1 second difference
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_concurrent_lock_acquisition_simulation(mock_prisma): async def test_bytes_handling(pod_lock_manager, mock_redis):
"""
Test handling of bytes values from Redis
"""
# Mock failed acquisition
mock_redis.async_set_cache.return_value = False
# Mock get_cache to return bytes
mock_redis.async_get_cache.return_value = pod_lock_manager.pod_id.encode("utf-8")
result = await pod_lock_manager.acquire_lock()
assert result == True
# Reset for release test
mock_redis.async_get_cache.return_value = pod_lock_manager.pod_id.encode("utf-8")
mock_redis.async_delete_cache.return_value = 1
await pod_lock_manager.release_lock()
mock_redis.async_delete_cache.assert_called_once()
@pytest.mark.asyncio
async def test_concurrent_lock_acquisition_simulation():
""" """
Simulate multiple pods trying to acquire the lock simultaneously Simulate multiple pods trying to acquire the lock simultaneously
""" """
pod1 = PodLockManager(cronjob_id="test_job") mock_redis = MockRedisCache()
pod2 = PodLockManager(cronjob_id="test_job") pod1 = PodLockManager(cronjob_id="test_job", redis_cache=mock_redis)
pod3 = PodLockManager(cronjob_id="test_job") pod2 = PodLockManager(cronjob_id="test_job", redis_cache=mock_redis)
pod3 = PodLockManager(cronjob_id="test_job", redis_cache=mock_redis)
# Simulate first pod getting the lock # Simulate first pod getting the lock
mock_prisma.db.litellm_cronjob.find_unique.return_value = None mock_redis.async_set_cache.return_value = True
mock_response = AsyncMock()
mock_response.pod_id = pod1.pod_id
mock_response.status = "ACTIVE"
mock_prisma.db.litellm_cronjob.create.return_value = mock_response
# First pod should get the lock # First pod should get the lock
result1 = await pod1.acquire_lock() result1 = await pod1.acquire_lock()
assert result1 == True assert result1 == True
# Simulate other pods trying to acquire same lock immediately after # Simulate other pods failing to get the lock
mock_existing = AsyncMock() mock_redis.async_set_cache.return_value = False
mock_existing.status = "ACTIVE" mock_redis.async_get_cache.return_value = pod1.pod_id
mock_existing.pod_id = pod1.pod_id
mock_existing.ttl = datetime.now(timezone.utc) + timedelta(seconds=30)
mock_prisma.db.litellm_cronjob.find_unique.return_value = mock_existing
# Other pods should fail to acquire # Other pods should fail to acquire
result2 = await pod2.acquire_lock() result2 = await pod2.acquire_lock()
result3 = await pod3.acquire_lock() result3 = await pod3.acquire_lock()
# Since other pods don't have the lock, they should get False
assert result2 == False assert result2 == False
assert result3 == False assert result3 == False
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_lock_takeover_race_condition(mock_prisma): async def test_lock_takeover_race_condition(mock_redis):
""" """
Test scenario where multiple pods try to take over an expired lock Test scenario where multiple pods try to take over an expired lock using Redis
""" """
pod1 = PodLockManager(cronjob_id="test_job") pod1 = PodLockManager(cronjob_id="test_job", redis_cache=mock_redis)
pod2 = PodLockManager(cronjob_id="test_job") pod2 = PodLockManager(cronjob_id="test_job", redis_cache=mock_redis)
# Simulate expired lock # Simulate first pod's acquisition succeeding
mock_existing = AsyncMock() mock_redis.async_set_cache.return_value = True
mock_existing.status = "ACTIVE"
mock_existing.pod_id = "old_pod"
mock_existing.ttl = datetime.now(timezone.utc) - timedelta(seconds=30)
mock_prisma.db.litellm_cronjob.find_unique.return_value = mock_existing
# Simulate pod1's update succeeding # First pod should successfully acquire
mock_update1 = AsyncMock()
mock_update1.pod_id = pod1.pod_id
mock_prisma.db.litellm_cronjob.update.return_value = mock_update1
# First pod should successfully take over
result1 = await pod1.acquire_lock() result1 = await pod1.acquire_lock()
assert result1 == True assert result1 == True
# Simulate pod2's update failing due to race condition # Simulate race condition: second pod tries but fails
mock_prisma.db.litellm_cronjob.update.side_effect = Exception( mock_redis.async_set_cache.return_value = False
"Row was updated by another transaction" mock_redis.async_get_cache.return_value = pod1.pod_id
)
# Second pod should fail to take over # Second pod should fail to acquire
result2 = await pod2.acquire_lock() result2 = await pod2.acquire_lock()
assert result2 == False assert result2 == False

View file

@ -8,6 +8,7 @@ from dotenv import load_dotenv
from fastapi import Request from fastapi import Request
from fastapi.routing import APIRoute from fastapi.routing import APIRoute
import httpx import httpx
import json
load_dotenv() load_dotenv()
import io import io
@ -72,7 +73,7 @@ verbose_proxy_logger.setLevel(level=logging.DEBUG)
from starlette.datastructures import URL from starlette.datastructures import URL
from litellm.caching.caching import DualCache from litellm.caching.caching import DualCache, RedisCache
from litellm.proxy._types import ( from litellm.proxy._types import (
DynamoDBArgs, DynamoDBArgs,
GenerateKeyRequest, GenerateKeyRequest,
@ -99,6 +100,12 @@ request_data = {
], ],
} }
global_redis_cache = RedisCache(
host=os.getenv("REDIS_HOST"),
port=os.getenv("REDIS_PORT"),
password=os.getenv("REDIS_PASSWORD"),
)
@pytest.fixture @pytest.fixture
def prisma_client(): def prisma_client():
@ -131,12 +138,10 @@ async def setup_db_connection(prisma_client):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_pod_lock_acquisition_when_no_active_lock(prisma_client): async def test_pod_lock_acquisition_when_no_active_lock():
"""Test if a pod can acquire a lock when no lock is active""" """Test if a pod can acquire a lock when no lock is active"""
await setup_db_connection(prisma_client)
cronjob_id = str(uuid.uuid4()) cronjob_id = str(uuid.uuid4())
lock_manager = PodLockManager(cronjob_id=cronjob_id) lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
# Attempt to acquire lock # Attempt to acquire lock
result = await lock_manager.acquire_lock() result = await lock_manager.acquire_lock()
@ -144,96 +149,84 @@ async def test_pod_lock_acquisition_when_no_active_lock(prisma_client):
assert result == True, "Pod should be able to acquire lock when no lock exists" assert result == True, "Pod should be able to acquire lock when no lock exists"
# Verify in database # Verify in database
lock_record = await prisma_client.db.litellm_cronjob.find_first( lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
where={"cronjob_id": cronjob_id} lock_record = await global_redis_cache.async_get_cache(lock_key)
) print("lock_record=", lock_record)
assert lock_record.status == "ACTIVE" assert lock_record == lock_manager.pod_id
assert lock_record.pod_id == lock_manager.pod_id
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_pod_lock_acquisition_after_completion(prisma_client): async def test_pod_lock_acquisition_after_completion():
"""Test if a new pod can acquire lock after previous pod completes""" """Test if a new pod can acquire lock after previous pod completes"""
await setup_db_connection(prisma_client)
cronjob_id = str(uuid.uuid4()) cronjob_id = str(uuid.uuid4())
# First pod acquires and releases lock # First pod acquires and releases lock
first_lock_manager = PodLockManager(cronjob_id=cronjob_id) first_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
await first_lock_manager.acquire_lock() await first_lock_manager.acquire_lock()
await first_lock_manager.release_lock() await first_lock_manager.release_lock()
# Second pod attempts to acquire lock # Second pod attempts to acquire lock
second_lock_manager = PodLockManager(cronjob_id=cronjob_id) second_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
result = await second_lock_manager.acquire_lock() result = await second_lock_manager.acquire_lock()
assert result == True, "Second pod should acquire lock after first pod releases it" assert result == True, "Second pod should acquire lock after first pod releases it"
# Verify in database # Verify in redis
lock_record = await prisma_client.db.litellm_cronjob.find_first( lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
where={"cronjob_id": cronjob_id} lock_record = await global_redis_cache.async_get_cache(lock_key)
) assert lock_record == second_lock_manager.pod_id
assert lock_record.status == "ACTIVE"
assert lock_record.pod_id == second_lock_manager.pod_id
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_pod_lock_acquisition_after_expiry(prisma_client): async def test_pod_lock_acquisition_after_expiry():
"""Test if a new pod can acquire lock after previous pod's lock expires""" """Test if a new pod can acquire lock after previous pod's lock expires"""
await setup_db_connection(prisma_client)
cronjob_id = str(uuid.uuid4()) cronjob_id = str(uuid.uuid4())
# First pod acquires lock # First pod acquires lock
first_lock_manager = PodLockManager(cronjob_id=cronjob_id) first_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
await first_lock_manager.acquire_lock() await first_lock_manager.acquire_lock()
# release the lock from the first pod # release the lock from the first pod
await first_lock_manager.release_lock() await first_lock_manager.release_lock()
# Second pod attempts to acquire lock # Second pod attempts to acquire lock
second_lock_manager = PodLockManager(cronjob_id=cronjob_id) second_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
result = await second_lock_manager.acquire_lock() result = await second_lock_manager.acquire_lock()
assert ( assert (
result == True result == True
), "Second pod should acquire lock after first pod's lock expires" ), "Second pod should acquire lock after first pod's lock expires"
# Verify in database # Verify in redis
lock_record = await prisma_client.db.litellm_cronjob.find_first( lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
where={"cronjob_id": cronjob_id} lock_record = await global_redis_cache.async_get_cache(lock_key)
) assert lock_record == second_lock_manager.pod_id
assert lock_record.status == "ACTIVE"
assert lock_record.pod_id == second_lock_manager.pod_id
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_pod_lock_release(prisma_client): async def test_pod_lock_release():
"""Test if a pod can successfully release its lock""" """Test if a pod can successfully release its lock"""
await setup_db_connection(prisma_client)
cronjob_id = str(uuid.uuid4()) cronjob_id = str(uuid.uuid4())
lock_manager = PodLockManager(cronjob_id=cronjob_id) lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
# Acquire and then release lock # Acquire and then release lock
await lock_manager.acquire_lock() await lock_manager.acquire_lock()
await lock_manager.release_lock() await lock_manager.release_lock()
# Verify in database # Verify in redis
lock_record = await prisma_client.db.litellm_cronjob.find_first( lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
where={"cronjob_id": cronjob_id} lock_record = await global_redis_cache.async_get_cache(lock_key)
) assert lock_record is None
assert lock_record.status == "INACTIVE"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_concurrent_lock_acquisition(prisma_client): async def test_concurrent_lock_acquisition():
"""Test that only one pod can acquire the lock when multiple pods try simultaneously""" """Test that only one pod can acquire the lock when multiple pods try simultaneously"""
await setup_db_connection(prisma_client)
cronjob_id = str(uuid.uuid4()) cronjob_id = str(uuid.uuid4())
# Create multiple lock managers simulating different pods # Create multiple lock managers simulating different pods
lock_manager1 = PodLockManager(cronjob_id=cronjob_id) lock_manager1 = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
lock_manager2 = PodLockManager(cronjob_id=cronjob_id) lock_manager2 = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
lock_manager3 = PodLockManager(cronjob_id=cronjob_id) lock_manager3 = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
# Try to acquire locks concurrently # Try to acquire locks concurrently
results = await asyncio.gather( results = await asyncio.gather(
@ -246,109 +239,72 @@ async def test_concurrent_lock_acquisition(prisma_client):
print("all results=", results) print("all results=", results)
assert sum(results) == 1, "Only one pod should acquire the lock" assert sum(results) == 1, "Only one pod should acquire the lock"
# Verify in database # Verify in redis
lock_record = await prisma_client.db.litellm_cronjob.find_first( lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
where={"cronjob_id": cronjob_id} lock_record = await global_redis_cache.async_get_cache(lock_key)
) assert lock_record in [
assert lock_record.status == "ACTIVE"
assert lock_record.pod_id in [
lock_manager1.pod_id, lock_manager1.pod_id,
lock_manager2.pod_id, lock_manager2.pod_id,
lock_manager3.pod_id, lock_manager3.pod_id,
] ]
@pytest.mark.asyncio
async def test_lock_renewal(prisma_client):
"""Test that a pod can successfully renew its lock"""
await setup_db_connection(prisma_client)
cronjob_id = str(uuid.uuid4())
lock_manager = PodLockManager(cronjob_id=cronjob_id)
# Acquire initial lock
await lock_manager.acquire_lock()
# Get initial TTL
initial_record = await prisma_client.db.litellm_cronjob.find_first(
where={"cronjob_id": cronjob_id}
)
initial_ttl = initial_record.ttl
# Wait a short time
await asyncio.sleep(1)
# Renew the lock
await lock_manager.renew_lock()
# Get updated record
renewed_record = await prisma_client.db.litellm_cronjob.find_first(
where={"cronjob_id": cronjob_id}
)
assert renewed_record.ttl > initial_ttl, "Lock TTL should be extended after renewal"
assert renewed_record.status == "ACTIVE"
assert renewed_record.pod_id == lock_manager.pod_id
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_lock_acquisition_with_expired_ttl(prisma_client): async def test_lock_acquisition_with_expired_ttl():
"""Test that a pod can acquire a lock when existing lock has expired TTL""" """Test that a pod can acquire a lock when existing lock has expired TTL"""
await setup_db_connection(prisma_client)
cronjob_id = str(uuid.uuid4()) cronjob_id = str(uuid.uuid4())
first_lock_manager = PodLockManager(cronjob_id=cronjob_id) first_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
# First pod acquires lock # First pod acquires lock with a very short TTL to simulate expiration
await first_lock_manager.acquire_lock() short_ttl = 1 # 1 second
lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
# Manually expire the TTL await global_redis_cache.async_set_cache(
expired_time = datetime.now(timezone.utc) - timedelta(seconds=10) lock_key,
await prisma_client.db.litellm_cronjob.update( first_lock_manager.pod_id,
where={"cronjob_id": cronjob_id}, data={"ttl": expired_time} ttl=short_ttl,
) )
# Wait for the lock to expire
await asyncio.sleep(short_ttl + 0.5) # Wait slightly longer than the TTL
# Second pod tries to acquire without explicit release # Second pod tries to acquire without explicit release
second_lock_manager = PodLockManager(cronjob_id=cronjob_id) second_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
result = await second_lock_manager.acquire_lock() result = await second_lock_manager.acquire_lock()
assert result == True, "Should acquire lock when existing lock has expired TTL" assert result == True, "Should acquire lock when existing lock has expired TTL"
# Verify in database # Verify in Redis
lock_record = await prisma_client.db.litellm_cronjob.find_first( lock_record = await global_redis_cache.async_get_cache(lock_key)
where={"cronjob_id": cronjob_id} print("lock_record=", lock_record)
) assert lock_record == second_lock_manager.pod_id
assert lock_record.status == "ACTIVE"
assert lock_record.pod_id == second_lock_manager.pod_id
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_release_expired_lock(prisma_client): async def test_release_expired_lock():
"""Test that a pod cannot release a lock that has been taken over by another pod""" """Test that a pod cannot release a lock that has been taken over by another pod"""
await setup_db_connection(prisma_client)
cronjob_id = str(uuid.uuid4()) cronjob_id = str(uuid.uuid4())
first_lock_manager = PodLockManager(cronjob_id=cronjob_id)
# First pod acquires lock with a very short TTL
# First pod acquires lock first_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
await first_lock_manager.acquire_lock() short_ttl = 1 # 1 second
lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
# Manually expire the TTL await global_redis_cache.async_set_cache(
expired_time = datetime.now(timezone.utc) - timedelta(seconds=10) lock_key,
await prisma_client.db.litellm_cronjob.update( first_lock_manager.pod_id,
where={"cronjob_id": cronjob_id}, data={"ttl": expired_time} ttl=short_ttl,
) )
# Wait for the lock to expire
await asyncio.sleep(short_ttl + 0.5) # Wait slightly longer than the TTL
# Second pod acquires the lock # Second pod acquires the lock
second_lock_manager = PodLockManager(cronjob_id=cronjob_id) second_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
await second_lock_manager.acquire_lock() await second_lock_manager.acquire_lock()
# First pod attempts to release its lock # First pod attempts to release its lock
await first_lock_manager.release_lock() await first_lock_manager.release_lock()
# Verify that second pod's lock is still active # Verify that second pod's lock is still active
lock_record = await prisma_client.db.litellm_cronjob.find_first( lock_record = await global_redis_cache.async_get_cache(lock_key)
where={"cronjob_id": cronjob_id} assert lock_record == second_lock_manager.pod_id
)
assert lock_record.status == "ACTIVE"
assert lock_record.pod_id == second_lock_manager.pod_id