[Feat] Emit Key, Team Budget metrics on a cron job schedule (#9528)

* _initialize_remaining_budget_metrics

* initialize_budget_metrics_cron_job

* initialize_budget_metrics_cron_job

* initialize_budget_metrics_cron_job

* test_initialize_budget_metrics_cron_job

* LITELLM_PROXY_ADMIN_NAME

* fix code qa checks

* test_initialize_budget_metrics_cron_job

* test_initialize_budget_metrics_cron_job

* pod lock manager allow dynamic cron job ID

* fix pod lock manager

* require cronjobid for PodLockManager

* fix DB_SPEND_UPDATE_JOB_NAME acquire / release lock

* add comment on prometheus logger

* add debug statements for emitting key, team budget metrics

* test_pod_lock_manager.py

* test_initialize_budget_metrics_cron_job

* initialize_budget_metrics_cron_job

* initialize_remaining_budget_metrics

* remove outdated test
This commit is contained in:
Ishaan Jaff 2025-04-10 16:59:14 -07:00 committed by GitHub
parent 557a2ca102
commit 4c85e13226
10 changed files with 346 additions and 142 deletions

View file

@ -480,6 +480,7 @@ RESPONSE_FORMAT_TOOL_NAME = "json_tool_call" # default tool name used when conv
########################### Logging Callback Constants ########################### ########################### Logging Callback Constants ###########################
AZURE_STORAGE_MSFT_VERSION = "2019-07-07" AZURE_STORAGE_MSFT_VERSION = "2019-07-07"
PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES = 5
MCP_TOOL_NAME_PREFIX = "mcp_tool" MCP_TOOL_NAME_PREFIX = "mcp_tool"
########################### LiteLLM Proxy Specific Constants ########################### ########################### LiteLLM Proxy Specific Constants ###########################
@ -514,6 +515,7 @@ LITELLM_PROXY_ADMIN_NAME = "default_user_id"
########################### DB CRON JOB NAMES ########################### ########################### DB CRON JOB NAMES ###########################
DB_SPEND_UPDATE_JOB_NAME = "db_spend_update_job" DB_SPEND_UPDATE_JOB_NAME = "db_spend_update_job"
PROMETHEUS_EMIT_BUDGET_METRICS_JOB_NAME = "prometheus_emit_budget_metrics_job"
DEFAULT_CRON_JOB_LOCK_TTL_SECONDS = 60 # 1 minute DEFAULT_CRON_JOB_LOCK_TTL_SECONDS = 60 # 1 minute
PROXY_BUDGET_RESCHEDULER_MIN_TIME = 597 PROXY_BUDGET_RESCHEDULER_MIN_TIME = 597
PROXY_BUDGET_RESCHEDULER_MAX_TIME = 605 PROXY_BUDGET_RESCHEDULER_MAX_TIME = 605

View file

@ -1,10 +1,19 @@
# used for /metrics endpoint on LiteLLM Proxy # used for /metrics endpoint on LiteLLM Proxy
#### What this does #### #### What this does ####
# On success, log events to Prometheus # On success, log events to Prometheus
import asyncio
import sys import sys
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Awaitable, Callable, List, Literal, Optional, Tuple, cast from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
List,
Literal,
Optional,
Tuple,
cast,
)
import litellm import litellm
from litellm._logging import print_verbose, verbose_logger from litellm._logging import print_verbose, verbose_logger
@ -14,6 +23,11 @@ from litellm.types.integrations.prometheus import *
from litellm.types.utils import StandardLoggingPayload from litellm.types.utils import StandardLoggingPayload
from litellm.utils import get_end_user_id_for_cost_tracking from litellm.utils import get_end_user_id_for_cost_tracking
if TYPE_CHECKING:
from apscheduler.schedulers.asyncio import AsyncIOScheduler
else:
AsyncIOScheduler = Any
class PrometheusLogger(CustomLogger): class PrometheusLogger(CustomLogger):
# Class variables or attributes # Class variables or attributes
@ -359,8 +373,6 @@ class PrometheusLogger(CustomLogger):
label_name="litellm_requests_metric" label_name="litellm_requests_metric"
), ),
) )
self._initialize_prometheus_startup_metrics()
except Exception as e: except Exception as e:
print_verbose(f"Got exception on init prometheus client {str(e)}") print_verbose(f"Got exception on init prometheus client {str(e)}")
raise e raise e
@ -988,9 +1000,9 @@ class PrometheusLogger(CustomLogger):
): ):
try: try:
verbose_logger.debug("setting remaining tokens requests metric") verbose_logger.debug("setting remaining tokens requests metric")
standard_logging_payload: Optional[ standard_logging_payload: Optional[StandardLoggingPayload] = (
StandardLoggingPayload request_kwargs.get("standard_logging_object")
] = request_kwargs.get("standard_logging_object") )
if standard_logging_payload is None: if standard_logging_payload is None:
return return
@ -1337,24 +1349,6 @@ class PrometheusLogger(CustomLogger):
return max_budget - spend return max_budget - spend
def _initialize_prometheus_startup_metrics(self):
"""
Initialize prometheus startup metrics
Helper to create tasks for initializing metrics that are required on startup - eg. remaining budget metrics
"""
if litellm.prometheus_initialize_budget_metrics is not True:
verbose_logger.debug("Prometheus: skipping budget metrics initialization")
return
try:
if asyncio.get_running_loop():
asyncio.create_task(self._initialize_remaining_budget_metrics())
except RuntimeError as e: # no running event loop
verbose_logger.exception(
f"No running event loop - skipping budget metrics initialization: {str(e)}"
)
async def _initialize_budget_metrics( async def _initialize_budget_metrics(
self, self,
data_fetch_function: Callable[..., Awaitable[Tuple[List[Any], Optional[int]]]], data_fetch_function: Callable[..., Awaitable[Tuple[List[Any], Optional[int]]]],
@ -1475,12 +1469,41 @@ class PrometheusLogger(CustomLogger):
data_type="keys", data_type="keys",
) )
async def _initialize_remaining_budget_metrics(self): async def initialize_remaining_budget_metrics(self):
""" """
Initialize remaining budget metrics for all teams to avoid metric discrepancies. Handler for initializing remaining budget metrics for all teams to avoid metric discrepancies.
Runs when prometheus logger starts up. Runs when prometheus logger starts up.
- If redis cache is available, we use the pod lock manager to acquire a lock and initialize the metrics.
- Ensures only one pod emits the metrics at a time.
- If redis cache is not available, we initialize the metrics directly.
""" """
from litellm.constants import PROMETHEUS_EMIT_BUDGET_METRICS_JOB_NAME
from litellm.proxy.proxy_server import proxy_logging_obj
pod_lock_manager = proxy_logging_obj.db_spend_update_writer.pod_lock_manager
# if using redis, ensure only one pod emits the metrics at a time
if pod_lock_manager and pod_lock_manager.redis_cache:
if await pod_lock_manager.acquire_lock(
cronjob_id=PROMETHEUS_EMIT_BUDGET_METRICS_JOB_NAME
):
try:
await self._initialize_remaining_budget_metrics()
finally:
await pod_lock_manager.release_lock(
cronjob_id=PROMETHEUS_EMIT_BUDGET_METRICS_JOB_NAME
)
else:
# if not using redis, initialize the metrics directly
await self._initialize_remaining_budget_metrics()
async def _initialize_remaining_budget_metrics(self):
"""
Helper to initialize remaining budget metrics for all teams and API keys.
"""
verbose_logger.debug("Emitting key, team budget metrics....")
await self._initialize_team_budget_metrics() await self._initialize_team_budget_metrics()
await self._initialize_api_key_budget_metrics() await self._initialize_api_key_budget_metrics()
@ -1737,6 +1760,36 @@ class PrometheusLogger(CustomLogger):
return (end_time - start_time).total_seconds() return (end_time - start_time).total_seconds()
return None return None
@staticmethod
def initialize_budget_metrics_cron_job(scheduler: AsyncIOScheduler):
"""
Initialize budget metrics as a cron job. This job runs every `PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES` minutes.
It emits the current remaining budget metrics for all Keys and Teams.
"""
from litellm.constants import PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES
from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.prometheus import PrometheusLogger
prometheus_loggers: List[CustomLogger] = (
litellm.logging_callback_manager.get_custom_loggers_for_type(
callback_type=PrometheusLogger
)
)
# we need to get the initialized prometheus logger instance(s) and call logger.initialize_remaining_budget_metrics() on them
verbose_logger.debug("found %s prometheus loggers", len(prometheus_loggers))
if len(prometheus_loggers) > 0:
prometheus_logger = cast(PrometheusLogger, prometheus_loggers[0])
verbose_logger.debug(
"Initializing remaining budget metrics as a cron job executing every %s minutes"
% PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES
)
scheduler.add_job(
prometheus_logger.initialize_remaining_budget_metrics,
"interval",
minutes=PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES,
)
@staticmethod @staticmethod
def _mount_metrics_endpoint(premium_user: bool): def _mount_metrics_endpoint(premium_user: bool):
""" """

View file

@ -53,7 +53,7 @@ class DBSpendUpdateWriter:
): ):
self.redis_cache = redis_cache self.redis_cache = redis_cache
self.redis_update_buffer = RedisUpdateBuffer(redis_cache=self.redis_cache) self.redis_update_buffer = RedisUpdateBuffer(redis_cache=self.redis_cache)
self.pod_lock_manager = PodLockManager(cronjob_id=DB_SPEND_UPDATE_JOB_NAME) self.pod_lock_manager = PodLockManager()
self.spend_update_queue = SpendUpdateQueue() self.spend_update_queue = SpendUpdateQueue()
self.daily_spend_update_queue = DailySpendUpdateQueue() self.daily_spend_update_queue = DailySpendUpdateQueue()
@ -383,7 +383,9 @@ class DBSpendUpdateWriter:
) )
# Only commit from redis to db if this pod is the leader # Only commit from redis to db if this pod is the leader
if await self.pod_lock_manager.acquire_lock(): if await self.pod_lock_manager.acquire_lock(
cronjob_id=DB_SPEND_UPDATE_JOB_NAME,
):
verbose_proxy_logger.debug("acquired lock for spend updates") verbose_proxy_logger.debug("acquired lock for spend updates")
try: try:
@ -411,7 +413,9 @@ class DBSpendUpdateWriter:
except Exception as e: except Exception as e:
verbose_proxy_logger.error(f"Error committing spend updates: {e}") verbose_proxy_logger.error(f"Error committing spend updates: {e}")
finally: finally:
await self.pod_lock_manager.release_lock() await self.pod_lock_manager.release_lock(
cronjob_id=DB_SPEND_UPDATE_JOB_NAME,
)
async def _commit_spend_updates_to_db_without_redis_buffer( async def _commit_spend_updates_to_db_without_redis_buffer(
self, self,

View file

@ -21,18 +21,18 @@ class PodLockManager:
Ensures that only one pod can run a cron job at a time. Ensures that only one pod can run a cron job at a time.
""" """
def __init__(self, cronjob_id: str, redis_cache: Optional[RedisCache] = None): def __init__(self, redis_cache: Optional[RedisCache] = None):
self.pod_id = str(uuid.uuid4()) self.pod_id = str(uuid.uuid4())
self.cronjob_id = cronjob_id
self.redis_cache = redis_cache self.redis_cache = redis_cache
# Define a unique key for this cronjob lock in Redis.
self.lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
@staticmethod @staticmethod
def get_redis_lock_key(cronjob_id: str) -> str: def get_redis_lock_key(cronjob_id: str) -> str:
return f"cronjob_lock:{cronjob_id}" return f"cronjob_lock:{cronjob_id}"
async def acquire_lock(self) -> Optional[bool]: async def acquire_lock(
self,
cronjob_id: str,
) -> Optional[bool]:
""" """
Attempt to acquire the lock for a specific cron job using Redis. Attempt to acquire the lock for a specific cron job using Redis.
Uses the SET command with NX and EX options to ensure atomicity. Uses the SET command with NX and EX options to ensure atomicity.
@ -44,12 +44,13 @@ class PodLockManager:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"Pod %s attempting to acquire Redis lock for cronjob_id=%s", "Pod %s attempting to acquire Redis lock for cronjob_id=%s",
self.pod_id, self.pod_id,
self.cronjob_id, cronjob_id,
) )
# Try to set the lock key with the pod_id as its value, only if it doesn't exist (NX) # Try to set the lock key with the pod_id as its value, only if it doesn't exist (NX)
# and with an expiration (EX) to avoid deadlocks. # and with an expiration (EX) to avoid deadlocks.
lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
acquired = await self.redis_cache.async_set_cache( acquired = await self.redis_cache.async_set_cache(
self.lock_key, lock_key,
self.pod_id, self.pod_id,
nx=True, nx=True,
ttl=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS, ttl=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS,
@ -58,13 +59,13 @@ class PodLockManager:
verbose_proxy_logger.info( verbose_proxy_logger.info(
"Pod %s successfully acquired Redis lock for cronjob_id=%s", "Pod %s successfully acquired Redis lock for cronjob_id=%s",
self.pod_id, self.pod_id,
self.cronjob_id, cronjob_id,
) )
return True return True
else: else:
# Check if the current pod already holds the lock # Check if the current pod already holds the lock
current_value = await self.redis_cache.async_get_cache(self.lock_key) current_value = await self.redis_cache.async_get_cache(lock_key)
if current_value is not None: if current_value is not None:
if isinstance(current_value, bytes): if isinstance(current_value, bytes):
current_value = current_value.decode("utf-8") current_value = current_value.decode("utf-8")
@ -72,18 +73,21 @@ class PodLockManager:
verbose_proxy_logger.info( verbose_proxy_logger.info(
"Pod %s already holds the Redis lock for cronjob_id=%s", "Pod %s already holds the Redis lock for cronjob_id=%s",
self.pod_id, self.pod_id,
self.cronjob_id, cronjob_id,
) )
self._emit_acquired_lock_event(self.cronjob_id, self.pod_id) self._emit_acquired_lock_event(cronjob_id, self.pod_id)
return True return True
return False return False
except Exception as e: except Exception as e:
verbose_proxy_logger.error( verbose_proxy_logger.error(
f"Error acquiring Redis lock for {self.cronjob_id}: {e}" f"Error acquiring Redis lock for {cronjob_id}: {e}"
) )
return False return False
async def release_lock(self): async def release_lock(
self,
cronjob_id: str,
):
""" """
Release the lock if the current pod holds it. Release the lock if the current pod holds it.
Uses get and delete commands to ensure that only the owner can release the lock. Uses get and delete commands to ensure that only the owner can release the lock.
@ -92,46 +96,52 @@ class PodLockManager:
verbose_proxy_logger.debug("redis_cache is None, skipping release_lock") verbose_proxy_logger.debug("redis_cache is None, skipping release_lock")
return return
try: try:
cronjob_id = cronjob_id
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"Pod %s attempting to release Redis lock for cronjob_id=%s", "Pod %s attempting to release Redis lock for cronjob_id=%s",
self.pod_id, self.pod_id,
self.cronjob_id, cronjob_id,
) )
current_value = await self.redis_cache.async_get_cache(self.lock_key) lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
current_value = await self.redis_cache.async_get_cache(lock_key)
if current_value is not None: if current_value is not None:
if isinstance(current_value, bytes): if isinstance(current_value, bytes):
current_value = current_value.decode("utf-8") current_value = current_value.decode("utf-8")
if current_value == self.pod_id: if current_value == self.pod_id:
result = await self.redis_cache.async_delete_cache(self.lock_key) result = await self.redis_cache.async_delete_cache(lock_key)
if result == 1: if result == 1:
verbose_proxy_logger.info( verbose_proxy_logger.info(
"Pod %s successfully released Redis lock for cronjob_id=%s", "Pod %s successfully released Redis lock for cronjob_id=%s",
self.pod_id, self.pod_id,
self.cronjob_id, cronjob_id,
)
self._emit_released_lock_event(
cronjob_id=cronjob_id,
pod_id=self.pod_id,
) )
self._emit_released_lock_event(self.cronjob_id, self.pod_id)
else: else:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"Pod %s failed to release Redis lock for cronjob_id=%s", "Pod %s failed to release Redis lock for cronjob_id=%s",
self.pod_id, self.pod_id,
self.cronjob_id, cronjob_id,
) )
else: else:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"Pod %s cannot release Redis lock for cronjob_id=%s because it is held by pod %s", "Pod %s cannot release Redis lock for cronjob_id=%s because it is held by pod %s",
self.pod_id, self.pod_id,
self.cronjob_id, cronjob_id,
current_value, current_value,
) )
else: else:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"Pod %s attempted to release Redis lock for cronjob_id=%s, but no lock was found", "Pod %s attempted to release Redis lock for cronjob_id=%s, but no lock was found",
self.pod_id, self.pod_id,
self.cronjob_id, cronjob_id,
) )
except Exception as e: except Exception as e:
verbose_proxy_logger.error( verbose_proxy_logger.error(
f"Error releasing Redis lock for {self.cronjob_id}: {e}" f"Error releasing Redis lock for {cronjob_id}: {e}"
) )
@staticmethod @staticmethod

View file

@ -10,6 +10,36 @@ model_list:
api_key: fake-key api_key: fake-key
litellm_settings: litellm_settings:
prometheus_initialize_budget_metrics: true
callbacks: ["prometheus"]
mcp_tools:
- name: "get_current_time"
description: "Get the current time"
input_schema: {
"type": "object",
"properties": {
"format": {
"type": "string",
"description": "The format of the time to return",
"enum": ["short"]
}
}
}
handler: "mcp_tools.get_current_time"
- name: "get_current_date"
description: "Get the current date"
input_schema: {
"type": "object",
"properties": {
"format": {
"type": "string",
"description": "The format of the date to return",
"enum": ["short"]
}
}
}
handler: "mcp_tools.get_current_date"
default_team_settings: default_team_settings:
- team_id: test_dev - team_id: test_dev
success_callback: ["langfuse", "s3"] success_callback: ["langfuse", "s3"]

View file

@ -803,9 +803,9 @@ model_max_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter(
dual_cache=user_api_key_cache dual_cache=user_api_key_cache
) )
litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter) litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter)
redis_usage_cache: Optional[ redis_usage_cache: Optional[RedisCache] = (
RedisCache None # redis cache used for tracking spend, tpm/rpm limits
] = None # redis cache used for tracking spend, tpm/rpm limits )
user_custom_auth = None user_custom_auth = None
user_custom_key_generate = None user_custom_key_generate = None
user_custom_sso = None user_custom_sso = None
@ -1131,9 +1131,9 @@ async def update_cache( # noqa: PLR0915
_id = "team_id:{}".format(team_id) _id = "team_id:{}".format(team_id)
try: try:
# Fetch the existing cost for the given user # Fetch the existing cost for the given user
existing_spend_obj: Optional[ existing_spend_obj: Optional[LiteLLM_TeamTable] = (
LiteLLM_TeamTable await user_api_key_cache.async_get_cache(key=_id)
] = await user_api_key_cache.async_get_cache(key=_id) )
if existing_spend_obj is None: if existing_spend_obj is None:
# do nothing if team not in api key cache # do nothing if team not in api key cache
return return
@ -2812,9 +2812,9 @@ async def initialize( # noqa: PLR0915
user_api_base = api_base user_api_base = api_base
dynamic_config[user_model]["api_base"] = api_base dynamic_config[user_model]["api_base"] = api_base
if api_version: if api_version:
os.environ[ os.environ["AZURE_API_VERSION"] = (
"AZURE_API_VERSION" api_version # set this for azure - litellm can read this from the env
] = api_version # set this for azure - litellm can read this from the env )
if max_tokens: # model-specific param if max_tokens: # model-specific param
dynamic_config[user_model]["max_tokens"] = max_tokens dynamic_config[user_model]["max_tokens"] = max_tokens
if temperature: # model-specific param if temperature: # model-specific param
@ -3191,6 +3191,11 @@ class ProxyStartupEvent:
) )
await proxy_logging_obj.slack_alerting_instance.send_fallback_stats_from_prometheus() await proxy_logging_obj.slack_alerting_instance.send_fallback_stats_from_prometheus()
if litellm.prometheus_initialize_budget_metrics is True:
from litellm.integrations.prometheus import PrometheusLogger
PrometheusLogger.initialize_budget_metrics_cron_job(scheduler=scheduler)
scheduler.start() scheduler.start()
@classmethod @classmethod
@ -7753,9 +7758,9 @@ async def get_config_list(
hasattr(sub_field_info, "description") hasattr(sub_field_info, "description")
and sub_field_info.description is not None and sub_field_info.description is not None
): ):
nested_fields[ nested_fields[idx].field_description = (
idx sub_field_info.description
].field_description = sub_field_info.description )
idx += 1 idx += 1
_stored_in_db = None _stored_in_db = None

View file

@ -0,0 +1,44 @@
"""
Mock prometheus unit tests, these don't rely on LLM API calls
"""
import json
import os
import sys
import pytest
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
from apscheduler.schedulers.asyncio import AsyncIOScheduler
import litellm
from litellm.constants import PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES
from litellm.integrations.prometheus import PrometheusLogger
def test_initialize_budget_metrics_cron_job():
# Create a scheduler
scheduler = AsyncIOScheduler()
# Create and register a PrometheusLogger
prometheus_logger = PrometheusLogger()
litellm.callbacks = [prometheus_logger]
# Initialize the cron job
PrometheusLogger.initialize_budget_metrics_cron_job(scheduler)
# Verify that a job was added to the scheduler
jobs = scheduler.get_jobs()
assert len(jobs) == 1
# Verify job properties
job = jobs[0]
assert (
job.trigger.interval.total_seconds() / 60
== PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES
)
assert job.func.__name__ == "initialize_remaining_budget_metrics"

View file

@ -29,7 +29,7 @@ def mock_redis():
@pytest.fixture @pytest.fixture
def pod_lock_manager(mock_redis): def pod_lock_manager(mock_redis):
return PodLockManager(cronjob_id="test_job", redis_cache=mock_redis) return PodLockManager(redis_cache=mock_redis)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -40,12 +40,15 @@ async def test_acquire_lock_success(pod_lock_manager, mock_redis):
# Mock successful acquisition (SET NX returns True) # Mock successful acquisition (SET NX returns True)
mock_redis.async_set_cache.return_value = True mock_redis.async_set_cache.return_value = True
result = await pod_lock_manager.acquire_lock() result = await pod_lock_manager.acquire_lock(
cronjob_id="test_job",
)
assert result == True assert result == True
# Verify set_cache was called with correct parameters # Verify set_cache was called with correct parameters
lock_key = pod_lock_manager.get_redis_lock_key(cronjob_id="test_job")
mock_redis.async_set_cache.assert_called_once_with( mock_redis.async_set_cache.assert_called_once_with(
pod_lock_manager.lock_key, lock_key,
pod_lock_manager.pod_id, pod_lock_manager.pod_id,
nx=True, nx=True,
ttl=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS, ttl=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS,
@ -62,13 +65,16 @@ async def test_acquire_lock_existing_active(pod_lock_manager, mock_redis):
# Mock get_cache to return a different pod's ID # Mock get_cache to return a different pod's ID
mock_redis.async_get_cache.return_value = "different_pod_id" mock_redis.async_get_cache.return_value = "different_pod_id"
result = await pod_lock_manager.acquire_lock() result = await pod_lock_manager.acquire_lock(
cronjob_id="test_job",
)
assert result == False assert result == False
# Verify set_cache was called # Verify set_cache was called
mock_redis.async_set_cache.assert_called_once() mock_redis.async_set_cache.assert_called_once()
# Verify get_cache was called to check existing lock # Verify get_cache was called to check existing lock
mock_redis.async_get_cache.assert_called_once_with(pod_lock_manager.lock_key) lock_key = pod_lock_manager.get_redis_lock_key(cronjob_id="test_job")
mock_redis.async_get_cache.assert_called_once_with(lock_key)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -89,7 +95,9 @@ async def test_acquire_lock_expired(pod_lock_manager, mock_redis):
# Then set succeeds on retry (simulating key expiring between checks) # Then set succeeds on retry (simulating key expiring between checks)
mock_redis.async_set_cache.side_effect = [False, True] mock_redis.async_set_cache.side_effect = [False, True]
result = await pod_lock_manager.acquire_lock() result = await pod_lock_manager.acquire_lock(
cronjob_id="test_job",
)
assert result == False # First attempt fails assert result == False # First attempt fails
# Reset mock for a second attempt # Reset mock for a second attempt
@ -97,7 +105,9 @@ async def test_acquire_lock_expired(pod_lock_manager, mock_redis):
mock_redis.async_set_cache.return_value = True mock_redis.async_set_cache.return_value = True
# Try again (simulating the lock expired) # Try again (simulating the lock expired)
result = await pod_lock_manager.acquire_lock() result = await pod_lock_manager.acquire_lock(
cronjob_id="test_job",
)
assert result == True assert result == True
# Verify set_cache was called again # Verify set_cache was called again
@ -114,12 +124,15 @@ async def test_release_lock_success(pod_lock_manager, mock_redis):
# Mock successful deletion # Mock successful deletion
mock_redis.async_delete_cache.return_value = 1 mock_redis.async_delete_cache.return_value = 1
await pod_lock_manager.release_lock() await pod_lock_manager.release_lock(
cronjob_id="test_job",
)
# Verify get_cache was called # Verify get_cache was called
mock_redis.async_get_cache.assert_called_once_with(pod_lock_manager.lock_key) lock_key = pod_lock_manager.get_redis_lock_key(cronjob_id="test_job")
mock_redis.async_get_cache.assert_called_once_with(lock_key)
# Verify delete_cache was called # Verify delete_cache was called
mock_redis.async_delete_cache.assert_called_once_with(pod_lock_manager.lock_key) mock_redis.async_delete_cache.assert_called_once_with(lock_key)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -130,10 +143,13 @@ async def test_release_lock_different_pod(pod_lock_manager, mock_redis):
# Mock get_cache to return a different pod's ID # Mock get_cache to return a different pod's ID
mock_redis.async_get_cache.return_value = "different_pod_id" mock_redis.async_get_cache.return_value = "different_pod_id"
await pod_lock_manager.release_lock() await pod_lock_manager.release_lock(
cronjob_id="test_job",
)
# Verify get_cache was called # Verify get_cache was called
mock_redis.async_get_cache.assert_called_once_with(pod_lock_manager.lock_key) lock_key = pod_lock_manager.get_redis_lock_key(cronjob_id="test_job")
mock_redis.async_get_cache.assert_called_once_with(lock_key)
# Verify delete_cache was NOT called # Verify delete_cache was NOT called
mock_redis.async_delete_cache.assert_not_called() mock_redis.async_delete_cache.assert_not_called()
@ -146,10 +162,13 @@ async def test_release_lock_no_lock(pod_lock_manager, mock_redis):
# Mock get_cache to return None (no lock) # Mock get_cache to return None (no lock)
mock_redis.async_get_cache.return_value = None mock_redis.async_get_cache.return_value = None
await pod_lock_manager.release_lock() await pod_lock_manager.release_lock(
cronjob_id="test_job",
)
# Verify get_cache was called # Verify get_cache was called
mock_redis.async_get_cache.assert_called_once_with(pod_lock_manager.lock_key) lock_key = pod_lock_manager.get_redis_lock_key(cronjob_id="test_job")
mock_redis.async_get_cache.assert_called_once_with(lock_key)
# Verify delete_cache was NOT called # Verify delete_cache was NOT called
mock_redis.async_delete_cache.assert_not_called() mock_redis.async_delete_cache.assert_not_called()
@ -159,13 +178,20 @@ async def test_redis_none(monkeypatch):
""" """
Test behavior when redis_cache is None Test behavior when redis_cache is None
""" """
pod_lock_manager = PodLockManager(cronjob_id="test_job", redis_cache=None) pod_lock_manager = PodLockManager(redis_cache=None)
# Test acquire_lock with None redis_cache # Test acquire_lock with None redis_cache
assert await pod_lock_manager.acquire_lock() is None assert (
await pod_lock_manager.acquire_lock(
cronjob_id="test_job",
)
is None
)
# Test release_lock with None redis_cache (should not raise exception) # Test release_lock with None redis_cache (should not raise exception)
await pod_lock_manager.release_lock() await pod_lock_manager.release_lock(
cronjob_id="test_job",
)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -179,7 +205,9 @@ async def test_redis_error_handling(pod_lock_manager, mock_redis):
mock_redis.async_delete_cache.side_effect = Exception("Redis error") mock_redis.async_delete_cache.side_effect = Exception("Redis error")
# Test acquire_lock error handling # Test acquire_lock error handling
result = await pod_lock_manager.acquire_lock() result = await pod_lock_manager.acquire_lock(
cronjob_id="test_job",
)
assert result == False assert result == False
# Reset side effect for get_cache for the release test # Reset side effect for get_cache for the release test
@ -187,7 +215,9 @@ async def test_redis_error_handling(pod_lock_manager, mock_redis):
mock_redis.async_get_cache.return_value = pod_lock_manager.pod_id mock_redis.async_get_cache.return_value = pod_lock_manager.pod_id
# Test release_lock error handling (should not raise exception) # Test release_lock error handling (should not raise exception)
await pod_lock_manager.release_lock() await pod_lock_manager.release_lock(
cronjob_id="test_job",
)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -200,14 +230,18 @@ async def test_bytes_handling(pod_lock_manager, mock_redis):
# Mock get_cache to return bytes # Mock get_cache to return bytes
mock_redis.async_get_cache.return_value = pod_lock_manager.pod_id.encode("utf-8") mock_redis.async_get_cache.return_value = pod_lock_manager.pod_id.encode("utf-8")
result = await pod_lock_manager.acquire_lock() result = await pod_lock_manager.acquire_lock(
cronjob_id="test_job",
)
assert result == True assert result == True
# Reset for release test # Reset for release test
mock_redis.async_get_cache.return_value = pod_lock_manager.pod_id.encode("utf-8") mock_redis.async_get_cache.return_value = pod_lock_manager.pod_id.encode("utf-8")
mock_redis.async_delete_cache.return_value = 1 mock_redis.async_delete_cache.return_value = 1
await pod_lock_manager.release_lock() await pod_lock_manager.release_lock(
cronjob_id="test_job",
)
mock_redis.async_delete_cache.assert_called_once() mock_redis.async_delete_cache.assert_called_once()
@ -217,15 +251,17 @@ async def test_concurrent_lock_acquisition_simulation():
Simulate multiple pods trying to acquire the lock simultaneously Simulate multiple pods trying to acquire the lock simultaneously
""" """
mock_redis = MockRedisCache() mock_redis = MockRedisCache()
pod1 = PodLockManager(cronjob_id="test_job", redis_cache=mock_redis) pod1 = PodLockManager(redis_cache=mock_redis)
pod2 = PodLockManager(cronjob_id="test_job", redis_cache=mock_redis) pod2 = PodLockManager(redis_cache=mock_redis)
pod3 = PodLockManager(cronjob_id="test_job", redis_cache=mock_redis) pod3 = PodLockManager(redis_cache=mock_redis)
# Simulate first pod getting the lock # Simulate first pod getting the lock
mock_redis.async_set_cache.return_value = True mock_redis.async_set_cache.return_value = True
# First pod should get the lock # First pod should get the lock
result1 = await pod1.acquire_lock() result1 = await pod1.acquire_lock(
cronjob_id="test_job",
)
assert result1 == True assert result1 == True
# Simulate other pods failing to get the lock # Simulate other pods failing to get the lock
@ -233,8 +269,12 @@ async def test_concurrent_lock_acquisition_simulation():
mock_redis.async_get_cache.return_value = pod1.pod_id mock_redis.async_get_cache.return_value = pod1.pod_id
# Other pods should fail to acquire # Other pods should fail to acquire
result2 = await pod2.acquire_lock() result2 = await pod2.acquire_lock(
result3 = await pod3.acquire_lock() cronjob_id="test_job",
)
result3 = await pod3.acquire_lock(
cronjob_id="test_job",
)
# Since other pods don't have the lock, they should get False # Since other pods don't have the lock, they should get False
assert result2 == False assert result2 == False
@ -246,14 +286,16 @@ async def test_lock_takeover_race_condition(mock_redis):
""" """
Test scenario where multiple pods try to take over an expired lock using Redis Test scenario where multiple pods try to take over an expired lock using Redis
""" """
pod1 = PodLockManager(cronjob_id="test_job", redis_cache=mock_redis) pod1 = PodLockManager(redis_cache=mock_redis)
pod2 = PodLockManager(cronjob_id="test_job", redis_cache=mock_redis) pod2 = PodLockManager(redis_cache=mock_redis)
# Simulate first pod's acquisition succeeding # Simulate first pod's acquisition succeeding
mock_redis.async_set_cache.return_value = True mock_redis.async_set_cache.return_value = True
# First pod should successfully acquire # First pod should successfully acquire
result1 = await pod1.acquire_lock() result1 = await pod1.acquire_lock(
cronjob_id="test_job",
)
assert result1 == True assert result1 == True
# Simulate race condition: second pod tries but fails # Simulate race condition: second pod tries but fails
@ -261,5 +303,7 @@ async def test_lock_takeover_race_condition(mock_redis):
mock_redis.async_get_cache.return_value = pod1.pod_id mock_redis.async_get_cache.return_value = pod1.pod_id
# Second pod should fail to acquire # Second pod should fail to acquire
result2 = await pod2.acquire_lock() result2 = await pod2.acquire_lock(
cronjob_id="test_job",
)
assert result2 == False assert result2 == False

View file

@ -39,7 +39,7 @@ import time
@pytest.fixture @pytest.fixture
def prometheus_logger(): def prometheus_logger() -> PrometheusLogger:
collectors = list(REGISTRY._collector_to_names.keys()) collectors = list(REGISTRY._collector_to_names.keys())
for collector in collectors: for collector in collectors:
REGISTRY.unregister(collector) REGISTRY.unregister(collector)
@ -1212,24 +1212,6 @@ async def test_initialize_remaining_budget_metrics_exception_handling(
prometheus_logger.litellm_remaining_api_key_budget_metric.assert_not_called() prometheus_logger.litellm_remaining_api_key_budget_metric.assert_not_called()
def test_initialize_prometheus_startup_metrics_no_loop(prometheus_logger):
"""
Test that _initialize_prometheus_startup_metrics handles case when no event loop exists
"""
# Mock asyncio.get_running_loop to raise RuntimeError
litellm.prometheus_initialize_budget_metrics = True
with patch(
"asyncio.get_running_loop", side_effect=RuntimeError("No running event loop")
), patch("litellm._logging.verbose_logger.exception") as mock_logger:
# Call the function
prometheus_logger._initialize_prometheus_startup_metrics()
# Verify the error was logged
mock_logger.assert_called_once()
assert "No running event loop" in mock_logger.call_args[0][0]
@pytest.mark.asyncio(scope="session") @pytest.mark.asyncio(scope="session")
async def test_initialize_api_key_budget_metrics(prometheus_logger): async def test_initialize_api_key_budget_metrics(prometheus_logger):
""" """

View file

@ -141,10 +141,12 @@ async def setup_db_connection(prisma_client):
async def test_pod_lock_acquisition_when_no_active_lock(): async def test_pod_lock_acquisition_when_no_active_lock():
"""Test if a pod can acquire a lock when no lock is active""" """Test if a pod can acquire a lock when no lock is active"""
cronjob_id = str(uuid.uuid4()) cronjob_id = str(uuid.uuid4())
lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) lock_manager = PodLockManager(redis_cache=global_redis_cache)
# Attempt to acquire lock # Attempt to acquire lock
result = await lock_manager.acquire_lock() result = await lock_manager.acquire_lock(
cronjob_id=cronjob_id,
)
assert result == True, "Pod should be able to acquire lock when no lock exists" assert result == True, "Pod should be able to acquire lock when no lock exists"
@ -161,13 +163,19 @@ async def test_pod_lock_acquisition_after_completion():
"""Test if a new pod can acquire lock after previous pod completes""" """Test if a new pod can acquire lock after previous pod completes"""
cronjob_id = str(uuid.uuid4()) cronjob_id = str(uuid.uuid4())
# First pod acquires and releases lock # First pod acquires and releases lock
first_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) first_lock_manager = PodLockManager(redis_cache=global_redis_cache)
await first_lock_manager.acquire_lock() await first_lock_manager.acquire_lock(
await first_lock_manager.release_lock() cronjob_id=cronjob_id,
)
await first_lock_manager.release_lock(
cronjob_id=cronjob_id,
)
# Second pod attempts to acquire lock # Second pod attempts to acquire lock
second_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) second_lock_manager = PodLockManager(redis_cache=global_redis_cache)
result = await second_lock_manager.acquire_lock() result = await second_lock_manager.acquire_lock(
cronjob_id=cronjob_id,
)
assert result == True, "Second pod should acquire lock after first pod releases it" assert result == True, "Second pod should acquire lock after first pod releases it"
@ -182,15 +190,21 @@ async def test_pod_lock_acquisition_after_expiry():
"""Test if a new pod can acquire lock after previous pod's lock expires""" """Test if a new pod can acquire lock after previous pod's lock expires"""
cronjob_id = str(uuid.uuid4()) cronjob_id = str(uuid.uuid4())
# First pod acquires lock # First pod acquires lock
first_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) first_lock_manager = PodLockManager(redis_cache=global_redis_cache)
await first_lock_manager.acquire_lock() await first_lock_manager.acquire_lock(
cronjob_id=cronjob_id,
)
# release the lock from the first pod # release the lock from the first pod
await first_lock_manager.release_lock() await first_lock_manager.release_lock(
cronjob_id=cronjob_id,
)
# Second pod attempts to acquire lock # Second pod attempts to acquire lock
second_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) second_lock_manager = PodLockManager(redis_cache=global_redis_cache)
result = await second_lock_manager.acquire_lock() result = await second_lock_manager.acquire_lock(
cronjob_id=cronjob_id,
)
assert ( assert (
result == True result == True
@ -206,11 +220,15 @@ async def test_pod_lock_acquisition_after_expiry():
async def test_pod_lock_release(): async def test_pod_lock_release():
"""Test if a pod can successfully release its lock""" """Test if a pod can successfully release its lock"""
cronjob_id = str(uuid.uuid4()) cronjob_id = str(uuid.uuid4())
lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) lock_manager = PodLockManager(redis_cache=global_redis_cache)
# Acquire and then release lock # Acquire and then release lock
await lock_manager.acquire_lock() await lock_manager.acquire_lock(
await lock_manager.release_lock() cronjob_id=cronjob_id,
)
await lock_manager.release_lock(
cronjob_id=cronjob_id,
)
# Verify in redis # Verify in redis
lock_key = PodLockManager.get_redis_lock_key(cronjob_id) lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
@ -224,15 +242,21 @@ async def test_concurrent_lock_acquisition():
cronjob_id = str(uuid.uuid4()) cronjob_id = str(uuid.uuid4())
# Create multiple lock managers simulating different pods # Create multiple lock managers simulating different pods
lock_manager1 = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) lock_manager1 = PodLockManager(redis_cache=global_redis_cache)
lock_manager2 = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) lock_manager2 = PodLockManager(redis_cache=global_redis_cache)
lock_manager3 = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) lock_manager3 = PodLockManager(redis_cache=global_redis_cache)
# Try to acquire locks concurrently # Try to acquire locks concurrently
results = await asyncio.gather( results = await asyncio.gather(
lock_manager1.acquire_lock(), lock_manager1.acquire_lock(
lock_manager2.acquire_lock(), cronjob_id=cronjob_id,
lock_manager3.acquire_lock(), ),
lock_manager2.acquire_lock(
cronjob_id=cronjob_id,
),
lock_manager3.acquire_lock(
cronjob_id=cronjob_id,
),
) )
# Only one should succeed # Only one should succeed
@ -254,7 +278,7 @@ async def test_concurrent_lock_acquisition():
async def test_lock_acquisition_with_expired_ttl(): async def test_lock_acquisition_with_expired_ttl():
"""Test that a pod can acquire a lock when existing lock has expired TTL""" """Test that a pod can acquire a lock when existing lock has expired TTL"""
cronjob_id = str(uuid.uuid4()) cronjob_id = str(uuid.uuid4())
first_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) first_lock_manager = PodLockManager(redis_cache=global_redis_cache)
# First pod acquires lock with a very short TTL to simulate expiration # First pod acquires lock with a very short TTL to simulate expiration
short_ttl = 1 # 1 second short_ttl = 1 # 1 second
@ -269,8 +293,10 @@ async def test_lock_acquisition_with_expired_ttl():
await asyncio.sleep(short_ttl + 0.5) # Wait slightly longer than the TTL await asyncio.sleep(short_ttl + 0.5) # Wait slightly longer than the TTL
# Second pod tries to acquire without explicit release # Second pod tries to acquire without explicit release
second_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) second_lock_manager = PodLockManager(redis_cache=global_redis_cache)
result = await second_lock_manager.acquire_lock() result = await second_lock_manager.acquire_lock(
cronjob_id=cronjob_id,
)
assert result == True, "Should acquire lock when existing lock has expired TTL" assert result == True, "Should acquire lock when existing lock has expired TTL"
@ -286,7 +312,7 @@ async def test_release_expired_lock():
cronjob_id = str(uuid.uuid4()) cronjob_id = str(uuid.uuid4())
# First pod acquires lock with a very short TTL # First pod acquires lock with a very short TTL
first_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) first_lock_manager = PodLockManager(redis_cache=global_redis_cache)
short_ttl = 1 # 1 second short_ttl = 1 # 1 second
lock_key = PodLockManager.get_redis_lock_key(cronjob_id) lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
await global_redis_cache.async_set_cache( await global_redis_cache.async_set_cache(
@ -299,11 +325,15 @@ async def test_release_expired_lock():
await asyncio.sleep(short_ttl + 0.5) # Wait slightly longer than the TTL await asyncio.sleep(short_ttl + 0.5) # Wait slightly longer than the TTL
# Second pod acquires the lock # Second pod acquires the lock
second_lock_manager = PodLockManager(cronjob_id=cronjob_id, redis_cache=global_redis_cache) second_lock_manager = PodLockManager(redis_cache=global_redis_cache)
await second_lock_manager.acquire_lock() await second_lock_manager.acquire_lock(
cronjob_id=cronjob_id,
)
# First pod attempts to release its lock # First pod attempts to release its lock
await first_lock_manager.release_lock() await first_lock_manager.release_lock(
cronjob_id=cronjob_id,
)
# Verify that second pod's lock is still active # Verify that second pod's lock is still active
lock_record = await global_redis_cache.async_get_cache(lock_key) lock_record = await global_redis_cache.async_get_cache(lock_key)