mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
rename pod lock manager
This commit is contained in:
parent
0edd4aa8a7
commit
c53d172b06
3 changed files with 131 additions and 4 deletions
|
@ -22,7 +22,7 @@ from litellm.proxy._types import (
|
||||||
LiteLLM_UserTable,
|
LiteLLM_UserTable,
|
||||||
SpendLogsPayload,
|
SpendLogsPayload,
|
||||||
)
|
)
|
||||||
from litellm.proxy.db.pod_leader_manager import PodLockManager
|
from litellm.proxy.db.pod_lock_manager import PodLockManager
|
||||||
from litellm.proxy.db.redis_update_buffer import (
|
from litellm.proxy.db.redis_update_buffer import (
|
||||||
DBSpendUpdateTransactions,
|
DBSpendUpdateTransactions,
|
||||||
RedisUpdateBuffer,
|
RedisUpdateBuffer,
|
||||||
|
@ -49,7 +49,7 @@ class DBSpendUpdateWriter:
|
||||||
):
|
):
|
||||||
self.redis_cache = redis_cache
|
self.redis_cache = redis_cache
|
||||||
self.redis_update_buffer = RedisUpdateBuffer(redis_cache=self.redis_cache)
|
self.redis_update_buffer = RedisUpdateBuffer(redis_cache=self.redis_cache)
|
||||||
self.pod_leader_manager = PodLockManager(cronjob_id=DB_SPEND_UPDATE_JOB_NAME)
|
self.pod_lock_manager = PodLockManager(cronjob_id=DB_SPEND_UPDATE_JOB_NAME)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def update_database(
|
async def update_database(
|
||||||
|
@ -405,7 +405,7 @@ class DBSpendUpdateWriter:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only commit from redis to db if this pod is the leader
|
# Only commit from redis to db if this pod is the leader
|
||||||
if await self.pod_leader_manager.acquire_lock():
|
if await self.pod_lock_manager.acquire_lock():
|
||||||
verbose_proxy_logger.debug("acquired lock for spend updates")
|
verbose_proxy_logger.debug("acquired lock for spend updates")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -422,7 +422,7 @@ class DBSpendUpdateWriter:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.error(f"Error committing spend updates: {e}")
|
verbose_proxy_logger.error(f"Error committing spend updates: {e}")
|
||||||
finally:
|
finally:
|
||||||
await self.pod_leader_manager.release_lock()
|
await self.pod_lock_manager.release_lock()
|
||||||
else:
|
else:
|
||||||
db_spend_update_transactions = DBSpendUpdateTransactions(
|
db_spend_update_transactions = DBSpendUpdateTransactions(
|
||||||
user_list_transactions=prisma_client.user_list_transactions,
|
user_list_transactions=prisma_client.user_list_transactions,
|
||||||
|
|
127
tests/litellm/proxy/db/test_pod_lock_manager.py
Normal file
127
tests/litellm/proxy/db/test_pod_lock_manager.py
Normal file
|
@ -0,0 +1,127 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
from litellm.constants import DEFAULT_CRON_JOB_LOCK_TTL_SECONDS
|
||||||
|
from litellm.proxy.db.pod_lock_manager import PodLockManager
|
||||||
|
|
||||||
|
|
||||||
|
# Mock Prisma client class
|
||||||
|
class MockPrismaClient:
|
||||||
|
def __init__(self):
|
||||||
|
self.db = MagicMock()
|
||||||
|
self.db.litellm_cronjob = AsyncMock()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_prisma(monkeypatch):
|
||||||
|
mock_client = MockPrismaClient()
|
||||||
|
|
||||||
|
# Mock the prisma_client import in proxy_server
|
||||||
|
def mock_get_prisma():
|
||||||
|
return mock_client
|
||||||
|
|
||||||
|
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_client)
|
||||||
|
return mock_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def pod_lock_manager():
|
||||||
|
return PodLockManager(cronjob_id="test_job")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_acquire_lock_success(pod_lock_manager, mock_prisma):
|
||||||
|
# Mock successful lock acquisition
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status = "ACTIVE"
|
||||||
|
mock_response.pod_id = pod_lock_manager.pod_id
|
||||||
|
mock_prisma.db.litellm_cronjob.upsert.return_value = mock_response
|
||||||
|
|
||||||
|
result = await pod_lock_manager.acquire_lock()
|
||||||
|
assert result == True
|
||||||
|
|
||||||
|
# Verify upsert was called with correct parameters
|
||||||
|
mock_prisma.db.litellm_cronjob.upsert.assert_called_once()
|
||||||
|
call_args = mock_prisma.db.litellm_cronjob.upsert.call_args[1]
|
||||||
|
assert call_args["where"]["cronjob_id"] == "test_job"
|
||||||
|
assert "create" in call_args["data"]
|
||||||
|
assert "update" in call_args["data"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_acquire_lock_failure(pod_lock_manager, mock_prisma):
|
||||||
|
"""
|
||||||
|
Test that the lock is not acquired if the lock is held by a different pod
|
||||||
|
"""
|
||||||
|
# Mock failed lock acquisition (different pod holds the lock)
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status = "ACTIVE"
|
||||||
|
mock_response.pod_id = "different_pod_id"
|
||||||
|
mock_prisma.db.litellm_cronjob.upsert.return_value = mock_response
|
||||||
|
|
||||||
|
result = await pod_lock_manager.acquire_lock()
|
||||||
|
assert result == False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_renew_lock(pod_lock_manager, mock_prisma):
|
||||||
|
# Mock successful lock renewal
|
||||||
|
mock_prisma.db.litellm_cronjob.update.return_value = AsyncMock()
|
||||||
|
|
||||||
|
await pod_lock_manager.renew_lock()
|
||||||
|
|
||||||
|
# Verify update was called with correct parameters
|
||||||
|
mock_prisma.db.litellm_cronjob.update.assert_called_once()
|
||||||
|
call_args = mock_prisma.db.litellm_cronjob.update.call_args[1]
|
||||||
|
assert call_args["where"]["cronjob_id"] == "test_job"
|
||||||
|
assert call_args["where"]["pod_id"] == pod_lock_manager.pod_id
|
||||||
|
assert "ttl" in call_args["data"]
|
||||||
|
assert "last_updated" in call_args["data"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_release_lock(pod_lock_manager, mock_prisma):
|
||||||
|
# Mock successful lock release
|
||||||
|
mock_prisma.db.litellm_cronjob.update.return_value = AsyncMock()
|
||||||
|
|
||||||
|
await pod_lock_manager.release_lock()
|
||||||
|
|
||||||
|
# Verify update was called with correct parameters
|
||||||
|
mock_prisma.db.litellm_cronjob.update.assert_called_once()
|
||||||
|
call_args = mock_prisma.db.litellm_cronjob.update.call_args[1]
|
||||||
|
assert call_args["where"]["cronjob_id"] == "test_job"
|
||||||
|
assert call_args["where"]["pod_id"] == pod_lock_manager.pod_id
|
||||||
|
assert call_args["data"]["status"] == "INACTIVE"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prisma_client_none(pod_lock_manager, monkeypatch):
|
||||||
|
# Mock prisma_client as None
|
||||||
|
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", None)
|
||||||
|
|
||||||
|
# Test all methods with None client
|
||||||
|
assert await pod_lock_manager.acquire_lock() == False
|
||||||
|
assert await pod_lock_manager.renew_lock() == False
|
||||||
|
assert await pod_lock_manager.release_lock() == False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_database_error_handling(pod_lock_manager, mock_prisma):
|
||||||
|
# Mock database errors
|
||||||
|
mock_prisma.db.litellm_cronjob.upsert.side_effect = Exception("Database error")
|
||||||
|
mock_prisma.db.litellm_cronjob.update.side_effect = Exception("Database error")
|
||||||
|
|
||||||
|
# Test error handling in all methods
|
||||||
|
assert await pod_lock_manager.acquire_lock() == False
|
||||||
|
await pod_lock_manager.renew_lock() # Should not raise exception
|
||||||
|
await pod_lock_manager.release_lock() # Should not raise exception
|
Loading…
Add table
Add a link
Reference in a new issue