Merge pull request #9719 from BerriAI/litellm_metrics_pod_lock_manager

[Reliability] Emit operational metrics for new DB Transaction architecture
This commit is contained in:
Ishaan Jaff 2025-04-04 21:12:06 -07:00 committed by GitHub
commit 8c3670e192
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 351 additions and 40 deletions

View file

@ -156,7 +156,7 @@ PROXY_LOGOUT_URL="https://www.google.com"
Set this in your .env (so the proxy can set the correct redirect url)
```shell
PROXY_BASE_URL=https://litellm-api.up.railway.app/
PROXY_BASE_URL=https://litellm-api.up.railway.app
```
#### Step 4. Test flow

View file

@ -124,6 +124,7 @@ class ServiceLogging(CustomLogger):
service=service,
duration=duration,
call_type=call_type,
event_metadata=event_metadata,
)
for callback in litellm.service_callback:
@ -229,6 +230,7 @@ class ServiceLogging(CustomLogger):
service=service,
duration=duration,
call_type=call_type,
event_metadata=event_metadata,
)
for callback in litellm.service_callback:

View file

@ -3,11 +3,16 @@
# On success + failure, log events to Prometheus for litellm / adjacent services (litellm, redis, postgres, llm api providers)
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union
from litellm._logging import print_verbose, verbose_logger
from litellm.types.integrations.prometheus import LATENCY_BUCKETS
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
from litellm.types.services import (
DEFAULT_SERVICE_CONFIGS,
ServiceLoggerPayload,
ServiceMetrics,
ServiceTypes,
)
FAILED_REQUESTS_LABELS = ["error_class", "function_name"]
@ -23,7 +28,8 @@ class PrometheusServicesLogger:
):
try:
try:
from prometheus_client import REGISTRY, Counter, Histogram
from prometheus_client import REGISTRY, Counter, Gauge, Histogram
from prometheus_client.gc_collector import Collector
except ImportError:
raise Exception(
"Missing prometheus_client. Run `pip install prometheus-client`"
@ -31,36 +37,51 @@ class PrometheusServicesLogger:
self.Histogram = Histogram
self.Counter = Counter
self.Gauge = Gauge
self.REGISTRY = REGISTRY
verbose_logger.debug("in init prometheus services metrics")
self.services = [item.value for item in ServiceTypes]
self.payload_to_prometheus_map: Dict[
str, List[Union[Histogram, Counter, Gauge, Collector]]
] = {}
self.payload_to_prometheus_map = (
{}
) # store the prometheus histogram/counter we need to call for each field in payload
for service in ServiceTypes:
service_metrics: List[Union[Histogram, Counter, Gauge, Collector]] = []
for service in self.services:
histogram = self.create_histogram(service, type_of_request="latency")
counter_failed_request = self.create_counter(
service,
type_of_request="failed_requests",
additional_labels=FAILED_REQUESTS_LABELS,
)
counter_total_requests = self.create_counter(
service, type_of_request="total_requests"
)
self.payload_to_prometheus_map[service] = [
histogram,
counter_failed_request,
counter_total_requests,
]
metrics_to_initialize = self._get_service_metrics_initialize(service)
self.prometheus_to_amount_map: dict = (
{}
) # the field / value in ServiceLoggerPayload the object needs to be incremented by
# Initialize only the configured metrics for each service
if ServiceMetrics.HISTOGRAM in metrics_to_initialize:
histogram = self.create_histogram(
service.value, type_of_request="latency"
)
if histogram:
service_metrics.append(histogram)
if ServiceMetrics.COUNTER in metrics_to_initialize:
counter_failed_request = self.create_counter(
service.value,
type_of_request="failed_requests",
additional_labels=FAILED_REQUESTS_LABELS,
)
if counter_failed_request:
service_metrics.append(counter_failed_request)
counter_total_requests = self.create_counter(
service.value, type_of_request="total_requests"
)
if counter_total_requests:
service_metrics.append(counter_total_requests)
if ServiceMetrics.GAUGE in metrics_to_initialize:
gauge = self.create_gauge(service.value, type_of_request="size")
if gauge:
service_metrics.append(gauge)
if service_metrics:
self.payload_to_prometheus_map[service.value] = service_metrics
self.prometheus_to_amount_map: dict = {}
### MOCK TESTING ###
self.mock_testing = mock_testing
self.mock_testing_success_calls = 0
@ -70,6 +91,19 @@ class PrometheusServicesLogger:
print_verbose(f"Got exception on init prometheus client {str(e)}")
raise e
def _get_service_metrics_initialize(
self, service: ServiceTypes
) -> List[ServiceMetrics]:
DEFAULT_METRICS = [ServiceMetrics.COUNTER, ServiceMetrics.HISTOGRAM]
if service not in DEFAULT_SERVICE_CONFIGS:
return DEFAULT_METRICS
metrics = DEFAULT_SERVICE_CONFIGS.get(service, {}).get("metrics", [])
if not metrics:
verbose_logger.debug(f"No metrics found for service {service}")
return DEFAULT_METRICS
return metrics
def is_metric_registered(self, metric_name) -> bool:
for metric in self.REGISTRY.collect():
if metric_name == metric.name:
@ -94,6 +128,15 @@ class PrometheusServicesLogger:
buckets=LATENCY_BUCKETS,
)
def create_gauge(self, service: str, type_of_request: str):
metric_name = "litellm_{}_{}".format(service, type_of_request)
is_registered = self.is_metric_registered(metric_name)
if is_registered:
return self._get_metric(metric_name)
return self.Gauge(
metric_name, "Gauge for {} service".format(service), labelnames=[service]
)
def create_counter(
self,
service: str,
@ -120,6 +163,15 @@ class PrometheusServicesLogger:
histogram.labels(labels).observe(amount)
def update_gauge(
self,
gauge,
labels: str,
amount: float,
):
assert isinstance(gauge, self.Gauge)
gauge.labels(labels).set(amount)
def increment_counter(
self,
counter,
@ -190,6 +242,13 @@ class PrometheusServicesLogger:
labels=payload.service.value,
amount=1, # LOG TOTAL REQUESTS TO PROMETHEUS
)
elif isinstance(obj, self.Gauge):
if payload.event_metadata:
self.update_gauge(
gauge=obj,
labels=payload.event_metadata.get("gauge_labels") or "",
amount=payload.event_metadata.get("gauge_value") or 0,
)
async def async_service_failure_hook(
self,

View file

@ -2,8 +2,14 @@
Base class for in memory buffer for database transactions
"""
import asyncio
from typing import Optional
from litellm._logging import verbose_proxy_logger
from litellm._service_logger import ServiceLogging
service_logger_obj = (
ServiceLogging()
) # used for tracking metrics for In memory buffer, redis buffer, pod lock manager
from litellm.constants import MAX_IN_MEMORY_QUEUE_FLUSH_COUNT, MAX_SIZE_IN_MEMORY_QUEUE
@ -18,6 +24,9 @@ class BaseUpdateQueue:
"""Enqueue an update."""
verbose_proxy_logger.debug("Adding update to queue: %s", update)
await self.update_queue.put(update)
await self._emit_new_item_added_to_queue_event(
queue_size=self.update_queue.qsize()
)
async def flush_all_updates_from_in_memory_queue(self):
"""Get all updates from the queue."""
@ -31,3 +40,10 @@ class BaseUpdateQueue:
break
updates.append(await self.update_queue.get())
return updates
async def _emit_new_item_added_to_queue_event(
self,
queue_size: Optional[int] = None,
):
"""placeholder, emit event when a new item is added to the queue"""
pass

View file

@ -1,10 +1,14 @@
import asyncio
from copy import deepcopy
from typing import Dict, List
from typing import Dict, List, Optional
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import DailyUserSpendTransaction
from litellm.proxy.db.db_transaction_queue.base_update_queue import BaseUpdateQueue
from litellm.proxy.db.db_transaction_queue.base_update_queue import (
BaseUpdateQueue,
service_logger_obj,
)
from litellm.types.services import ServiceTypes
class DailySpendUpdateQueue(BaseUpdateQueue):
@ -117,3 +121,19 @@ class DailySpendUpdateQueue(BaseUpdateQueue):
else:
aggregated_daily_spend_update_transactions[_key] = deepcopy(payload)
return aggregated_daily_spend_update_transactions
async def _emit_new_item_added_to_queue_event(
self,
queue_size: Optional[int] = None,
):
asyncio.create_task(
service_logger_obj.async_service_success_hook(
service=ServiceTypes.IN_MEMORY_DAILY_SPEND_UPDATE_QUEUE,
duration=0,
call_type="_emit_new_item_added_to_queue_event",
event_metadata={
"gauge_labels": ServiceTypes.IN_MEMORY_DAILY_SPEND_UPDATE_QUEUE,
"gauge_value": queue_size,
},
)
)

View file

@ -1,9 +1,12 @@
import asyncio
import uuid
from typing import TYPE_CHECKING, Any, Optional
from litellm._logging import verbose_proxy_logger
from litellm.caching.redis_cache import RedisCache
from litellm.constants import DEFAULT_CRON_JOB_LOCK_TTL_SECONDS
from litellm.proxy.db.db_transaction_queue.base_update_queue import service_logger_obj
from litellm.types.services import ServiceTypes
if TYPE_CHECKING:
ProxyLogging = Any
@ -57,6 +60,7 @@ class PodLockManager:
self.pod_id,
self.cronjob_id,
)
return True
else:
# Check if the current pod already holds the lock
@ -70,6 +74,7 @@ class PodLockManager:
self.pod_id,
self.cronjob_id,
)
self._emit_acquired_lock_event(self.cronjob_id, self.pod_id)
return True
return False
except Exception as e:
@ -104,6 +109,7 @@ class PodLockManager:
self.pod_id,
self.cronjob_id,
)
self._emit_released_lock_event(self.cronjob_id, self.pod_id)
else:
verbose_proxy_logger.debug(
"Pod %s failed to release Redis lock for cronjob_id=%s",
@ -127,3 +133,31 @@ class PodLockManager:
verbose_proxy_logger.error(
f"Error releasing Redis lock for {self.cronjob_id}: {e}"
)
@staticmethod
def _emit_acquired_lock_event(cronjob_id: str, pod_id: str):
asyncio.create_task(
service_logger_obj.async_service_success_hook(
service=ServiceTypes.POD_LOCK_MANAGER,
duration=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS,
call_type="_emit_acquired_lock_event",
event_metadata={
"gauge_labels": f"{cronjob_id}:{pod_id}",
"gauge_value": 1,
},
)
)
@staticmethod
def _emit_released_lock_event(cronjob_id: str, pod_id: str):
asyncio.create_task(
service_logger_obj.async_service_success_hook(
service=ServiceTypes.POD_LOCK_MANAGER,
duration=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS,
call_type="_emit_released_lock_event",
event_metadata={
"gauge_labels": f"{cronjob_id}:{pod_id}",
"gauge_value": 0,
},
)
)

View file

@ -4,6 +4,7 @@ Handles buffering database `UPDATE` transactions in Redis before committing them
This is to prevent deadlocks and improve reliability
"""
import asyncio
import json
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
@ -16,11 +17,13 @@ from litellm.constants import (
)
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.proxy._types import DailyUserSpendTransaction, DBSpendUpdateTransactions
from litellm.proxy.db.db_transaction_queue.base_update_queue import service_logger_obj
from litellm.proxy.db.db_transaction_queue.daily_spend_update_queue import (
DailySpendUpdateQueue,
)
from litellm.proxy.db.db_transaction_queue.spend_update_queue import SpendUpdateQueue
from litellm.secret_managers.main import str_to_bool
from litellm.types.services import ServiceTypes
if TYPE_CHECKING:
from litellm.proxy.utils import PrismaClient
@ -136,18 +139,27 @@ class RedisUpdateBuffer:
return
list_of_transactions = [safe_dumps(db_spend_update_transactions)]
await self.redis_cache.async_rpush(
current_redis_buffer_size = await self.redis_cache.async_rpush(
key=REDIS_UPDATE_BUFFER_KEY,
values=list_of_transactions,
)
await self._emit_new_item_added_to_redis_buffer_event(
queue_size=current_redis_buffer_size,
service=ServiceTypes.REDIS_SPEND_UPDATE_QUEUE,
)
list_of_daily_spend_update_transactions = [
safe_dumps(daily_spend_update_transactions)
]
await self.redis_cache.async_rpush(
current_redis_buffer_size = await self.redis_cache.async_rpush(
key=REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY,
values=list_of_daily_spend_update_transactions,
)
await self._emit_new_item_added_to_redis_buffer_event(
queue_size=current_redis_buffer_size,
service=ServiceTypes.REDIS_DAILY_SPEND_UPDATE_QUEUE,
)
@staticmethod
def _number_of_transactions_to_store_in_redis(
@ -300,3 +312,20 @@ class RedisUpdateBuffer:
)
return combined_transaction
async def _emit_new_item_added_to_redis_buffer_event(
self,
service: ServiceTypes,
queue_size: int,
):
asyncio.create_task(
service_logger_obj.async_service_success_hook(
service=service,
duration=0,
call_type="_emit_new_item_added_to_queue_event",
event_metadata={
"gauge_labels": service,
"gauge_value": queue_size,
},
)
)

View file

@ -1,5 +1,5 @@
import asyncio
from typing import Dict, List
from typing import Dict, List, Optional
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import (
@ -7,7 +7,11 @@ from litellm.proxy._types import (
Litellm_EntityType,
SpendUpdateQueueItem,
)
from litellm.proxy.db.db_transaction_queue.base_update_queue import BaseUpdateQueue
from litellm.proxy.db.db_transaction_queue.base_update_queue import (
BaseUpdateQueue,
service_logger_obj,
)
from litellm.types.services import ServiceTypes
class SpendUpdateQueue(BaseUpdateQueue):
@ -203,3 +207,19 @@ class SpendUpdateQueue(BaseUpdateQueue):
transactions_dict[entity_id] += response_cost or 0
return db_spend_update_transactions
async def _emit_new_item_added_to_queue_event(
self,
queue_size: Optional[int] = None,
):
asyncio.create_task(
service_logger_obj.async_service_success_hook(
service=ServiceTypes.IN_MEMORY_SPEND_UPDATE_QUEUE,
duration=0,
call_type="_emit_new_item_added_to_queue_event",
event_metadata={
"gauge_labels": ServiceTypes.IN_MEMORY_SPEND_UPDATE_QUEUE,
"gauge_value": queue_size,
},
)
)

View file

@ -5,11 +5,6 @@ model_list:
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
general_settings:
use_redis_transaction_buffer: true
litellm_settings:
cache: True
cache_params:
type: redis
supported_call_types: []
callbacks: ["prometheus"]
service_callback: ["prometheus_system"]

View file

@ -1,8 +1,15 @@
import enum
import uuid
from typing import Optional
from typing import List, Optional
from pydantic import BaseModel, Field
from typing_extensions import TypedDict
class ServiceMetrics(enum.Enum):
COUNTER = "counter"
HISTOGRAM = "histogram"
GAUGE = "gauge"
class ServiceTypes(str, enum.Enum):
@ -18,6 +25,84 @@ class ServiceTypes(str, enum.Enum):
ROUTER = "router"
AUTH = "auth"
PROXY_PRE_CALL = "proxy_pre_call"
POD_LOCK_MANAGER = "pod_lock_manager"
"""
Operational metrics for DB Transaction Queues
"""
# daily spend update queue - actual transaction events
IN_MEMORY_DAILY_SPEND_UPDATE_QUEUE = "in_memory_daily_spend_update_queue"
REDIS_DAILY_SPEND_UPDATE_QUEUE = "redis_daily_spend_update_queue"
# spend update queue - current spend of key, user, team
IN_MEMORY_SPEND_UPDATE_QUEUE = "in_memory_spend_update_queue"
REDIS_SPEND_UPDATE_QUEUE = "redis_spend_update_queue"
class ServiceConfig(TypedDict):
"""
Configuration for services and their metrics
"""
metrics: List[ServiceMetrics] # What metrics this service should support
"""
Metric types to use for each service
- REDIS only needs Counter, Histogram
- Pod Lock Manager only needs a gauge metric
"""
DEFAULT_SERVICE_CONFIGS = {
ServiceTypes.REDIS.value: {
"metrics": [ServiceMetrics.COUNTER, ServiceMetrics.HISTOGRAM]
},
ServiceTypes.DB.value: {
"metrics": [ServiceMetrics.COUNTER, ServiceMetrics.HISTOGRAM]
},
ServiceTypes.BATCH_WRITE_TO_DB.value: {
"metrics": [ServiceMetrics.COUNTER, ServiceMetrics.HISTOGRAM]
},
ServiceTypes.RESET_BUDGET_JOB.value: {
"metrics": [ServiceMetrics.COUNTER, ServiceMetrics.HISTOGRAM]
},
ServiceTypes.LITELLM.value: {
"metrics": [ServiceMetrics.COUNTER, ServiceMetrics.HISTOGRAM]
},
ServiceTypes.ROUTER.value: {
"metrics": [ServiceMetrics.COUNTER, ServiceMetrics.HISTOGRAM]
},
ServiceTypes.AUTH.value: {
"metrics": [ServiceMetrics.COUNTER, ServiceMetrics.HISTOGRAM]
},
ServiceTypes.PROXY_PRE_CALL.value: {
"metrics": [ServiceMetrics.COUNTER, ServiceMetrics.HISTOGRAM]
},
# Operational metrics for DB Transaction Queues
ServiceTypes.POD_LOCK_MANAGER.value: {"metrics": [ServiceMetrics.GAUGE]},
ServiceTypes.IN_MEMORY_DAILY_SPEND_UPDATE_QUEUE.value: {
"metrics": [ServiceMetrics.GAUGE]
},
ServiceTypes.REDIS_DAILY_SPEND_UPDATE_QUEUE.value: {
"metrics": [ServiceMetrics.GAUGE]
},
ServiceTypes.IN_MEMORY_SPEND_UPDATE_QUEUE.value: {
"metrics": [ServiceMetrics.GAUGE]
},
ServiceTypes.REDIS_SPEND_UPDATE_QUEUE.value: {"metrics": [ServiceMetrics.GAUGE]},
}
class ServiceEventMetadata(TypedDict, total=False):
"""
The metadata logged during service success/failure
Add any extra fields you expect to access in the service_success_hook/service_failure_hook
"""
# Dynamically control gauge labels and values
gauge_labels: Optional[str]
gauge_value: Optional[float]
class ServiceLoggerPayload(BaseModel):
@ -30,6 +115,9 @@ class ServiceLoggerPayload(BaseModel):
service: ServiceTypes = Field(description="who is this for? - postgres/redis")
duration: float = Field(description="How long did the request take?")
call_type: str = Field(description="The call of the service, being made")
event_metadata: Optional[dict] = Field(
description="The metadata logged during service success/failure"
)
def to_json(self, **kwargs):
try:

View file

@ -0,0 +1,48 @@
import json
import os
import sys
from unittest.mock import AsyncMock, patch
import pytest
from fastapi.testclient import TestClient
from litellm.integrations.prometheus_services import (
PrometheusServicesLogger,
ServiceMetrics,
ServiceTypes,
)
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
def test_create_gauge_new():
"""Test creating a new gauge"""
pl = PrometheusServicesLogger()
# Create new gauge
gauge = pl.create_gauge(service="test_service", type_of_request="size")
assert gauge is not None
assert pl._get_metric("litellm_test_service_size") is gauge
def test_update_gauge():
"""Test updating a gauge's value"""
pl = PrometheusServicesLogger()
# Create a gauge to test with
gauge = pl.create_gauge(service="test_service", type_of_request="size")
# Mock the labels method to verify it's called correctly
with patch.object(gauge, "labels") as mock_labels:
mock_gauge = AsyncMock()
mock_labels.return_value = mock_gauge
# Call update_gauge
pl.update_gauge(gauge=gauge, labels="test_label", amount=42.5)
# Verify correct methods were called
mock_labels.assert_called_once_with("test_label")
mock_gauge.set.assert_called_once_with(42.5)