From 9f42d15713a26540a051c69ffffeb46884db6f02 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 13 Apr 2024 18:31:35 -0700 Subject: [PATCH] feat(prometheus_services.py): track when redis calls fail --- litellm/integrations/prometheus_services.py | 67 +++++++++++++++------ 1 file changed, 50 insertions(+), 17 deletions(-) diff --git a/litellm/integrations/prometheus_services.py b/litellm/integrations/prometheus_services.py index a26271987..ecd75ad0b 100644 --- a/litellm/integrations/prometheus_services.py +++ b/litellm/integrations/prometheus_services.py @@ -40,17 +40,13 @@ class PrometheusServicesLogger: for service in self.services: histogram = self.create_histogram(service) - self.payload_to_prometheus_map[service] = histogram + counter = self.create_counter(service) + self.payload_to_prometheus_map[service] = [histogram, counter] self.prometheus_to_amount_map: dict = ( {} ) # the field / value in ServiceLoggerPayload the object needs to be incremented by - # self.payload_to_prometheus_map["service"] = [self.litellm_service_latency] - # self.prometheus_to_amount_map[self.litellm_service_latency._name] = ( - # "duration" - # ) - ### MOCK TESTING ### self.mock_testing = mock_testing self.mock_testing_success_calls = 0 @@ -85,6 +81,17 @@ class PrometheusServicesLogger: labelnames=[label], ) + def create_counter(self, label: str): + metric_name = "litellm_{}_requests".format(label) + is_registered = self.is_metric_registered(metric_name) + if is_registered: + return self.get_metric(metric_name) + return self.Counter( + metric_name, + "Total failed requests for {} service".format(label), + labelnames=[label], + ) + def observe_histogram( self, histogram, @@ -98,7 +105,7 @@ class PrometheusServicesLogger: def increment_counter( self, counter, - labels: list, + labels: str, amount: float, ): assert isinstance(counter, self.Counter) @@ -110,16 +117,29 @@ class PrometheusServicesLogger: self.mock_testing_success_calls += 1 if payload.service.value in self.payload_to_prometheus_map: - self.observe_histogram( - histogram=self.payload_to_prometheus_map[payload.service.value], - labels=payload.service.value, - amount=payload.duration, - ) + prom_objects = self.payload_to_prometheus_map[payload.service.value] + for obj in prom_objects: + if isinstance(obj, self.Histogram): + self.observe_histogram( + histogram=obj, + labels=payload.service.value, + amount=payload.duration, + ) def service_failure_hook(self, payload: ServiceLoggerPayload): if self.mock_testing: self.mock_testing_failure_calls += 1 + if payload.service.value in self.payload_to_prometheus_map: + prom_objects = self.payload_to_prometheus_map[payload.service.value] + for obj in prom_objects: + if isinstance(obj, self.Counter): + self.increment_counter( + counter=obj, + labels=payload.service.value, + amount=1, # LOG ERROR COUNT TO PROMETHEUS + ) + async def async_service_success_hook(self, payload: ServiceLoggerPayload): """ Log successful call to prometheus @@ -128,12 +148,25 @@ class PrometheusServicesLogger: self.mock_testing_success_calls += 1 if payload.service.value in self.payload_to_prometheus_map: - self.observe_histogram( - histogram=self.payload_to_prometheus_map[payload.service.value], - labels=payload.service.value, - amount=payload.duration, - ) + prom_objects = self.payload_to_prometheus_map[payload.service.value] + for obj in prom_objects: + if isinstance(obj, self.Histogram): + self.observe_histogram( + histogram=obj, + labels=payload.service.value, + amount=payload.duration, + ) async def async_service_failure_hook(self, payload: ServiceLoggerPayload): if self.mock_testing: self.mock_testing_failure_calls += 1 + + if payload.service.value in self.payload_to_prometheus_map: + prom_objects = self.payload_to_prometheus_map[payload.service.value] + for obj in prom_objects: + if isinstance(obj, self.Counter): + self.increment_counter( + counter=obj, + labels=payload.service.value, + amount=1, # LOG ERROR COUNT TO PROMETHEUS + )