mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
rename pod lock manager
This commit is contained in:
parent
0edd4aa8a7
commit
c53d172b06
3 changed files with 131 additions and 4 deletions
|
@ -22,7 +22,7 @@ from litellm.proxy._types import (
|
|||
LiteLLM_UserTable,
|
||||
SpendLogsPayload,
|
||||
)
|
||||
from litellm.proxy.db.pod_leader_manager import PodLockManager
|
||||
from litellm.proxy.db.pod_lock_manager import PodLockManager
|
||||
from litellm.proxy.db.redis_update_buffer import (
|
||||
DBSpendUpdateTransactions,
|
||||
RedisUpdateBuffer,
|
||||
|
@ -49,7 +49,7 @@ class DBSpendUpdateWriter:
|
|||
):
|
||||
self.redis_cache = redis_cache
|
||||
self.redis_update_buffer = RedisUpdateBuffer(redis_cache=self.redis_cache)
|
||||
self.pod_leader_manager = PodLockManager(cronjob_id=DB_SPEND_UPDATE_JOB_NAME)
|
||||
self.pod_lock_manager = PodLockManager(cronjob_id=DB_SPEND_UPDATE_JOB_NAME)
|
||||
|
||||
@staticmethod
|
||||
async def update_database(
|
||||
|
@ -405,7 +405,7 @@ class DBSpendUpdateWriter:
|
|||
)
|
||||
|
||||
# Only commit from redis to db if this pod is the leader
|
||||
if await self.pod_leader_manager.acquire_lock():
|
||||
if await self.pod_lock_manager.acquire_lock():
|
||||
verbose_proxy_logger.debug("acquired lock for spend updates")
|
||||
|
||||
try:
|
||||
|
@ -422,7 +422,7 @@ class DBSpendUpdateWriter:
|
|||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error committing spend updates: {e}")
|
||||
finally:
|
||||
await self.pod_leader_manager.release_lock()
|
||||
await self.pod_lock_manager.release_lock()
|
||||
else:
|
||||
db_spend_update_transactions = DBSpendUpdateTransactions(
|
||||
user_list_transactions=prisma_client.user_list_transactions,
|
||||
|
|
127
tests/litellm/proxy/db/test_pod_lock_manager.py
Normal file
127
tests/litellm/proxy/db/test_pod_lock_manager.py
Normal file
|
@ -0,0 +1,127 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm.constants import DEFAULT_CRON_JOB_LOCK_TTL_SECONDS
|
||||
from litellm.proxy.db.pod_lock_manager import PodLockManager
|
||||
|
||||
|
||||
# Mock Prisma client class
|
||||
class MockPrismaClient:
|
||||
def __init__(self):
|
||||
self.db = MagicMock()
|
||||
self.db.litellm_cronjob = AsyncMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_prisma(monkeypatch):
|
||||
mock_client = MockPrismaClient()
|
||||
|
||||
# Mock the prisma_client import in proxy_server
|
||||
def mock_get_prisma():
|
||||
return mock_client
|
||||
|
||||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_client)
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pod_lock_manager():
|
||||
return PodLockManager(cronjob_id="test_job")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_lock_success(pod_lock_manager, mock_prisma):
|
||||
# Mock successful lock acquisition
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status = "ACTIVE"
|
||||
mock_response.pod_id = pod_lock_manager.pod_id
|
||||
mock_prisma.db.litellm_cronjob.upsert.return_value = mock_response
|
||||
|
||||
result = await pod_lock_manager.acquire_lock()
|
||||
assert result == True
|
||||
|
||||
# Verify upsert was called with correct parameters
|
||||
mock_prisma.db.litellm_cronjob.upsert.assert_called_once()
|
||||
call_args = mock_prisma.db.litellm_cronjob.upsert.call_args[1]
|
||||
assert call_args["where"]["cronjob_id"] == "test_job"
|
||||
assert "create" in call_args["data"]
|
||||
assert "update" in call_args["data"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_lock_failure(pod_lock_manager, mock_prisma):
|
||||
"""
|
||||
Test that the lock is not acquired if the lock is held by a different pod
|
||||
"""
|
||||
# Mock failed lock acquisition (different pod holds the lock)
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status = "ACTIVE"
|
||||
mock_response.pod_id = "different_pod_id"
|
||||
mock_prisma.db.litellm_cronjob.upsert.return_value = mock_response
|
||||
|
||||
result = await pod_lock_manager.acquire_lock()
|
||||
assert result == False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_renew_lock(pod_lock_manager, mock_prisma):
|
||||
# Mock successful lock renewal
|
||||
mock_prisma.db.litellm_cronjob.update.return_value = AsyncMock()
|
||||
|
||||
await pod_lock_manager.renew_lock()
|
||||
|
||||
# Verify update was called with correct parameters
|
||||
mock_prisma.db.litellm_cronjob.update.assert_called_once()
|
||||
call_args = mock_prisma.db.litellm_cronjob.update.call_args[1]
|
||||
assert call_args["where"]["cronjob_id"] == "test_job"
|
||||
assert call_args["where"]["pod_id"] == pod_lock_manager.pod_id
|
||||
assert "ttl" in call_args["data"]
|
||||
assert "last_updated" in call_args["data"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_release_lock(pod_lock_manager, mock_prisma):
|
||||
# Mock successful lock release
|
||||
mock_prisma.db.litellm_cronjob.update.return_value = AsyncMock()
|
||||
|
||||
await pod_lock_manager.release_lock()
|
||||
|
||||
# Verify update was called with correct parameters
|
||||
mock_prisma.db.litellm_cronjob.update.assert_called_once()
|
||||
call_args = mock_prisma.db.litellm_cronjob.update.call_args[1]
|
||||
assert call_args["where"]["cronjob_id"] == "test_job"
|
||||
assert call_args["where"]["pod_id"] == pod_lock_manager.pod_id
|
||||
assert call_args["data"]["status"] == "INACTIVE"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prisma_client_none(pod_lock_manager, monkeypatch):
|
||||
# Mock prisma_client as None
|
||||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", None)
|
||||
|
||||
# Test all methods with None client
|
||||
assert await pod_lock_manager.acquire_lock() == False
|
||||
assert await pod_lock_manager.renew_lock() == False
|
||||
assert await pod_lock_manager.release_lock() == False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_database_error_handling(pod_lock_manager, mock_prisma):
|
||||
# Mock database errors
|
||||
mock_prisma.db.litellm_cronjob.upsert.side_effect = Exception("Database error")
|
||||
mock_prisma.db.litellm_cronjob.update.side_effect = Exception("Database error")
|
||||
|
||||
# Test error handling in all methods
|
||||
assert await pod_lock_manager.acquire_lock() == False
|
||||
await pod_lock_manager.renew_lock() # Should not raise exception
|
||||
await pod_lock_manager.release_lock() # Should not raise exception
|
Loading…
Add table
Add a link
Reference in a new issue