diff --git a/litellm/_service_logger.py b/litellm/_service_logger.py index 8f835bea83..7a60359d54 100644 --- a/litellm/_service_logger.py +++ b/litellm/_service_logger.py @@ -124,6 +124,7 @@ class ServiceLogging(CustomLogger): service=service, duration=duration, call_type=call_type, + event_metadata=event_metadata, ) for callback in litellm.service_callback: @@ -229,6 +230,7 @@ class ServiceLogging(CustomLogger): service=service, duration=duration, call_type=call_type, + event_metadata=event_metadata, ) for callback in litellm.service_callback: diff --git a/litellm/integrations/prometheus_services.py b/litellm/integrations/prometheus_services.py index 4bf293fb01..d14cbd7469 100644 --- a/litellm/integrations/prometheus_services.py +++ b/litellm/integrations/prometheus_services.py @@ -7,7 +7,12 @@ from typing import List, Optional, Union from litellm._logging import print_verbose, verbose_logger from litellm.types.integrations.prometheus import LATENCY_BUCKETS -from litellm.types.services import ServiceLoggerPayload, ServiceTypes +from litellm.types.services import ( + DEFAULT_SERVICE_CONFIGS, + ServiceLoggerPayload, + ServiceMetrics, + ServiceTypes, +) FAILED_REQUESTS_LABELS = ["error_class", "function_name"] @@ -23,7 +28,8 @@ class PrometheusServicesLogger: ): try: try: - from prometheus_client import REGISTRY, Counter, Histogram + from prometheus_client import REGISTRY, Counter, Gauge, Histogram + from prometheus_client.gc_collector import Collector except ImportError: raise Exception( "Missing prometheus_client. Run `pip install prometheus-client`" @@ -31,36 +37,52 @@ class PrometheusServicesLogger: self.Histogram = Histogram self.Counter = Counter + self.Gauge = Gauge self.REGISTRY = REGISTRY verbose_logger.debug("in init prometheus services metrics") - self.services = [item.value for item in ServiceTypes] - - self.payload_to_prometheus_map = ( - {} - ) # store the prometheus histogram/counter we need to call for each field in payload + self.services = [item for item in ServiceTypes] + self.payload_to_prometheus_map = {} for service in self.services: - histogram = self.create_histogram(service, type_of_request="latency") - counter_failed_request = self.create_counter( - service, - type_of_request="failed_requests", - additional_labels=FAILED_REQUESTS_LABELS, - ) - counter_total_requests = self.create_counter( - service, type_of_request="total_requests" - ) - self.payload_to_prometheus_map[service] = [ - histogram, - counter_failed_request, - counter_total_requests, - ] + service_metrics: List[Union[Histogram, Counter, Gauge, Collector]] = [] - self.prometheus_to_amount_map: dict = ( - {} - ) # the field / value in ServiceLoggerPayload the object needs to be incremented by + metrics_to_initialize = self._get_service_metrics_initialize(service) + # Initialize only the configured metrics for each service + if ServiceMetrics.HISTOGRAM in metrics_to_initialize: + histogram = self.create_histogram( + service, type_of_request="latency" + ) + if histogram: + service_metrics.append(histogram) + + if ServiceMetrics.COUNTER in metrics_to_initialize: + counter_failed_request = self.create_counter( + service, + type_of_request="failed_requests", + additional_labels=FAILED_REQUESTS_LABELS, + ) + if counter_failed_request: + service_metrics.append(counter_failed_request) + counter_total_requests = self.create_counter( + service, type_of_request="total_requests" + ) + if counter_total_requests: + service_metrics.append(counter_total_requests) + + if ServiceMetrics.GAUGE in metrics_to_initialize: + gauge = self.create_gauge( + service, type_of_request="pod_lock_manager" + ) + if gauge: + service_metrics.append(gauge) + + if service_metrics: + self.payload_to_prometheus_map[service] = service_metrics + + self.prometheus_to_amount_map: dict = {} ### MOCK TESTING ### self.mock_testing = mock_testing self.mock_testing_success_calls = 0 @@ -70,6 +92,17 @@ class PrometheusServicesLogger: print_verbose(f"Got exception on init prometheus client {str(e)}") raise e + def _get_service_metrics_initialize( + self, service: ServiceTypes + ) -> List[ServiceMetrics]: + if service not in DEFAULT_SERVICE_CONFIGS: + raise ValueError(f"Service {service} not found in DEFAULT_SERVICE_CONFIGS") + + metrics = DEFAULT_SERVICE_CONFIGS.get(service, {}).get("metrics", []) + if not metrics: + raise ValueError(f"No metrics found for service {service}") + return metrics + def is_metric_registered(self, metric_name) -> bool: for metric in self.REGISTRY.collect(): if metric_name == metric.name: @@ -94,6 +127,15 @@ class PrometheusServicesLogger: buckets=LATENCY_BUCKETS, ) + def create_gauge(self, service: str, type_of_request: str): + metric_name = "litellm_{}_{}".format(service, type_of_request) + is_registered = self.is_metric_registered(metric_name) + if is_registered: + return self._get_metric(metric_name) + return self.Gauge( + metric_name, "Gauge for {} service".format(service), labelnames=[service] + ) + def create_counter( self, service: str, @@ -120,6 +162,15 @@ class PrometheusServicesLogger: histogram.labels(labels).observe(amount) + def update_gauge( + self, + gauge, + labels: str, + amount: float, + ): + assert isinstance(gauge, self.Gauge) + gauge.labels(labels).set(amount) + def increment_counter( self, counter, @@ -190,6 +241,13 @@ class PrometheusServicesLogger: labels=payload.service.value, amount=1, # LOG TOTAL REQUESTS TO PROMETHEUS ) + elif isinstance(obj, self.Gauge): + if payload.event_metadata: + self.update_gauge( + gauge=obj, + labels=payload.event_metadata.get("gauge_labels") or "", + amount=payload.event_metadata.get("gauge_value") or 0, + ) async def async_service_failure_hook( self, diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index fe8d73d26a..2ee830bca4 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -12,4 +12,6 @@ litellm_settings: cache: True cache_params: type: redis - supported_call_types: [] \ No newline at end of file + supported_call_types: [] + callbacks: ["prometheus"] + service_callback: ["prometheus_system"] \ No newline at end of file