test pod lock manager

This commit is contained in:
Ishaan Jaff 2025-04-02 14:39:40 -07:00
parent 2e939a21b3
commit a64631edfb
2 changed files with 167 additions and 219 deletions

View file

@ -1,137 +1,129 @@
import uuid
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Optional
from litellm._logging import verbose_proxy_logger
from litellm.caching.redis_cache import RedisCache
from litellm.constants import DEFAULT_CRON_JOB_LOCK_TTL_SECONDS
if TYPE_CHECKING:
from litellm.proxy.utils import PrismaClient, ProxyLogging
ProxyLogging = Any
else:
PrismaClient = Any
ProxyLogging = Any
class PodLockManager:
"""
Manager for acquiring and releasing locks for cron jobs.
Manager for acquiring and releasing locks for cron jobs using Redis.
Ensures that only one pod can run a cron job at a time.
"""
def __init__(self, cronjob_id: str):
def __init__(self, cronjob_id: str, redis_cache: Optional[RedisCache] = None):
self.pod_id = str(uuid.uuid4())
self.cronjob_id = cronjob_id
self.redis_cache = redis_cache
# Define a unique key for this cronjob lock in Redis.
self.lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
async def acquire_lock(self) -> bool:
@staticmethod
def get_redis_lock_key(cronjob_id: str) -> str:
return f"cronjob_lock:{cronjob_id}"
async def acquire_lock(self) -> Optional[bool]:
"""
Attempt to acquire the lock for a specific cron job using database locking.
Attempt to acquire the lock for a specific cron job using Redis.
Uses the SET command with NX and EX options to ensure atomicity.
"""
from litellm.proxy.proxy_server import prisma_client
verbose_proxy_logger.debug(
"Pod %s acquiring lock for cronjob_id=%s", self.pod_id, self.cronjob_id
)
if not prisma_client:
verbose_proxy_logger.debug("prisma is None, returning False")
return False
try:
current_time = datetime.now(timezone.utc)
ttl_expiry = current_time + timedelta(
seconds=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS
)
# Use Prisma's findUnique with FOR UPDATE lock to prevent race conditions
lock_record = await prisma_client.db.litellm_cronjob.find_unique(
where={"cronjob_id": self.cronjob_id},
)
if lock_record:
# If record exists, only update if it's inactive or expired
if lock_record.status == "ACTIVE" and lock_record.ttl > current_time:
return lock_record.pod_id == self.pod_id
# Update existing record
updated_lock = await prisma_client.db.litellm_cronjob.update(
where={"cronjob_id": self.cronjob_id},
data={
"pod_id": self.pod_id,
"status": "ACTIVE",
"last_updated": current_time,
"ttl": ttl_expiry,
},
)
else:
# Create new record if none exists
updated_lock = await prisma_client.db.litellm_cronjob.create(
data={
"cronjob_id": self.cronjob_id,
"pod_id": self.pod_id,
"status": "ACTIVE",
"last_updated": current_time,
"ttl": ttl_expiry,
}
)
return updated_lock.pod_id == self.pod_id
except Exception as e:
verbose_proxy_logger.error(
f"Error acquiring the lock for {self.cronjob_id}: {e}"
)
return False
async def renew_lock(self):
"""
Renew the lock (update the TTL) for the pod holding the lock.
"""
from litellm.proxy.proxy_server import prisma_client
if not prisma_client:
return False
if self.redis_cache is None:
verbose_proxy_logger.debug("redis_cache is None, skipping acquire_lock")
return None
try:
verbose_proxy_logger.debug(
"renewing lock for cronjob_id=%s", self.cronjob_id
"Pod %s attempting to acquire Redis lock for cronjob_id=%s",
self.pod_id,
self.cronjob_id,
)
current_time = datetime.now(timezone.utc)
# Extend the TTL for another DEFAULT_CRON_JOB_LOCK_TTL_SECONDS
ttl_expiry = current_time + timedelta(
seconds=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS
)
await prisma_client.db.litellm_cronjob.update(
where={"cronjob_id": self.cronjob_id, "pod_id": self.pod_id},
data={"ttl": ttl_expiry, "last_updated": current_time},
)
verbose_proxy_logger.info(
f"Renewed the lock for Pod {self.pod_id} for {self.cronjob_id}"
# Try to set the lock key with the pod_id as its value, only if it doesn't exist (NX)
# and with an expiration (EX) to avoid deadlocks.
acquired = await self.redis_cache.async_set_cache(
self.lock_key,
self.pod_id,
nx=True,
ttl=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS,
)
if acquired:
verbose_proxy_logger.info(
"Pod %s successfully acquired Redis lock for cronjob_id=%s",
self.pod_id,
self.cronjob_id,
)
return True
else:
# Check if the current pod already holds the lock
current_value = await self.redis_cache.async_get_cache(self.lock_key)
if current_value is not None:
if isinstance(current_value, bytes):
current_value = current_value.decode("utf-8")
if current_value == self.pod_id:
verbose_proxy_logger.info(
"Pod %s already holds the Redis lock for cronjob_id=%s",
self.pod_id,
self.cronjob_id,
)
return True
return False
except Exception as e:
verbose_proxy_logger.error(
f"Error renewing the lock for {self.cronjob_id}: {e}"
f"Error acquiring Redis lock for {self.cronjob_id}: {e}"
)
return False
async def release_lock(self):
"""
Release the lock and mark the pod as inactive.
Release the lock if the current pod holds it.
Uses get and delete commands to ensure that only the owner can release the lock.
"""
from litellm.proxy.proxy_server import prisma_client
if not prisma_client:
return False
if self.redis_cache is None:
verbose_proxy_logger.debug("redis_cache is None, skipping release_lock")
return
try:
verbose_proxy_logger.debug(
"Pod %s releasing lock for cronjob_id=%s", self.pod_id, self.cronjob_id
)
await prisma_client.db.litellm_cronjob.update(
where={"cronjob_id": self.cronjob_id, "pod_id": self.pod_id},
data={"status": "INACTIVE"},
)
verbose_proxy_logger.info(
f"Pod {self.pod_id} has released the lock for {self.cronjob_id}."
"Pod %s attempting to release Redis lock for cronjob_id=%s",
self.pod_id,
self.cronjob_id,
)
current_value = await self.redis_cache.async_get_cache(self.lock_key)
if current_value is not None:
if isinstance(current_value, bytes):
current_value = current_value.decode("utf-8")
if current_value == self.pod_id:
result = await self.redis_cache.async_delete_cache(self.lock_key)
if result == 1:
verbose_proxy_logger.info(
"Pod %s successfully released Redis lock for cronjob_id=%s",
self.pod_id,
self.cronjob_id,
)
else:
verbose_proxy_logger.debug(
"Pod %s failed to release Redis lock for cronjob_id=%s",
self.pod_id,
self.cronjob_id,
)
else:
verbose_proxy_logger.debug(
"Pod %s cannot release Redis lock for cronjob_id=%s because it is held by pod %s",
self.pod_id,
self.cronjob_id,
current_value,
)
else:
verbose_proxy_logger.debug(
"Pod %s attempted to release Redis lock for cronjob_id=%s, but no lock was found",
self.pod_id,
self.cronjob_id,
)
except Exception as e:
verbose_proxy_logger.error(
f"Error releasing the lock for {self.cronjob_id}: {e}"
f"Error releasing Redis lock for {self.cronjob_id}: {e}"
)

View file

@ -8,6 +8,7 @@ from dotenv import load_dotenv
from fastapi import Request
from fastapi.routing import APIRoute
import httpx
import json
load_dotenv()
import io
@ -72,7 +73,7 @@ verbose_proxy_logger.setLevel(level=logging.DEBUG)
from starlette.datastructures import URL
from litellm.caching.caching import DualCache
from litellm.caching.caching import DualCache, RedisCache
from litellm.proxy._types import (
DynamoDBArgs,
GenerateKeyRequest,
@ -99,6 +100,12 @@ request_data = {
],
}
global_redis_cache = RedisCache(
host=os.getenv("REDIS_HOST"),
port=os.getenv("REDIS_PORT"),
password=os.getenv("REDIS_PASSWORD"),
)
@pytest.fixture
def prisma_client():
@ -131,12 +138,10 @@ async def setup_db_connection(prisma_client):
@pytest.mark.asyncio
async def test_pod_lock_acquisition_when_no_active_lock(prisma_client):
async def test_pod_lock_acquisition_when_no_active_lock():
"""Test if a pod can acquire a lock when no lock is active"""
await setup_db_connection(prisma_client)
cronjob_id = str(uuid.uuid4())
lock_manager = PodLockManager(cronjob_id=cronjob_id)
lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
# Attempt to acquire lock
result = await lock_manager.acquire_lock()
@ -144,96 +149,84 @@ async def test_pod_lock_acquisition_when_no_active_lock(prisma_client):
assert result == True, "Pod should be able to acquire lock when no lock exists"
# Verify in database
lock_record = await prisma_client.db.litellm_cronjob.find_first(
where={"cronjob_id": cronjob_id}
)
assert lock_record.status == "ACTIVE"
assert lock_record.pod_id == lock_manager.pod_id
lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
lock_record = await global_redis_cache.async_get_cache(lock_key)
print("lock_record=", lock_record)
assert lock_record == lock_manager.pod_id
@pytest.mark.asyncio
async def test_pod_lock_acquisition_after_completion(prisma_client):
async def test_pod_lock_acquisition_after_completion():
"""Test if a new pod can acquire lock after previous pod completes"""
await setup_db_connection(prisma_client)
cronjob_id = str(uuid.uuid4())
# First pod acquires and releases lock
first_lock_manager = PodLockManager(cronjob_id=cronjob_id)
first_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
await first_lock_manager.acquire_lock()
await first_lock_manager.release_lock()
# Second pod attempts to acquire lock
second_lock_manager = PodLockManager(cronjob_id=cronjob_id)
second_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
result = await second_lock_manager.acquire_lock()
assert result == True, "Second pod should acquire lock after first pod releases it"
# Verify in database
lock_record = await prisma_client.db.litellm_cronjob.find_first(
where={"cronjob_id": cronjob_id}
)
assert lock_record.status == "ACTIVE"
assert lock_record.pod_id == second_lock_manager.pod_id
# Verify in redis
lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
lock_record = await global_redis_cache.async_get_cache(lock_key)
assert lock_record == second_lock_manager.pod_id
@pytest.mark.asyncio
async def test_pod_lock_acquisition_after_expiry(prisma_client):
async def test_pod_lock_acquisition_after_expiry():
"""Test if a new pod can acquire lock after previous pod's lock expires"""
await setup_db_connection(prisma_client)
cronjob_id = str(uuid.uuid4())
# First pod acquires lock
first_lock_manager = PodLockManager(cronjob_id=cronjob_id)
first_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
await first_lock_manager.acquire_lock()
# release the lock from the first pod
await first_lock_manager.release_lock()
# Second pod attempts to acquire lock
second_lock_manager = PodLockManager(cronjob_id=cronjob_id)
second_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
result = await second_lock_manager.acquire_lock()
assert (
result == True
), "Second pod should acquire lock after first pod's lock expires"
# Verify in database
lock_record = await prisma_client.db.litellm_cronjob.find_first(
where={"cronjob_id": cronjob_id}
)
assert lock_record.status == "ACTIVE"
assert lock_record.pod_id == second_lock_manager.pod_id
# Verify in redis
lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
lock_record = await global_redis_cache.async_get_cache(lock_key)
assert lock_record == second_lock_manager.pod_id
@pytest.mark.asyncio
async def test_pod_lock_release(prisma_client):
async def test_pod_lock_release():
"""Test if a pod can successfully release its lock"""
await setup_db_connection(prisma_client)
cronjob_id = str(uuid.uuid4())
lock_manager = PodLockManager(cronjob_id=cronjob_id)
lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
# Acquire and then release lock
await lock_manager.acquire_lock()
await lock_manager.release_lock()
# Verify in database
lock_record = await prisma_client.db.litellm_cronjob.find_first(
where={"cronjob_id": cronjob_id}
)
assert lock_record.status == "INACTIVE"
# Verify in redis
lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
lock_record = await global_redis_cache.async_get_cache(lock_key)
assert lock_record is None
@pytest.mark.asyncio
async def test_concurrent_lock_acquisition(prisma_client):
async def test_concurrent_lock_acquisition():
"""Test that only one pod can acquire the lock when multiple pods try simultaneously"""
await setup_db_connection(prisma_client)
cronjob_id = str(uuid.uuid4())
# Create multiple lock managers simulating different pods
lock_manager1 = PodLockManager(cronjob_id=cronjob_id)
lock_manager2 = PodLockManager(cronjob_id=cronjob_id)
lock_manager3 = PodLockManager(cronjob_id=cronjob_id)
lock_manager1 = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
lock_manager2 = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
lock_manager3 = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
# Try to acquire locks concurrently
results = await asyncio.gather(
@ -246,109 +239,72 @@ async def test_concurrent_lock_acquisition(prisma_client):
print("all results=", results)
assert sum(results) == 1, "Only one pod should acquire the lock"
# Verify in database
lock_record = await prisma_client.db.litellm_cronjob.find_first(
where={"cronjob_id": cronjob_id}
)
assert lock_record.status == "ACTIVE"
assert lock_record.pod_id in [
# Verify in redis
lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
lock_record = await global_redis_cache.async_get_cache(lock_key)
assert lock_record in [
lock_manager1.pod_id,
lock_manager2.pod_id,
lock_manager3.pod_id,
]
@pytest.mark.asyncio
async def test_lock_renewal(prisma_client):
"""Test that a pod can successfully renew its lock"""
await setup_db_connection(prisma_client)
cronjob_id = str(uuid.uuid4())
lock_manager = PodLockManager(cronjob_id=cronjob_id)
# Acquire initial lock
await lock_manager.acquire_lock()
# Get initial TTL
initial_record = await prisma_client.db.litellm_cronjob.find_first(
where={"cronjob_id": cronjob_id}
)
initial_ttl = initial_record.ttl
# Wait a short time
await asyncio.sleep(1)
# Renew the lock
await lock_manager.renew_lock()
# Get updated record
renewed_record = await prisma_client.db.litellm_cronjob.find_first(
where={"cronjob_id": cronjob_id}
)
assert renewed_record.ttl > initial_ttl, "Lock TTL should be extended after renewal"
assert renewed_record.status == "ACTIVE"
assert renewed_record.pod_id == lock_manager.pod_id
@pytest.mark.asyncio
async def test_lock_acquisition_with_expired_ttl(prisma_client):
async def test_lock_acquisition_with_expired_ttl():
"""Test that a pod can acquire a lock when existing lock has expired TTL"""
await setup_db_connection(prisma_client)
cronjob_id = str(uuid.uuid4())
first_lock_manager = PodLockManager(cronjob_id=cronjob_id)
first_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
# First pod acquires lock
await first_lock_manager.acquire_lock()
# Manually expire the TTL
expired_time = datetime.now(timezone.utc) - timedelta(seconds=10)
await prisma_client.db.litellm_cronjob.update(
where={"cronjob_id": cronjob_id}, data={"ttl": expired_time}
# First pod acquires lock with a very short TTL to simulate expiration
short_ttl = 1 # 1 second
lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
await global_redis_cache.async_set_cache(
lock_key,
first_lock_manager.pod_id,
ttl=short_ttl,
)
# Wait for the lock to expire
await asyncio.sleep(short_ttl + 0.5) # Wait slightly longer than the TTL
# Second pod tries to acquire without explicit release
second_lock_manager = PodLockManager(cronjob_id=cronjob_id)
second_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
result = await second_lock_manager.acquire_lock()
assert result == True, "Should acquire lock when existing lock has expired TTL"
# Verify in database
lock_record = await prisma_client.db.litellm_cronjob.find_first(
where={"cronjob_id": cronjob_id}
)
assert lock_record.status == "ACTIVE"
assert lock_record.pod_id == second_lock_manager.pod_id
# Verify in Redis
lock_record = await global_redis_cache.async_get_cache(lock_key)
print("lock_record=", lock_record)
assert lock_record == second_lock_manager.pod_id
@pytest.mark.asyncio
async def test_release_expired_lock(prisma_client):
async def test_release_expired_lock():
"""Test that a pod cannot release a lock that has been taken over by another pod"""
await setup_db_connection(prisma_client)
cronjob_id = str(uuid.uuid4())
first_lock_manager = PodLockManager(cronjob_id=cronjob_id)
# First pod acquires lock
await first_lock_manager.acquire_lock()
# Manually expire the TTL
expired_time = datetime.now(timezone.utc) - timedelta(seconds=10)
await prisma_client.db.litellm_cronjob.update(
where={"cronjob_id": cronjob_id}, data={"ttl": expired_time}
# First pod acquires lock with a very short TTL
first_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
short_ttl = 1 # 1 second
lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
await global_redis_cache.async_set_cache(
lock_key,
first_lock_manager.pod_id,
ttl=short_ttl,
)
# Wait for the lock to expire
await asyncio.sleep(short_ttl + 0.5) # Wait slightly longer than the TTL
# Second pod acquires the lock
second_lock_manager = PodLockManager(cronjob_id=cronjob_id)
second_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache)
await second_lock_manager.acquire_lock()
# First pod attempts to release its lock
await first_lock_manager.release_lock()
# Verify that second pod's lock is still active
lock_record = await prisma_client.db.litellm_cronjob.find_first(
where={"cronjob_id": cronjob_id}
)
assert lock_record.status == "ACTIVE"
assert lock_record.pod_id == second_lock_manager.pod_id
lock_record = await global_redis_cache.async_get_cache(lock_key)
assert lock_record == second_lock_manager.pod_id