diff --git a/litellm/integrations/prometheus.py b/litellm/integrations/prometheus.py index a92984706..8a4f409b6 100644 --- a/litellm/integrations/prometheus.py +++ b/litellm/integrations/prometheus.py @@ -6,7 +6,7 @@ import subprocess import sys import traceback import uuid -from datetime import datetime, timedelta +from datetime import date, datetime, timedelta from typing import Optional, TypedDict, Union import dotenv @@ -334,13 +334,8 @@ class PrometheusLogger(CustomLogger): print_verbose(f"Got exception on init prometheus client {str(e)}") raise e - async def async_log_success_event( # noqa: PLR0915 - self, kwargs, response_obj, start_time, end_time - ): + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): # Define prometheus client - from litellm.proxy.common_utils.callback_utils import ( - get_model_group_from_litellm_kwargs, - ) from litellm.types.utils import StandardLoggingPayload verbose_logger.debug( @@ -358,7 +353,6 @@ class PrometheusLogger(CustomLogger): _metadata = litellm_params.get("metadata", {}) proxy_server_request = litellm_params.get("proxy_server_request") or {} end_user_id = proxy_server_request.get("body", {}).get("user", None) - model_parameters: dict = standard_logging_payload["model_parameters"] user_id = standard_logging_payload["metadata"]["user_api_key_user_id"] user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"] user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"] @@ -369,25 +363,6 @@ class PrometheusLogger(CustomLogger): output_tokens = standard_logging_payload["completion_tokens"] tokens_used = standard_logging_payload["total_tokens"] response_cost = standard_logging_payload["response_cost"] - _team_spend = litellm_params.get("metadata", {}).get( - "user_api_key_team_spend", None - ) - _team_max_budget = litellm_params.get("metadata", {}).get( - "user_api_key_team_max_budget", None - ) - _remaining_team_budget = safe_get_remaining_budget( - max_budget=_team_max_budget, spend=_team_spend - ) - - _api_key_spend = litellm_params.get("metadata", {}).get( - "user_api_key_spend", None - ) - _api_key_max_budget = litellm_params.get("metadata", {}).get( - "user_api_key_max_budget", None - ) - _remaining_api_key_budget = safe_get_remaining_budget( - max_budget=_api_key_max_budget, spend=_api_key_spend - ) print_verbose( f"inside track_prometheus_metrics, model {model}, response_cost {response_cost}, tokens_used {tokens_used}, end_user_id {end_user_id}, user_api_key {user_api_key}" @@ -402,24 +377,76 @@ class PrometheusLogger(CustomLogger): user_api_key = hash_token(user_api_key) - self.litellm_requests_metric.labels( - end_user_id, - user_api_key, - user_api_key_alias, - model, - user_api_team, - user_api_team_alias, - user_id, - ).inc() - self.litellm_spend_metric.labels( - end_user_id, - user_api_key, - user_api_key_alias, - model, - user_api_team, - user_api_team_alias, - user_id, - ).inc(response_cost) + # increment total LLM requests and spend metric + self._increment_top_level_request_and_spend_metrics( + end_user_id=end_user_id, + user_api_key=user_api_key, + user_api_key_alias=user_api_key_alias, + model=model, + user_api_team=user_api_team, + user_api_team_alias=user_api_team_alias, + user_id=user_id, + response_cost=response_cost, + ) + + # input, output, total token metrics + self._increment_token_metrics( + standard_logging_payload=standard_logging_payload, + end_user_id=end_user_id, + user_api_key=user_api_key, + user_api_key_alias=user_api_key_alias, + model=model, + user_api_team=user_api_team, + user_api_team_alias=user_api_team_alias, + user_id=user_id, + ) + + # remaining budget metrics + self._increment_remaining_budget_metrics( + user_api_team=user_api_team, + user_api_team_alias=user_api_team_alias, + user_api_key=user_api_key, + user_api_key_alias=user_api_key_alias, + litellm_params=litellm_params, + ) + + # set proxy virtual key rpm/tpm metrics + self._set_virtual_key_rate_limit_metrics( + user_api_key=user_api_key, + user_api_key_alias=user_api_key_alias, + kwargs=kwargs, + metadata=_metadata, + ) + + # set latency metrics + self._set_latency_metrics( + kwargs=kwargs, + model=model, + user_api_key=user_api_key, + user_api_key_alias=user_api_key_alias, + user_api_team=user_api_team, + user_api_team_alias=user_api_team_alias, + standard_logging_payload=standard_logging_payload, + ) + + # set x-ratelimit headers + self.set_llm_deployment_success_metrics( + kwargs, start_time, end_time, output_tokens + ) + pass + + def _increment_token_metrics( + self, + standard_logging_payload: StandardLoggingPayload, + end_user_id: Optional[str], + user_api_key: Optional[str], + user_api_key_alias: Optional[str], + model: Optional[str], + user_api_team: Optional[str], + user_api_team_alias: Optional[str], + user_id: Optional[str], + ): + # token metrics self.litellm_tokens_metric.labels( end_user_id, user_api_key, @@ -450,6 +477,34 @@ class PrometheusLogger(CustomLogger): user_id, ).inc(standard_logging_payload["completion_tokens"]) + def _increment_remaining_budget_metrics( + self, + user_api_team: Optional[str], + user_api_team_alias: Optional[str], + user_api_key: Optional[str], + user_api_key_alias: Optional[str], + litellm_params: dict, + ): + _team_spend = litellm_params.get("metadata", {}).get( + "user_api_key_team_spend", None + ) + _team_max_budget = litellm_params.get("metadata", {}).get( + "user_api_key_team_max_budget", None + ) + _remaining_team_budget = self._safe_get_remaining_budget( + max_budget=_team_max_budget, spend=_team_spend + ) + + _api_key_spend = litellm_params.get("metadata", {}).get( + "user_api_key_spend", None + ) + _api_key_max_budget = litellm_params.get("metadata", {}).get( + "user_api_key_max_budget", None + ) + _remaining_api_key_budget = self._safe_get_remaining_budget( + max_budget=_api_key_max_budget, spend=_api_key_spend + ) + # Remaining Budget Metrics self.litellm_remaining_team_budget_metric.labels( user_api_team, user_api_team_alias ).set(_remaining_team_budget) @@ -458,6 +513,47 @@ class PrometheusLogger(CustomLogger): user_api_key, user_api_key_alias ).set(_remaining_api_key_budget) + def _increment_top_level_request_and_spend_metrics( + self, + end_user_id: Optional[str], + user_api_key: Optional[str], + user_api_key_alias: Optional[str], + model: Optional[str], + user_api_team: Optional[str], + user_api_team_alias: Optional[str], + user_id: Optional[str], + response_cost: float, + ): + self.litellm_requests_metric.labels( + end_user_id, + user_api_key, + user_api_key_alias, + model, + user_api_team, + user_api_team_alias, + user_id, + ).inc() + self.litellm_spend_metric.labels( + end_user_id, + user_api_key, + user_api_key_alias, + model, + user_api_team, + user_api_team_alias, + user_id, + ).inc(response_cost) + + def _set_virtual_key_rate_limit_metrics( + self, + user_api_key: Optional[str], + user_api_key_alias: Optional[str], + kwargs: dict, + metadata: dict, + ): + from litellm.proxy.common_utils.callback_utils import ( + get_model_group_from_litellm_kwargs, + ) + # Set remaining rpm/tpm for API Key + model # see parallel_request_limiter.py - variables are set there model_group = get_model_group_from_litellm_kwargs(kwargs) @@ -466,10 +562,8 @@ class PrometheusLogger(CustomLogger): ) remaining_tokens_variable_name = f"litellm-key-remaining-tokens-{model_group}" - remaining_requests = _metadata.get( - remaining_requests_variable_name, sys.maxsize - ) - remaining_tokens = _metadata.get(remaining_tokens_variable_name, sys.maxsize) + remaining_requests = metadata.get(remaining_requests_variable_name, sys.maxsize) + remaining_tokens = metadata.get(remaining_tokens_variable_name, sys.maxsize) self.litellm_remaining_api_key_requests_for_model.labels( user_api_key, user_api_key_alias, model_group @@ -479,9 +573,20 @@ class PrometheusLogger(CustomLogger): user_api_key, user_api_key_alias, model_group ).set(remaining_tokens) + def _set_latency_metrics( + self, + kwargs: dict, + model: Optional[str], + user_api_key: Optional[str], + user_api_key_alias: Optional[str], + user_api_team: Optional[str], + user_api_team_alias: Optional[str], + standard_logging_payload: StandardLoggingPayload, + ): # latency metrics - total_time: timedelta = kwargs.get("end_time") - kwargs.get("start_time") - total_time_seconds = total_time.total_seconds() + model_parameters: dict = standard_logging_payload["model_parameters"] + end_time: datetime = kwargs.get("end_time") or datetime.now() + start_time: Optional[datetime] = kwargs.get("start_time") api_call_start_time = kwargs.get("api_call_start_time", None) completion_start_time = kwargs.get("completion_start_time", None) @@ -509,9 +614,7 @@ class PrometheusLogger(CustomLogger): if api_call_start_time is not None and isinstance( api_call_start_time, datetime ): - api_call_total_time: timedelta = ( - kwargs.get("end_time") - api_call_start_time - ) + api_call_total_time: timedelta = end_time - api_call_start_time api_call_total_time_seconds = api_call_total_time.total_seconds() self.litellm_llm_api_latency_metric.labels( model, @@ -521,20 +624,17 @@ class PrometheusLogger(CustomLogger): user_api_team_alias, ).observe(api_call_total_time_seconds) - # log metrics - self.litellm_request_total_latency_metric.labels( - model, - user_api_key, - user_api_key_alias, - user_api_team, - user_api_team_alias, - ).observe(total_time_seconds) - - # set x-ratelimit headers - self.set_llm_deployment_success_metrics( - kwargs, start_time, end_time, output_tokens - ) - pass + # total request latency + if start_time is not None and isinstance(start_time, datetime): + total_time: timedelta = end_time - start_time + total_time_seconds = total_time.total_seconds() + self.litellm_request_total_latency_metric.labels( + model, + user_api_key, + user_api_key_alias, + user_api_team, + user_api_team_alias, + ).observe(total_time_seconds) async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): from litellm.types.utils import StandardLoggingPayload @@ -1007,14 +1107,13 @@ class PrometheusLogger(CustomLogger): litellm_model_name, model_id, api_base, api_provider, exception_status ).inc() + def _safe_get_remaining_budget( + self, max_budget: Optional[float], spend: Optional[float] + ) -> float: + if max_budget is None: + return float("inf") -def safe_get_remaining_budget( - max_budget: Optional[float], spend: Optional[float] -) -> float: - if max_budget is None: - return float("inf") + if spend is None: + return max_budget - if spend is None: - return max_budget - - return max_budget - spend + return max_budget - spend diff --git a/tests/local_testing/test_prometheus.py b/tests/local_testing/test_prometheus.py index 2f0e4a19e..164d94553 100644 --- a/tests/local_testing/test_prometheus.py +++ b/tests/local_testing/test_prometheus.py @@ -16,6 +16,14 @@ from litellm import completion from litellm._logging import verbose_logger from litellm.integrations.prometheus import PrometheusLogger from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from litellm.types.utils import ( + StandardLoggingPayload, + StandardLoggingMetadata, + StandardLoggingHiddenParams, + StandardLoggingModelInformation, +) +from unittest.mock import MagicMock, patch +from datetime import datetime, timedelta verbose_logger.setLevel(logging.DEBUG) diff --git a/tests/logging_callback_tests/test_prometheus_unit_tests.py b/tests/logging_callback_tests/test_prometheus_unit_tests.py new file mode 100644 index 000000000..035569273 --- /dev/null +++ b/tests/logging_callback_tests/test_prometheus_unit_tests.py @@ -0,0 +1,344 @@ +import io +import os +import sys + +sys.path.insert(0, os.path.abspath("../..")) + +import asyncio +import logging +import uuid + +import pytest +from prometheus_client import REGISTRY, CollectorRegistry + +import litellm +from litellm import completion +from litellm._logging import verbose_logger +from litellm.integrations.prometheus import PrometheusLogger +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from litellm.types.utils import ( + StandardLoggingPayload, + StandardLoggingMetadata, + StandardLoggingHiddenParams, + StandardLoggingModelInformation, +) +import pytest +from unittest.mock import MagicMock, patch +from datetime import datetime, timedelta +from litellm.integrations.prometheus import PrometheusLogger + +verbose_logger.setLevel(logging.DEBUG) + +litellm.set_verbose = True +import time + + +@pytest.fixture +def prometheus_logger(): + collectors = list(REGISTRY._collector_to_names.keys()) + for collector in collectors: + REGISTRY.unregister(collector) + return PrometheusLogger() + + +def create_standard_logging_payload() -> StandardLoggingPayload: + return StandardLoggingPayload( + id="test_id", + call_type="completion", + response_cost=0.1, + response_cost_failure_debug_info=None, + status="success", + total_tokens=30, + prompt_tokens=20, + completion_tokens=10, + startTime=1234567890.0, + endTime=1234567891.0, + completionStartTime=1234567890.5, + model_map_information=StandardLoggingModelInformation( + model_map_key="gpt-3.5-turbo", model_map_value=None + ), + model="gpt-3.5-turbo", + model_id="model-123", + model_group="openai-gpt", + api_base="https://api.openai.com", + metadata=StandardLoggingMetadata( + user_api_key_hash="test_hash", + user_api_key_alias="test_alias", + user_api_key_team_id="test_team", + user_api_key_user_id="test_user", + user_api_key_team_alias="test_team_alias", + spend_logs_metadata=None, + requester_ip_address="127.0.0.1", + requester_metadata=None, + ), + cache_hit=False, + cache_key=None, + saved_cache_cost=0.0, + request_tags=[], + end_user=None, + requester_ip_address="127.0.0.1", + messages=[{"role": "user", "content": "Hello, world!"}], + response={"choices": [{"message": {"content": "Hi there!"}}]}, + error_str=None, + model_parameters={"stream": True}, + hidden_params=StandardLoggingHiddenParams( + model_id="model-123", + cache_key=None, + api_base="https://api.openai.com", + response_cost="0.1", + additional_headers=None, + ), + ) + + +def test_safe_get_remaining_budget(prometheus_logger): + assert prometheus_logger._safe_get_remaining_budget(100, 30) == 70 + assert prometheus_logger._safe_get_remaining_budget(100, None) == 100 + assert prometheus_logger._safe_get_remaining_budget(None, 30) == float("inf") + assert prometheus_logger._safe_get_remaining_budget(None, None) == float("inf") + + +@pytest.mark.asyncio +async def test_async_log_success_event(prometheus_logger): + standard_logging_object = create_standard_logging_payload() + kwargs = { + "model": "gpt-3.5-turbo", + "litellm_params": { + "metadata": { + "user_api_key": "test_key", + "user_api_key_user_id": "test_user", + "user_api_key_team_id": "test_team", + } + }, + "start_time": datetime.now(), + "completion_start_time": datetime.now(), + "api_call_start_time": datetime.now(), + "end_time": datetime.now() + timedelta(seconds=1), + "standard_logging_object": standard_logging_object, + } + response_obj = MagicMock() + + # Mock the prometheus client methods + + # High Level Metrics - request/spend + prometheus_logger.litellm_requests_metric = MagicMock() + prometheus_logger.litellm_spend_metric = MagicMock() + + # Token Metrics + prometheus_logger.litellm_tokens_metric = MagicMock() + prometheus_logger.litellm_input_tokens_metric = MagicMock() + prometheus_logger.litellm_output_tokens_metric = MagicMock() + + # Remaining Budget Metrics + prometheus_logger.litellm_remaining_team_budget_metric = MagicMock() + prometheus_logger.litellm_remaining_api_key_budget_metric = MagicMock() + + # Virtual Key Rate limit Metrics + prometheus_logger.litellm_remaining_api_key_requests_for_model = MagicMock() + prometheus_logger.litellm_remaining_api_key_tokens_for_model = MagicMock() + + # Latency Metrics + prometheus_logger.litellm_llm_api_time_to_first_token_metric = MagicMock() + prometheus_logger.litellm_llm_api_latency_metric = MagicMock() + prometheus_logger.litellm_request_total_latency_metric = MagicMock() + + await prometheus_logger.async_log_success_event( + kwargs, response_obj, kwargs["start_time"], kwargs["end_time"] + ) + + # Assert that the metrics were incremented + prometheus_logger.litellm_requests_metric.labels.assert_called() + prometheus_logger.litellm_spend_metric.labels.assert_called() + + # Token Metrics + prometheus_logger.litellm_tokens_metric.labels.assert_called() + prometheus_logger.litellm_input_tokens_metric.labels.assert_called() + prometheus_logger.litellm_output_tokens_metric.labels.assert_called() + + # Remaining Budget Metrics + prometheus_logger.litellm_remaining_team_budget_metric.labels.assert_called() + prometheus_logger.litellm_remaining_api_key_budget_metric.labels.assert_called() + + # Virtual Key Rate limit Metrics + prometheus_logger.litellm_remaining_api_key_requests_for_model.labels.assert_called() + prometheus_logger.litellm_remaining_api_key_tokens_for_model.labels.assert_called() + + # Latency Metrics + prometheus_logger.litellm_llm_api_time_to_first_token_metric.labels.assert_called() + prometheus_logger.litellm_llm_api_latency_metric.labels.assert_called() + prometheus_logger.litellm_request_total_latency_metric.labels.assert_called() + + +def test_increment_token_metrics(prometheus_logger): + """ + Test the increment_token_metrics method + + input, output, and total tokens metrics are incremented by the values in the standard logging payload + """ + prometheus_logger.litellm_tokens_metric = MagicMock() + prometheus_logger.litellm_input_tokens_metric = MagicMock() + prometheus_logger.litellm_output_tokens_metric = MagicMock() + + standard_logging_payload = create_standard_logging_payload() + standard_logging_payload["total_tokens"] = 100 + standard_logging_payload["prompt_tokens"] = 50 + standard_logging_payload["completion_tokens"] = 50 + + prometheus_logger._increment_token_metrics( + standard_logging_payload, + end_user_id="user1", + user_api_key="key1", + user_api_key_alias="alias1", + model="gpt-3.5-turbo", + user_api_team="team1", + user_api_team_alias="team_alias1", + user_id="user1", + ) + + prometheus_logger.litellm_tokens_metric.labels.assert_called_once_with( + "user1", "key1", "alias1", "gpt-3.5-turbo", "team1", "team_alias1", "user1" + ) + prometheus_logger.litellm_tokens_metric.labels().inc.assert_called_once_with(100) + + prometheus_logger.litellm_input_tokens_metric.labels.assert_called_once_with( + "user1", "key1", "alias1", "gpt-3.5-turbo", "team1", "team_alias1", "user1" + ) + prometheus_logger.litellm_input_tokens_metric.labels().inc.assert_called_once_with( + 50 + ) + + prometheus_logger.litellm_output_tokens_metric.labels.assert_called_once_with( + "user1", "key1", "alias1", "gpt-3.5-turbo", "team1", "team_alias1", "user1" + ) + prometheus_logger.litellm_output_tokens_metric.labels().inc.assert_called_once_with( + 50 + ) + + +def test_increment_remaining_budget_metrics(prometheus_logger): + """ + Test the increment_remaining_budget_metrics method + + team and api key budget metrics are set to the difference between max budget and spend + """ + prometheus_logger.litellm_remaining_team_budget_metric = MagicMock() + prometheus_logger.litellm_remaining_api_key_budget_metric = MagicMock() + + litellm_params = { + "metadata": { + "user_api_key_team_spend": 50, + "user_api_key_team_max_budget": 100, + "user_api_key_spend": 25, + "user_api_key_max_budget": 75, + } + } + + prometheus_logger._increment_remaining_budget_metrics( + user_api_team="team1", + user_api_team_alias="team_alias1", + user_api_key="key1", + user_api_key_alias="alias1", + litellm_params=litellm_params, + ) + + prometheus_logger.litellm_remaining_team_budget_metric.labels.assert_called_once_with( + "team1", "team_alias1" + ) + prometheus_logger.litellm_remaining_team_budget_metric.labels().set.assert_called_once_with( + 50 + ) + + prometheus_logger.litellm_remaining_api_key_budget_metric.labels.assert_called_once_with( + "key1", "alias1" + ) + prometheus_logger.litellm_remaining_api_key_budget_metric.labels().set.assert_called_once_with( + 50 + ) + + +def test_set_latency_metrics(prometheus_logger): + """ + Test the set_latency_metrics method + + time to first token, llm api latency, and request total latency metrics are set to the values in the standard logging payload + """ + standard_logging_payload = create_standard_logging_payload() + standard_logging_payload["model_parameters"] = {"stream": True} + prometheus_logger.litellm_llm_api_time_to_first_token_metric = MagicMock() + prometheus_logger.litellm_llm_api_latency_metric = MagicMock() + prometheus_logger.litellm_request_total_latency_metric = MagicMock() + + now = datetime.now() + kwargs = { + "end_time": now, # when the request ends + "start_time": now - timedelta(seconds=2), # when the request starts + "api_call_start_time": now - timedelta(seconds=1.5), # when the api call starts + "completion_start_time": now + - timedelta(seconds=1), # when the completion starts + } + + prometheus_logger._set_latency_metrics( + kwargs=kwargs, + model="gpt-3.5-turbo", + user_api_key="key1", + user_api_key_alias="alias1", + user_api_team="team1", + user_api_team_alias="team_alias1", + standard_logging_payload=standard_logging_payload, + ) + + # completion_start_time - api_call_start_time + prometheus_logger.litellm_llm_api_time_to_first_token_metric.labels.assert_called_once_with( + "gpt-3.5-turbo", "key1", "alias1", "team1", "team_alias1" + ) + prometheus_logger.litellm_llm_api_time_to_first_token_metric.labels().observe.assert_called_once_with( + 0.5 + ) + + # end_time - api_call_start_time + prometheus_logger.litellm_llm_api_latency_metric.labels.assert_called_once_with( + "gpt-3.5-turbo", "key1", "alias1", "team1", "team_alias1" + ) + prometheus_logger.litellm_llm_api_latency_metric.labels().observe.assert_called_once_with( + 1.5 + ) + + # total latency for the request + prometheus_logger.litellm_request_total_latency_metric.labels.assert_called_once_with( + "gpt-3.5-turbo", "key1", "alias1", "team1", "team_alias1" + ) + prometheus_logger.litellm_request_total_latency_metric.labels().observe.assert_called_once_with( + 2.0 + ) + + +def test_increment_top_level_request_and_spend_metrics(prometheus_logger): + """ + Test the increment_top_level_request_and_spend_metrics method + + - litellm_requests_metric is incremented by 1 + - litellm_spend_metric is incremented by the response cost in the standard logging payload + """ + prometheus_logger.litellm_requests_metric = MagicMock() + prometheus_logger.litellm_spend_metric = MagicMock() + + prometheus_logger._increment_top_level_request_and_spend_metrics( + end_user_id="user1", + user_api_key="key1", + user_api_key_alias="alias1", + model="gpt-3.5-turbo", + user_api_team="team1", + user_api_team_alias="team_alias1", + user_id="user1", + response_cost=0.1, + ) + + prometheus_logger.litellm_requests_metric.labels.assert_called_once_with( + "user1", "key1", "alias1", "gpt-3.5-turbo", "team1", "team_alias1", "user1" + ) + prometheus_logger.litellm_requests_metric.labels().inc.assert_called_once() + + prometheus_logger.litellm_spend_metric.labels.assert_called_once_with( + "user1", "key1", "alias1", "gpt-3.5-turbo", "team1", "team_alias1", "user1" + ) + prometheus_logger.litellm_spend_metric.labels().inc.assert_called_once_with(0.1) diff --git a/tests/logging_callback_tests/test_unit_tests_init_callbacks.py b/tests/logging_callback_tests/test_unit_tests_init_callbacks.py index e0a7e85f5..5b14717dc 100644 --- a/tests/logging_callback_tests/test_unit_tests_init_callbacks.py +++ b/tests/logging_callback_tests/test_unit_tests_init_callbacks.py @@ -17,6 +17,7 @@ import litellm import asyncio import logging from litellm._logging import verbose_logger +from prometheus_client import REGISTRY, CollectorRegistry from litellm.integrations.lago import LagoLogger from litellm.integrations.openmeter import OpenMeterLogger @@ -33,6 +34,12 @@ from litellm.integrations.argilla import ArgillaLogger from litellm.proxy.hooks.dynamic_rate_limiter import _PROXY_DynamicRateLimitHandler from unittest.mock import patch +# clear prometheus collectors / registry +collectors = list(REGISTRY._collector_to_names.keys()) +for collector in collectors: + REGISTRY.unregister(collector) +###################################### + callback_class_str_to_classType = { "lago": LagoLogger, "openmeter": OpenMeterLogger, @@ -111,6 +118,11 @@ async def use_callback_in_llm_call( elif callback == "openmeter": # it's currently handled in jank way, TODO: fix openmete and then actually run it's test return + elif callback == "prometheus": + # pytest teardown - clear existing prometheus collectors + collectors = list(REGISTRY._collector_to_names.keys()) + for collector in collectors: + REGISTRY.unregister(collector) # Mock the httpx call for Argilla dataset retrieval if callback == "argilla":