(refactor) prometheus async_log_success_event to be under 100 LOC (#6416)

* unit testig for prometheus

* unit testing for success metrics

* use 1 helper for _increment_token_metrics

* use helper for _increment_remaining_budget_metrics

* use _increment_remaining_budget_metrics

* use _increment_top_level_request_and_spend_metrics

* use helper for _set_latency_metrics

* remove noqa violation

* fix test prometheus

* test prometheus

* unit testing for all prometheus helper functions

* fix prom unit tests

* fix unit tests prometheus

* fix unit test prom
This commit is contained in:
Ishaan Jaff 2024-10-24 16:41:09 +04:00 committed by GitHub
parent ca09f4afec
commit cdda7c243f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 540 additions and 77 deletions

View file

@ -6,7 +6,7 @@ import subprocess
import sys import sys
import traceback import traceback
import uuid import uuid
from datetime import datetime, timedelta from datetime import date, datetime, timedelta
from typing import Optional, TypedDict, Union from typing import Optional, TypedDict, Union
import dotenv import dotenv
@ -334,13 +334,8 @@ class PrometheusLogger(CustomLogger):
print_verbose(f"Got exception on init prometheus client {str(e)}") print_verbose(f"Got exception on init prometheus client {str(e)}")
raise e raise e
async def async_log_success_event( # noqa: PLR0915 async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
self, kwargs, response_obj, start_time, end_time
):
# Define prometheus client # Define prometheus client
from litellm.proxy.common_utils.callback_utils import (
get_model_group_from_litellm_kwargs,
)
from litellm.types.utils import StandardLoggingPayload from litellm.types.utils import StandardLoggingPayload
verbose_logger.debug( verbose_logger.debug(
@ -358,7 +353,6 @@ class PrometheusLogger(CustomLogger):
_metadata = litellm_params.get("metadata", {}) _metadata = litellm_params.get("metadata", {})
proxy_server_request = litellm_params.get("proxy_server_request") or {} proxy_server_request = litellm_params.get("proxy_server_request") or {}
end_user_id = proxy_server_request.get("body", {}).get("user", None) end_user_id = proxy_server_request.get("body", {}).get("user", None)
model_parameters: dict = standard_logging_payload["model_parameters"]
user_id = standard_logging_payload["metadata"]["user_api_key_user_id"] user_id = standard_logging_payload["metadata"]["user_api_key_user_id"]
user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"] user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"]
user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"] user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"]
@ -369,25 +363,6 @@ class PrometheusLogger(CustomLogger):
output_tokens = standard_logging_payload["completion_tokens"] output_tokens = standard_logging_payload["completion_tokens"]
tokens_used = standard_logging_payload["total_tokens"] tokens_used = standard_logging_payload["total_tokens"]
response_cost = standard_logging_payload["response_cost"] response_cost = standard_logging_payload["response_cost"]
_team_spend = litellm_params.get("metadata", {}).get(
"user_api_key_team_spend", None
)
_team_max_budget = litellm_params.get("metadata", {}).get(
"user_api_key_team_max_budget", None
)
_remaining_team_budget = safe_get_remaining_budget(
max_budget=_team_max_budget, spend=_team_spend
)
_api_key_spend = litellm_params.get("metadata", {}).get(
"user_api_key_spend", None
)
_api_key_max_budget = litellm_params.get("metadata", {}).get(
"user_api_key_max_budget", None
)
_remaining_api_key_budget = safe_get_remaining_budget(
max_budget=_api_key_max_budget, spend=_api_key_spend
)
print_verbose( print_verbose(
f"inside track_prometheus_metrics, model {model}, response_cost {response_cost}, tokens_used {tokens_used}, end_user_id {end_user_id}, user_api_key {user_api_key}" f"inside track_prometheus_metrics, model {model}, response_cost {response_cost}, tokens_used {tokens_used}, end_user_id {end_user_id}, user_api_key {user_api_key}"
@ -402,24 +377,76 @@ class PrometheusLogger(CustomLogger):
user_api_key = hash_token(user_api_key) user_api_key = hash_token(user_api_key)
self.litellm_requests_metric.labels( # increment total LLM requests and spend metric
end_user_id, self._increment_top_level_request_and_spend_metrics(
user_api_key, end_user_id=end_user_id,
user_api_key_alias, user_api_key=user_api_key,
model, user_api_key_alias=user_api_key_alias,
user_api_team, model=model,
user_api_team_alias, user_api_team=user_api_team,
user_id, user_api_team_alias=user_api_team_alias,
).inc() user_id=user_id,
self.litellm_spend_metric.labels( response_cost=response_cost,
end_user_id, )
user_api_key,
user_api_key_alias, # input, output, total token metrics
model, self._increment_token_metrics(
user_api_team, standard_logging_payload=standard_logging_payload,
user_api_team_alias, end_user_id=end_user_id,
user_id, user_api_key=user_api_key,
).inc(response_cost) user_api_key_alias=user_api_key_alias,
model=model,
user_api_team=user_api_team,
user_api_team_alias=user_api_team_alias,
user_id=user_id,
)
# remaining budget metrics
self._increment_remaining_budget_metrics(
user_api_team=user_api_team,
user_api_team_alias=user_api_team_alias,
user_api_key=user_api_key,
user_api_key_alias=user_api_key_alias,
litellm_params=litellm_params,
)
# set proxy virtual key rpm/tpm metrics
self._set_virtual_key_rate_limit_metrics(
user_api_key=user_api_key,
user_api_key_alias=user_api_key_alias,
kwargs=kwargs,
metadata=_metadata,
)
# set latency metrics
self._set_latency_metrics(
kwargs=kwargs,
model=model,
user_api_key=user_api_key,
user_api_key_alias=user_api_key_alias,
user_api_team=user_api_team,
user_api_team_alias=user_api_team_alias,
standard_logging_payload=standard_logging_payload,
)
# set x-ratelimit headers
self.set_llm_deployment_success_metrics(
kwargs, start_time, end_time, output_tokens
)
pass
def _increment_token_metrics(
self,
standard_logging_payload: StandardLoggingPayload,
end_user_id: Optional[str],
user_api_key: Optional[str],
user_api_key_alias: Optional[str],
model: Optional[str],
user_api_team: Optional[str],
user_api_team_alias: Optional[str],
user_id: Optional[str],
):
# token metrics
self.litellm_tokens_metric.labels( self.litellm_tokens_metric.labels(
end_user_id, end_user_id,
user_api_key, user_api_key,
@ -450,6 +477,34 @@ class PrometheusLogger(CustomLogger):
user_id, user_id,
).inc(standard_logging_payload["completion_tokens"]) ).inc(standard_logging_payload["completion_tokens"])
def _increment_remaining_budget_metrics(
self,
user_api_team: Optional[str],
user_api_team_alias: Optional[str],
user_api_key: Optional[str],
user_api_key_alias: Optional[str],
litellm_params: dict,
):
_team_spend = litellm_params.get("metadata", {}).get(
"user_api_key_team_spend", None
)
_team_max_budget = litellm_params.get("metadata", {}).get(
"user_api_key_team_max_budget", None
)
_remaining_team_budget = self._safe_get_remaining_budget(
max_budget=_team_max_budget, spend=_team_spend
)
_api_key_spend = litellm_params.get("metadata", {}).get(
"user_api_key_spend", None
)
_api_key_max_budget = litellm_params.get("metadata", {}).get(
"user_api_key_max_budget", None
)
_remaining_api_key_budget = self._safe_get_remaining_budget(
max_budget=_api_key_max_budget, spend=_api_key_spend
)
# Remaining Budget Metrics
self.litellm_remaining_team_budget_metric.labels( self.litellm_remaining_team_budget_metric.labels(
user_api_team, user_api_team_alias user_api_team, user_api_team_alias
).set(_remaining_team_budget) ).set(_remaining_team_budget)
@ -458,6 +513,47 @@ class PrometheusLogger(CustomLogger):
user_api_key, user_api_key_alias user_api_key, user_api_key_alias
).set(_remaining_api_key_budget) ).set(_remaining_api_key_budget)
def _increment_top_level_request_and_spend_metrics(
self,
end_user_id: Optional[str],
user_api_key: Optional[str],
user_api_key_alias: Optional[str],
model: Optional[str],
user_api_team: Optional[str],
user_api_team_alias: Optional[str],
user_id: Optional[str],
response_cost: float,
):
self.litellm_requests_metric.labels(
end_user_id,
user_api_key,
user_api_key_alias,
model,
user_api_team,
user_api_team_alias,
user_id,
).inc()
self.litellm_spend_metric.labels(
end_user_id,
user_api_key,
user_api_key_alias,
model,
user_api_team,
user_api_team_alias,
user_id,
).inc(response_cost)
def _set_virtual_key_rate_limit_metrics(
self,
user_api_key: Optional[str],
user_api_key_alias: Optional[str],
kwargs: dict,
metadata: dict,
):
from litellm.proxy.common_utils.callback_utils import (
get_model_group_from_litellm_kwargs,
)
# Set remaining rpm/tpm for API Key + model # Set remaining rpm/tpm for API Key + model
# see parallel_request_limiter.py - variables are set there # see parallel_request_limiter.py - variables are set there
model_group = get_model_group_from_litellm_kwargs(kwargs) model_group = get_model_group_from_litellm_kwargs(kwargs)
@ -466,10 +562,8 @@ class PrometheusLogger(CustomLogger):
) )
remaining_tokens_variable_name = f"litellm-key-remaining-tokens-{model_group}" remaining_tokens_variable_name = f"litellm-key-remaining-tokens-{model_group}"
remaining_requests = _metadata.get( remaining_requests = metadata.get(remaining_requests_variable_name, sys.maxsize)
remaining_requests_variable_name, sys.maxsize remaining_tokens = metadata.get(remaining_tokens_variable_name, sys.maxsize)
)
remaining_tokens = _metadata.get(remaining_tokens_variable_name, sys.maxsize)
self.litellm_remaining_api_key_requests_for_model.labels( self.litellm_remaining_api_key_requests_for_model.labels(
user_api_key, user_api_key_alias, model_group user_api_key, user_api_key_alias, model_group
@ -479,9 +573,20 @@ class PrometheusLogger(CustomLogger):
user_api_key, user_api_key_alias, model_group user_api_key, user_api_key_alias, model_group
).set(remaining_tokens) ).set(remaining_tokens)
def _set_latency_metrics(
self,
kwargs: dict,
model: Optional[str],
user_api_key: Optional[str],
user_api_key_alias: Optional[str],
user_api_team: Optional[str],
user_api_team_alias: Optional[str],
standard_logging_payload: StandardLoggingPayload,
):
# latency metrics # latency metrics
total_time: timedelta = kwargs.get("end_time") - kwargs.get("start_time") model_parameters: dict = standard_logging_payload["model_parameters"]
total_time_seconds = total_time.total_seconds() end_time: datetime = kwargs.get("end_time") or datetime.now()
start_time: Optional[datetime] = kwargs.get("start_time")
api_call_start_time = kwargs.get("api_call_start_time", None) api_call_start_time = kwargs.get("api_call_start_time", None)
completion_start_time = kwargs.get("completion_start_time", None) completion_start_time = kwargs.get("completion_start_time", None)
@ -509,9 +614,7 @@ class PrometheusLogger(CustomLogger):
if api_call_start_time is not None and isinstance( if api_call_start_time is not None and isinstance(
api_call_start_time, datetime api_call_start_time, datetime
): ):
api_call_total_time: timedelta = ( api_call_total_time: timedelta = end_time - api_call_start_time
kwargs.get("end_time") - api_call_start_time
)
api_call_total_time_seconds = api_call_total_time.total_seconds() api_call_total_time_seconds = api_call_total_time.total_seconds()
self.litellm_llm_api_latency_metric.labels( self.litellm_llm_api_latency_metric.labels(
model, model,
@ -521,7 +624,10 @@ class PrometheusLogger(CustomLogger):
user_api_team_alias, user_api_team_alias,
).observe(api_call_total_time_seconds) ).observe(api_call_total_time_seconds)
# log metrics # total request latency
if start_time is not None and isinstance(start_time, datetime):
total_time: timedelta = end_time - start_time
total_time_seconds = total_time.total_seconds()
self.litellm_request_total_latency_metric.labels( self.litellm_request_total_latency_metric.labels(
model, model,
user_api_key, user_api_key,
@ -530,12 +636,6 @@ class PrometheusLogger(CustomLogger):
user_api_team_alias, user_api_team_alias,
).observe(total_time_seconds) ).observe(total_time_seconds)
# set x-ratelimit headers
self.set_llm_deployment_success_metrics(
kwargs, start_time, end_time, output_tokens
)
pass
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
from litellm.types.utils import StandardLoggingPayload from litellm.types.utils import StandardLoggingPayload
@ -1007,9 +1107,8 @@ class PrometheusLogger(CustomLogger):
litellm_model_name, model_id, api_base, api_provider, exception_status litellm_model_name, model_id, api_base, api_provider, exception_status
).inc() ).inc()
def _safe_get_remaining_budget(
def safe_get_remaining_budget( self, max_budget: Optional[float], spend: Optional[float]
max_budget: Optional[float], spend: Optional[float]
) -> float: ) -> float:
if max_budget is None: if max_budget is None:
return float("inf") return float("inf")

View file

@ -16,6 +16,14 @@ from litellm import completion
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.integrations.prometheus import PrometheusLogger from litellm.integrations.prometheus import PrometheusLogger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.types.utils import (
StandardLoggingPayload,
StandardLoggingMetadata,
StandardLoggingHiddenParams,
StandardLoggingModelInformation,
)
from unittest.mock import MagicMock, patch
from datetime import datetime, timedelta
verbose_logger.setLevel(logging.DEBUG) verbose_logger.setLevel(logging.DEBUG)

View file

@ -0,0 +1,344 @@
import io
import os
import sys
sys.path.insert(0, os.path.abspath("../.."))
import asyncio
import logging
import uuid
import pytest
from prometheus_client import REGISTRY, CollectorRegistry
import litellm
from litellm import completion
from litellm._logging import verbose_logger
from litellm.integrations.prometheus import PrometheusLogger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.types.utils import (
StandardLoggingPayload,
StandardLoggingMetadata,
StandardLoggingHiddenParams,
StandardLoggingModelInformation,
)
import pytest
from unittest.mock import MagicMock, patch
from datetime import datetime, timedelta
from litellm.integrations.prometheus import PrometheusLogger
verbose_logger.setLevel(logging.DEBUG)
litellm.set_verbose = True
import time
@pytest.fixture
def prometheus_logger():
collectors = list(REGISTRY._collector_to_names.keys())
for collector in collectors:
REGISTRY.unregister(collector)
return PrometheusLogger()
def create_standard_logging_payload() -> StandardLoggingPayload:
return StandardLoggingPayload(
id="test_id",
call_type="completion",
response_cost=0.1,
response_cost_failure_debug_info=None,
status="success",
total_tokens=30,
prompt_tokens=20,
completion_tokens=10,
startTime=1234567890.0,
endTime=1234567891.0,
completionStartTime=1234567890.5,
model_map_information=StandardLoggingModelInformation(
model_map_key="gpt-3.5-turbo", model_map_value=None
),
model="gpt-3.5-turbo",
model_id="model-123",
model_group="openai-gpt",
api_base="https://api.openai.com",
metadata=StandardLoggingMetadata(
user_api_key_hash="test_hash",
user_api_key_alias="test_alias",
user_api_key_team_id="test_team",
user_api_key_user_id="test_user",
user_api_key_team_alias="test_team_alias",
spend_logs_metadata=None,
requester_ip_address="127.0.0.1",
requester_metadata=None,
),
cache_hit=False,
cache_key=None,
saved_cache_cost=0.0,
request_tags=[],
end_user=None,
requester_ip_address="127.0.0.1",
messages=[{"role": "user", "content": "Hello, world!"}],
response={"choices": [{"message": {"content": "Hi there!"}}]},
error_str=None,
model_parameters={"stream": True},
hidden_params=StandardLoggingHiddenParams(
model_id="model-123",
cache_key=None,
api_base="https://api.openai.com",
response_cost="0.1",
additional_headers=None,
),
)
def test_safe_get_remaining_budget(prometheus_logger):
assert prometheus_logger._safe_get_remaining_budget(100, 30) == 70
assert prometheus_logger._safe_get_remaining_budget(100, None) == 100
assert prometheus_logger._safe_get_remaining_budget(None, 30) == float("inf")
assert prometheus_logger._safe_get_remaining_budget(None, None) == float("inf")
@pytest.mark.asyncio
async def test_async_log_success_event(prometheus_logger):
standard_logging_object = create_standard_logging_payload()
kwargs = {
"model": "gpt-3.5-turbo",
"litellm_params": {
"metadata": {
"user_api_key": "test_key",
"user_api_key_user_id": "test_user",
"user_api_key_team_id": "test_team",
}
},
"start_time": datetime.now(),
"completion_start_time": datetime.now(),
"api_call_start_time": datetime.now(),
"end_time": datetime.now() + timedelta(seconds=1),
"standard_logging_object": standard_logging_object,
}
response_obj = MagicMock()
# Mock the prometheus client methods
# High Level Metrics - request/spend
prometheus_logger.litellm_requests_metric = MagicMock()
prometheus_logger.litellm_spend_metric = MagicMock()
# Token Metrics
prometheus_logger.litellm_tokens_metric = MagicMock()
prometheus_logger.litellm_input_tokens_metric = MagicMock()
prometheus_logger.litellm_output_tokens_metric = MagicMock()
# Remaining Budget Metrics
prometheus_logger.litellm_remaining_team_budget_metric = MagicMock()
prometheus_logger.litellm_remaining_api_key_budget_metric = MagicMock()
# Virtual Key Rate limit Metrics
prometheus_logger.litellm_remaining_api_key_requests_for_model = MagicMock()
prometheus_logger.litellm_remaining_api_key_tokens_for_model = MagicMock()
# Latency Metrics
prometheus_logger.litellm_llm_api_time_to_first_token_metric = MagicMock()
prometheus_logger.litellm_llm_api_latency_metric = MagicMock()
prometheus_logger.litellm_request_total_latency_metric = MagicMock()
await prometheus_logger.async_log_success_event(
kwargs, response_obj, kwargs["start_time"], kwargs["end_time"]
)
# Assert that the metrics were incremented
prometheus_logger.litellm_requests_metric.labels.assert_called()
prometheus_logger.litellm_spend_metric.labels.assert_called()
# Token Metrics
prometheus_logger.litellm_tokens_metric.labels.assert_called()
prometheus_logger.litellm_input_tokens_metric.labels.assert_called()
prometheus_logger.litellm_output_tokens_metric.labels.assert_called()
# Remaining Budget Metrics
prometheus_logger.litellm_remaining_team_budget_metric.labels.assert_called()
prometheus_logger.litellm_remaining_api_key_budget_metric.labels.assert_called()
# Virtual Key Rate limit Metrics
prometheus_logger.litellm_remaining_api_key_requests_for_model.labels.assert_called()
prometheus_logger.litellm_remaining_api_key_tokens_for_model.labels.assert_called()
# Latency Metrics
prometheus_logger.litellm_llm_api_time_to_first_token_metric.labels.assert_called()
prometheus_logger.litellm_llm_api_latency_metric.labels.assert_called()
prometheus_logger.litellm_request_total_latency_metric.labels.assert_called()
def test_increment_token_metrics(prometheus_logger):
"""
Test the increment_token_metrics method
input, output, and total tokens metrics are incremented by the values in the standard logging payload
"""
prometheus_logger.litellm_tokens_metric = MagicMock()
prometheus_logger.litellm_input_tokens_metric = MagicMock()
prometheus_logger.litellm_output_tokens_metric = MagicMock()
standard_logging_payload = create_standard_logging_payload()
standard_logging_payload["total_tokens"] = 100
standard_logging_payload["prompt_tokens"] = 50
standard_logging_payload["completion_tokens"] = 50
prometheus_logger._increment_token_metrics(
standard_logging_payload,
end_user_id="user1",
user_api_key="key1",
user_api_key_alias="alias1",
model="gpt-3.5-turbo",
user_api_team="team1",
user_api_team_alias="team_alias1",
user_id="user1",
)
prometheus_logger.litellm_tokens_metric.labels.assert_called_once_with(
"user1", "key1", "alias1", "gpt-3.5-turbo", "team1", "team_alias1", "user1"
)
prometheus_logger.litellm_tokens_metric.labels().inc.assert_called_once_with(100)
prometheus_logger.litellm_input_tokens_metric.labels.assert_called_once_with(
"user1", "key1", "alias1", "gpt-3.5-turbo", "team1", "team_alias1", "user1"
)
prometheus_logger.litellm_input_tokens_metric.labels().inc.assert_called_once_with(
50
)
prometheus_logger.litellm_output_tokens_metric.labels.assert_called_once_with(
"user1", "key1", "alias1", "gpt-3.5-turbo", "team1", "team_alias1", "user1"
)
prometheus_logger.litellm_output_tokens_metric.labels().inc.assert_called_once_with(
50
)
def test_increment_remaining_budget_metrics(prometheus_logger):
"""
Test the increment_remaining_budget_metrics method
team and api key budget metrics are set to the difference between max budget and spend
"""
prometheus_logger.litellm_remaining_team_budget_metric = MagicMock()
prometheus_logger.litellm_remaining_api_key_budget_metric = MagicMock()
litellm_params = {
"metadata": {
"user_api_key_team_spend": 50,
"user_api_key_team_max_budget": 100,
"user_api_key_spend": 25,
"user_api_key_max_budget": 75,
}
}
prometheus_logger._increment_remaining_budget_metrics(
user_api_team="team1",
user_api_team_alias="team_alias1",
user_api_key="key1",
user_api_key_alias="alias1",
litellm_params=litellm_params,
)
prometheus_logger.litellm_remaining_team_budget_metric.labels.assert_called_once_with(
"team1", "team_alias1"
)
prometheus_logger.litellm_remaining_team_budget_metric.labels().set.assert_called_once_with(
50
)
prometheus_logger.litellm_remaining_api_key_budget_metric.labels.assert_called_once_with(
"key1", "alias1"
)
prometheus_logger.litellm_remaining_api_key_budget_metric.labels().set.assert_called_once_with(
50
)
def test_set_latency_metrics(prometheus_logger):
"""
Test the set_latency_metrics method
time to first token, llm api latency, and request total latency metrics are set to the values in the standard logging payload
"""
standard_logging_payload = create_standard_logging_payload()
standard_logging_payload["model_parameters"] = {"stream": True}
prometheus_logger.litellm_llm_api_time_to_first_token_metric = MagicMock()
prometheus_logger.litellm_llm_api_latency_metric = MagicMock()
prometheus_logger.litellm_request_total_latency_metric = MagicMock()
now = datetime.now()
kwargs = {
"end_time": now, # when the request ends
"start_time": now - timedelta(seconds=2), # when the request starts
"api_call_start_time": now - timedelta(seconds=1.5), # when the api call starts
"completion_start_time": now
- timedelta(seconds=1), # when the completion starts
}
prometheus_logger._set_latency_metrics(
kwargs=kwargs,
model="gpt-3.5-turbo",
user_api_key="key1",
user_api_key_alias="alias1",
user_api_team="team1",
user_api_team_alias="team_alias1",
standard_logging_payload=standard_logging_payload,
)
# completion_start_time - api_call_start_time
prometheus_logger.litellm_llm_api_time_to_first_token_metric.labels.assert_called_once_with(
"gpt-3.5-turbo", "key1", "alias1", "team1", "team_alias1"
)
prometheus_logger.litellm_llm_api_time_to_first_token_metric.labels().observe.assert_called_once_with(
0.5
)
# end_time - api_call_start_time
prometheus_logger.litellm_llm_api_latency_metric.labels.assert_called_once_with(
"gpt-3.5-turbo", "key1", "alias1", "team1", "team_alias1"
)
prometheus_logger.litellm_llm_api_latency_metric.labels().observe.assert_called_once_with(
1.5
)
# total latency for the request
prometheus_logger.litellm_request_total_latency_metric.labels.assert_called_once_with(
"gpt-3.5-turbo", "key1", "alias1", "team1", "team_alias1"
)
prometheus_logger.litellm_request_total_latency_metric.labels().observe.assert_called_once_with(
2.0
)
def test_increment_top_level_request_and_spend_metrics(prometheus_logger):
"""
Test the increment_top_level_request_and_spend_metrics method
- litellm_requests_metric is incremented by 1
- litellm_spend_metric is incremented by the response cost in the standard logging payload
"""
prometheus_logger.litellm_requests_metric = MagicMock()
prometheus_logger.litellm_spend_metric = MagicMock()
prometheus_logger._increment_top_level_request_and_spend_metrics(
end_user_id="user1",
user_api_key="key1",
user_api_key_alias="alias1",
model="gpt-3.5-turbo",
user_api_team="team1",
user_api_team_alias="team_alias1",
user_id="user1",
response_cost=0.1,
)
prometheus_logger.litellm_requests_metric.labels.assert_called_once_with(
"user1", "key1", "alias1", "gpt-3.5-turbo", "team1", "team_alias1", "user1"
)
prometheus_logger.litellm_requests_metric.labels().inc.assert_called_once()
prometheus_logger.litellm_spend_metric.labels.assert_called_once_with(
"user1", "key1", "alias1", "gpt-3.5-turbo", "team1", "team_alias1", "user1"
)
prometheus_logger.litellm_spend_metric.labels().inc.assert_called_once_with(0.1)

View file

@ -17,6 +17,7 @@ import litellm
import asyncio import asyncio
import logging import logging
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from prometheus_client import REGISTRY, CollectorRegistry
from litellm.integrations.lago import LagoLogger from litellm.integrations.lago import LagoLogger
from litellm.integrations.openmeter import OpenMeterLogger from litellm.integrations.openmeter import OpenMeterLogger
@ -33,6 +34,12 @@ from litellm.integrations.argilla import ArgillaLogger
from litellm.proxy.hooks.dynamic_rate_limiter import _PROXY_DynamicRateLimitHandler from litellm.proxy.hooks.dynamic_rate_limiter import _PROXY_DynamicRateLimitHandler
from unittest.mock import patch from unittest.mock import patch
# clear prometheus collectors / registry
collectors = list(REGISTRY._collector_to_names.keys())
for collector in collectors:
REGISTRY.unregister(collector)
######################################
callback_class_str_to_classType = { callback_class_str_to_classType = {
"lago": LagoLogger, "lago": LagoLogger,
"openmeter": OpenMeterLogger, "openmeter": OpenMeterLogger,
@ -111,6 +118,11 @@ async def use_callback_in_llm_call(
elif callback == "openmeter": elif callback == "openmeter":
# it's currently handled in jank way, TODO: fix openmete and then actually run it's test # it's currently handled in jank way, TODO: fix openmete and then actually run it's test
return return
elif callback == "prometheus":
# pytest teardown - clear existing prometheus collectors
collectors = list(REGISTRY._collector_to_names.keys())
for collector in collectors:
REGISTRY.unregister(collector)
# Mock the httpx call for Argilla dataset retrieval # Mock the httpx call for Argilla dataset retrieval
if callback == "argilla": if callback == "argilla":