test pod lock manager

This commit is contained in:
Ishaan Jaff 2025-04-02 14:39:40 -07:00
parent 2e939a21b3
commit a64631edfb
2 changed files with 167 additions and 219 deletions

View file

@ -1,137 +1,129 @@
import uuid import uuid
from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.caching.redis_cache import RedisCache
from litellm.constants import DEFAULT_CRON_JOB_LOCK_TTL_SECONDS from litellm.constants import DEFAULT_CRON_JOB_LOCK_TTL_SECONDS
if TYPE_CHECKING: if TYPE_CHECKING:
from litellm.proxy.utils import PrismaClient, ProxyLogging ProxyLogging = Any
else: else:
PrismaClient = Any
ProxyLogging = Any ProxyLogging = Any
class PodLockManager: class PodLockManager:
""" """
Manager for acquiring and releasing locks for cron jobs. Manager for acquiring and releasing locks for cron jobs using Redis.
Ensures that only one pod can run a cron job at a time. Ensures that only one pod can run a cron job at a time.
""" """
def __init__(self, cronjob_id: str): def __init__(self, cronjob_id: str, redis_cache: Optional[RedisCache] = None):
self.pod_id = str(uuid.uuid4()) self.pod_id = str(uuid.uuid4())
self.cronjob_id = cronjob_id self.cronjob_id = cronjob_id
self.redis_cache = redis_cache
# Define a unique key for this cronjob lock in Redis.
self.lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
async def acquire_lock(self) -> bool: @staticmethod
def get_redis_lock_key(cronjob_id: str) -> str:
return f"cronjob_lock:{cronjob_id}"
async def acquire_lock(self) -> Optional[bool]:
""" """
Attempt to acquire the lock for a specific cron job using database locking. Attempt to acquire the lock for a specific cron job using Redis.
Uses the SET command with NX and EX options to ensure atomicity.
""" """
from litellm.proxy.proxy_server import prisma_client if self.redis_cache is None:
verbose_proxy_logger.debug("redis_cache is None, skipping acquire_lock")
verbose_proxy_logger.debug( return None
"Pod %s acquiring lock for cronjob_id=%s", self.pod_id, self.cronjob_id
)
if not prisma_client:
verbose_proxy_logger.debug("prisma is None, returning False")
return False
try:
current_time = datetime.now(timezone.utc)
ttl_expiry = current_time + timedelta(
seconds=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS
)
# Use Prisma's findUnique with FOR UPDATE lock to prevent race conditions
lock_record = await prisma_client.db.litellm_cronjob.find_unique(
where={"cronjob_id": self.cronjob_id},
)
if lock_record:
# If record exists, only update if it's inactive or expired
if lock_record.status == "ACTIVE" and lock_record.ttl > current_time:
return lock_record.pod_id == self.pod_id
# Update existing record
updated_lock = await prisma_client.db.litellm_cronjob.update(
where={"cronjob_id": self.cronjob_id},
data={
"pod_id": self.pod_id,
"status": "ACTIVE",
"last_updated": current_time,
"ttl": ttl_expiry,
},
)
else:
# Create new record if none exists
updated_lock = await prisma_client.db.litellm_cronjob.create(
data={
"cronjob_id": self.cronjob_id,
"pod_id": self.pod_id,
"status": "ACTIVE",
"last_updated": current_time,
"ttl": ttl_expiry,
}
)
return updated_lock.pod_id == self.pod_id
except Exception as e:
verbose_proxy_logger.error(
f"Error acquiring the lock for {self.cronjob_id}: {e}"
)
return False
async def renew_lock(self):
"""
Renew the lock (update the TTL) for the pod holding the lock.
"""
from litellm.proxy.proxy_server import prisma_client
if not prisma_client:
return False
try: try:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"renewing lock for cronjob_id=%s", self.cronjob_id "Pod %s attempting to acquire Redis lock for cronjob_id=%s",
self.pod_id,
self.cronjob_id,
) )
current_time = datetime.now(timezone.utc) # Try to set the lock key with the pod_id as its value, only if it doesn't exist (NX)
# Extend the TTL for another DEFAULT_CRON_JOB_LOCK_TTL_SECONDS # and with an expiration (EX) to avoid deadlocks.
ttl_expiry = current_time + timedelta( acquired = await self.redis_cache.async_set_cache(
seconds=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS self.lock_key,
) self.pod_id,
nx=True,
await prisma_client.db.litellm_cronjob.update( ttl=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS,
where={"cronjob_id": self.cronjob_id, "pod_id": self.pod_id},
data={"ttl": ttl_expiry, "last_updated": current_time},
) )
if acquired:
verbose_proxy_logger.info( verbose_proxy_logger.info(
f"Renewed the lock for Pod {self.pod_id} for {self.cronjob_id}" "Pod %s successfully acquired Redis lock for cronjob_id=%s",
self.pod_id,
self.cronjob_id,
) )
return True
else:
# Check if the current pod already holds the lock
current_value = await self.redis_cache.async_get_cache(self.lock_key)
if current_value is not None:
if isinstance(current_value, bytes):
current_value = current_value.decode("utf-8")
if current_value == self.pod_id:
verbose_proxy_logger.info(
"Pod %s already holds the Redis lock for cronjob_id=%s",
self.pod_id,
self.cronjob_id,
)
return True
return False
except Exception as e: except Exception as e:
verbose_proxy_logger.error( verbose_proxy_logger.error(
f"Error renewing the lock for {self.cronjob_id}: {e}" f"Error acquiring Redis lock for {self.cronjob_id}: {e}"
) )
return False
async def release_lock(self): async def release_lock(self):
""" """
Release the lock and mark the pod as inactive. Release the lock if the current pod holds it.
Uses get and delete commands to ensure that only the owner can release the lock.
""" """
from litellm.proxy.proxy_server import prisma_client if self.redis_cache is None:
verbose_proxy_logger.debug("redis_cache is None, skipping release_lock")
if not prisma_client: return
return False
try: try:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"Pod %s releasing lock for cronjob_id=%s", self.pod_id, self.cronjob_id "Pod %s attempting to release Redis lock for cronjob_id=%s",
) self.pod_id,
await prisma_client.db.litellm_cronjob.update( self.cronjob_id,
where={"cronjob_id": self.cronjob_id, "pod_id": self.pod_id},
data={"status": "INACTIVE"},
) )
current_value = await self.redis_cache.async_get_cache(self.lock_key)
if current_value is not None:
if isinstance(current_value, bytes):
current_value = current_value.decode("utf-8")
if current_value == self.pod_id:
result = await self.redis_cache.async_delete_cache(self.lock_key)
if result == 1:
verbose_proxy_logger.info( verbose_proxy_logger.info(
f"Pod {self.pod_id} has released the lock for {self.cronjob_id}." "Pod %s successfully released Redis lock for cronjob_id=%s",
self.pod_id,
self.cronjob_id,
)
else:
verbose_proxy_logger.debug(
"Pod %s failed to release Redis lock for cronjob_id=%s",
self.pod_id,
self.cronjob_id,
)
else:
verbose_proxy_logger.debug(
"Pod %s cannot release Redis lock for cronjob_id=%s because it is held by pod %s",
self.pod_id,
self.cronjob_id,
current_value,
)
else:
verbose_proxy_logger.debug(
"Pod %s attempted to release Redis lock for cronjob_id=%s, but no lock was found",
self.pod_id,
self.cronjob_id,
) )
except Exception as e: except Exception as e:
verbose_proxy_logger.error( verbose_proxy_logger.error(
f"Error releasing the lock for {self.cronjob_id}: {e}" f"Error releasing Redis lock for {self.cronjob_id}: {e}"
) )

View file

@ -8,6 +8,7 @@ from dotenv import load_dotenv
from fastapi import Request from fastapi import Request
from fastapi.routing import APIRoute from fastapi.routing import APIRoute
import httpx import httpx
import json
load_dotenv() load_dotenv()
import io import io
@ -72,7 +73,7 @@ verbose_proxy_logger.setLevel(level=logging.DEBUG)
from starlette.datastructures import URL from starlette.datastructures import URL
from litellm.caching.caching import DualCache from litellm.caching.caching import DualCache, RedisCache
from litellm.proxy._types import ( from litellm.proxy._types import (
DynamoDBArgs, DynamoDBArgs,
GenerateKeyRequest, GenerateKeyRequest,
@ -99,6 +100,12 @@ request_data = {
], ],
} }
global_redis_cache = RedisCache(
host=os.getenv("REDIS_HOST"),
port=os.getenv("REDIS_PORT"),
password=os.getenv("REDIS_PASSWORD"),
)
@pytest.fixture @pytest.fixture
def prisma_client(): def prisma_client():
@ -131,12 +138,10 @@ async def setup_db_connection(prisma_client):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_pod_lock_acquisition_when_no_active_lock(prisma_client): async def test_pod_lock_acquisition_when_no_active_lock():
"""Test if a pod can acquire a lock when no lock is active""" """Test if a pod can acquire a lock when no lock is active"""
await setup_db_connection(prisma_client)
cronjob_id = str(uuid.uuid4()) cronjob_id = str(uuid.uuid4())
lock_manager = PodLockManager(cronjob_id=cronjob_id) lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
# Attempt to acquire lock # Attempt to acquire lock
result = await lock_manager.acquire_lock() result = await lock_manager.acquire_lock()
@ -144,96 +149,84 @@ async def test_pod_lock_acquisition_when_no_active_lock(prisma_client):
assert result == True, "Pod should be able to acquire lock when no lock exists" assert result == True, "Pod should be able to acquire lock when no lock exists"
# Verify in database # Verify in database
lock_record = await prisma_client.db.litellm_cronjob.find_first( lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
where={"cronjob_id": cronjob_id} lock_record = await global_redis_cache.async_get_cache(lock_key)
) print("lock_record=", lock_record)
assert lock_record.status == "ACTIVE" assert lock_record == lock_manager.pod_id
assert lock_record.pod_id == lock_manager.pod_id
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_pod_lock_acquisition_after_completion(prisma_client): async def test_pod_lock_acquisition_after_completion():
"""Test if a new pod can acquire lock after previous pod completes""" """Test if a new pod can acquire lock after previous pod completes"""
await setup_db_connection(prisma_client)
cronjob_id = str(uuid.uuid4()) cronjob_id = str(uuid.uuid4())
# First pod acquires and releases lock # First pod acquires and releases lock
first_lock_manager = PodLockManager(cronjob_id=cronjob_id) first_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
await first_lock_manager.acquire_lock() await first_lock_manager.acquire_lock()
await first_lock_manager.release_lock() await first_lock_manager.release_lock()
# Second pod attempts to acquire lock # Second pod attempts to acquire lock
second_lock_manager = PodLockManager(cronjob_id=cronjob_id) second_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
result = await second_lock_manager.acquire_lock() result = await second_lock_manager.acquire_lock()
assert result == True, "Second pod should acquire lock after first pod releases it" assert result == True, "Second pod should acquire lock after first pod releases it"
# Verify in database # Verify in redis
lock_record = await prisma_client.db.litellm_cronjob.find_first( lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
where={"cronjob_id": cronjob_id} lock_record = await global_redis_cache.async_get_cache(lock_key)
) assert lock_record == second_lock_manager.pod_id
assert lock_record.status == "ACTIVE"
assert lock_record.pod_id == second_lock_manager.pod_id
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_pod_lock_acquisition_after_expiry(prisma_client): async def test_pod_lock_acquisition_after_expiry():
"""Test if a new pod can acquire lock after previous pod's lock expires""" """Test if a new pod can acquire lock after previous pod's lock expires"""
await setup_db_connection(prisma_client)
cronjob_id = str(uuid.uuid4()) cronjob_id = str(uuid.uuid4())
# First pod acquires lock # First pod acquires lock
first_lock_manager = PodLockManager(cronjob_id=cronjob_id) first_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
await first_lock_manager.acquire_lock() await first_lock_manager.acquire_lock()
# release the lock from the first pod # release the lock from the first pod
await first_lock_manager.release_lock() await first_lock_manager.release_lock()
# Second pod attempts to acquire lock # Second pod attempts to acquire lock
second_lock_manager = PodLockManager(cronjob_id=cronjob_id) second_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
result = await second_lock_manager.acquire_lock() result = await second_lock_manager.acquire_lock()
assert ( assert (
result == True result == True
), "Second pod should acquire lock after first pod's lock expires" ), "Second pod should acquire lock after first pod's lock expires"
# Verify in database # Verify in redis
lock_record = await prisma_client.db.litellm_cronjob.find_first( lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
where={"cronjob_id": cronjob_id} lock_record = await global_redis_cache.async_get_cache(lock_key)
) assert lock_record == second_lock_manager.pod_id
assert lock_record.status == "ACTIVE"
assert lock_record.pod_id == second_lock_manager.pod_id
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_pod_lock_release(prisma_client): async def test_pod_lock_release():
"""Test if a pod can successfully release its lock""" """Test if a pod can successfully release its lock"""
await setup_db_connection(prisma_client)
cronjob_id = str(uuid.uuid4()) cronjob_id = str(uuid.uuid4())
lock_manager = PodLockManager(cronjob_id=cronjob_id) lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
# Acquire and then release lock # Acquire and then release lock
await lock_manager.acquire_lock() await lock_manager.acquire_lock()
await lock_manager.release_lock() await lock_manager.release_lock()
# Verify in database # Verify in redis
lock_record = await prisma_client.db.litellm_cronjob.find_first( lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
where={"cronjob_id": cronjob_id} lock_record = await global_redis_cache.async_get_cache(lock_key)
) assert lock_record is None
assert lock_record.status == "INACTIVE"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_concurrent_lock_acquisition(prisma_client): async def test_concurrent_lock_acquisition():
"""Test that only one pod can acquire the lock when multiple pods try simultaneously""" """Test that only one pod can acquire the lock when multiple pods try simultaneously"""
await setup_db_connection(prisma_client)
cronjob_id = str(uuid.uuid4()) cronjob_id = str(uuid.uuid4())
# Create multiple lock managers simulating different pods # Create multiple lock managers simulating different pods
lock_manager1 = PodLockManager(cronjob_id=cronjob_id) lock_manager1 = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
lock_manager2 = PodLockManager(cronjob_id=cronjob_id) lock_manager2 = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
lock_manager3 = PodLockManager(cronjob_id=cronjob_id) lock_manager3 = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
# Try to acquire locks concurrently # Try to acquire locks concurrently
results = await asyncio.gather( results = await asyncio.gather(
@ -246,109 +239,72 @@ async def test_concurrent_lock_acquisition(prisma_client):
print("all results=", results) print("all results=", results)
assert sum(results) == 1, "Only one pod should acquire the lock" assert sum(results) == 1, "Only one pod should acquire the lock"
# Verify in database # Verify in redis
lock_record = await prisma_client.db.litellm_cronjob.find_first( lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
where={"cronjob_id": cronjob_id} lock_record = await global_redis_cache.async_get_cache(lock_key)
) assert lock_record in [
assert lock_record.status == "ACTIVE"
assert lock_record.pod_id in [
lock_manager1.pod_id, lock_manager1.pod_id,
lock_manager2.pod_id, lock_manager2.pod_id,
lock_manager3.pod_id, lock_manager3.pod_id,
] ]
@pytest.mark.asyncio
async def test_lock_renewal(prisma_client):
"""Test that a pod can successfully renew its lock"""
await setup_db_connection(prisma_client)
cronjob_id = str(uuid.uuid4())
lock_manager = PodLockManager(cronjob_id=cronjob_id)
# Acquire initial lock
await lock_manager.acquire_lock()
# Get initial TTL
initial_record = await prisma_client.db.litellm_cronjob.find_first(
where={"cronjob_id": cronjob_id}
)
initial_ttl = initial_record.ttl
# Wait a short time
await asyncio.sleep(1)
# Renew the lock
await lock_manager.renew_lock()
# Get updated record
renewed_record = await prisma_client.db.litellm_cronjob.find_first(
where={"cronjob_id": cronjob_id}
)
assert renewed_record.ttl > initial_ttl, "Lock TTL should be extended after renewal"
assert renewed_record.status == "ACTIVE"
assert renewed_record.pod_id == lock_manager.pod_id
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_lock_acquisition_with_expired_ttl(prisma_client): async def test_lock_acquisition_with_expired_ttl():
"""Test that a pod can acquire a lock when existing lock has expired TTL""" """Test that a pod can acquire a lock when existing lock has expired TTL"""
await setup_db_connection(prisma_client)
cronjob_id = str(uuid.uuid4()) cronjob_id = str(uuid.uuid4())
first_lock_manager = PodLockManager(cronjob_id=cronjob_id) first_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
# First pod acquires lock # First pod acquires lock with a very short TTL to simulate expiration
await first_lock_manager.acquire_lock() short_ttl = 1 # 1 second
lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
# Manually expire the TTL await global_redis_cache.async_set_cache(
expired_time = datetime.now(timezone.utc) - timedelta(seconds=10) lock_key,
await prisma_client.db.litellm_cronjob.update( first_lock_manager.pod_id,
where={"cronjob_id": cronjob_id}, data={"ttl": expired_time} ttl=short_ttl,
) )
# Wait for the lock to expire
await asyncio.sleep(short_ttl + 0.5) # Wait slightly longer than the TTL
# Second pod tries to acquire without explicit release # Second pod tries to acquire without explicit release
second_lock_manager = PodLockManager(cronjob_id=cronjob_id) second_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
result = await second_lock_manager.acquire_lock() result = await second_lock_manager.acquire_lock()
assert result == True, "Should acquire lock when existing lock has expired TTL" assert result == True, "Should acquire lock when existing lock has expired TTL"
# Verify in database # Verify in Redis
lock_record = await prisma_client.db.litellm_cronjob.find_first( lock_record = await global_redis_cache.async_get_cache(lock_key)
where={"cronjob_id": cronjob_id} print("lock_record=", lock_record)
) assert lock_record == second_lock_manager.pod_id
assert lock_record.status == "ACTIVE"
assert lock_record.pod_id == second_lock_manager.pod_id
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_release_expired_lock(prisma_client): async def test_release_expired_lock():
"""Test that a pod cannot release a lock that has been taken over by another pod""" """Test that a pod cannot release a lock that has been taken over by another pod"""
await setup_db_connection(prisma_client)
cronjob_id = str(uuid.uuid4()) cronjob_id = str(uuid.uuid4())
first_lock_manager = PodLockManager(cronjob_id=cronjob_id)
# First pod acquires lock # First pod acquires lock with a very short TTL
await first_lock_manager.acquire_lock() first_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
short_ttl = 1 # 1 second
# Manually expire the TTL lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
expired_time = datetime.now(timezone.utc) - timedelta(seconds=10) await global_redis_cache.async_set_cache(
await prisma_client.db.litellm_cronjob.update( lock_key,
where={"cronjob_id": cronjob_id}, data={"ttl": expired_time} first_lock_manager.pod_id,
ttl=short_ttl,
) )
# Wait for the lock to expire
await asyncio.sleep(short_ttl + 0.5) # Wait slightly longer than the TTL
# Second pod acquires the lock # Second pod acquires the lock
second_lock_manager = PodLockManager(cronjob_id=cronjob_id) second_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
await second_lock_manager.acquire_lock() await second_lock_manager.acquire_lock()
# First pod attempts to release its lock # First pod attempts to release its lock
await first_lock_manager.release_lock() await first_lock_manager.release_lock()
# Verify that second pod's lock is still active # Verify that second pod's lock is still active
lock_record = await prisma_client.db.litellm_cronjob.find_first( lock_record = await global_redis_cache.async_get_cache(lock_key)
where={"cronjob_id": cronjob_id} assert lock_record == second_lock_manager.pod_id
)
assert lock_record.status == "ACTIVE"
assert lock_record.pod_id == second_lock_manager.pod_id