diff --git a/litellm/tests/test_prometheus.py b/litellm/tests/test_prometheus.py index 1232130cb..a7f9ef388 100644 --- a/litellm/tests/test_prometheus.py +++ b/litellm/tests/test_prometheus.py @@ -9,7 +9,7 @@ import logging import uuid import pytest -from prometheus_client import REGISTRY +from prometheus_client import REGISTRY, CollectorRegistry import litellm from litellm import completion @@ -85,8 +85,17 @@ async def test_async_prometheus_success_logging(): async def test_async_prometheus_success_logging_with_callbacks(): run_id = str(uuid.uuid4()) litellm.set_verbose = True + + litellm.success_callback = [] + litellm.failure_callback = [] litellm.callbacks = ["prometheus"] + # Get initial metric values + initial_metrics = {} + for metric in REGISTRY.collect(): + for sample in metric.samples: + initial_metrics[sample.name] = sample.value + response = await litellm.acompletion( model="claude-instant-1.2", messages=[{"role": "user", "content": "what llm are u"}], @@ -124,15 +133,37 @@ async def test_async_prometheus_success_logging_with_callbacks(): vars(test_prometheus_logger.litellm_requests_metric), ) - # Get the metrics - metrics = {} + # Get the updated metrics + updated_metrics = {} for metric in REGISTRY.collect(): for sample in metric.samples: - metrics[sample.name] = sample.value + updated_metrics[sample.name] = sample.value - print("metrics from prometheus", metrics) - assert metrics["litellm_requests_metric_total"] == 1.0 - assert metrics["litellm_total_tokens_total"] == 30.0 - assert metrics["litellm_deployment_success_responses_total"] == 1.0 - assert metrics["litellm_deployment_total_requests_total"] == 1.0 - assert metrics["litellm_deployment_latency_per_output_token_bucket"] == 1.0 + print("metrics from prometheus", updated_metrics) + + # Assert the delta for each metric + assert ( + updated_metrics["litellm_requests_metric_total"] + - initial_metrics.get("litellm_requests_metric_total", 0) + == 1.0 + ) + assert ( + updated_metrics["litellm_total_tokens_total"] + - initial_metrics.get("litellm_total_tokens_total", 0) + == 30.0 + ) + assert ( + updated_metrics["litellm_deployment_success_responses_total"] + - initial_metrics.get("litellm_deployment_success_responses_total", 0) + == 1.0 + ) + assert ( + updated_metrics["litellm_deployment_total_requests_total"] + - initial_metrics.get("litellm_deployment_total_requests_total", 0) + == 1.0 + ) + assert ( + updated_metrics["litellm_deployment_latency_per_output_token_bucket"] + - initial_metrics.get("litellm_deployment_latency_per_output_token_bucket", 0) + == 1.0 + )