Litellm dev 12 25 2024 p3 (#7421)

* refactor(prometheus.py): refactor to use a factory method for setting label values

allows for enforcing end user id disabling on prometheus e2e

* fix: fix linting error

* fix(prometheus.py): ensure label factory drops end-user value if disabled by user

* fix(prometheus.py): specify service_type in end user tracking get

* test: fix test

* test: add unit test for prometheus factory

* test: improve test (cover flag not set scenario)

* test(test_prometheus.py): e2e test covering if 'end_user_id' shows up in testing if disabled

scrapes the `/metrics` endpoint and scans text to check if id appears in emitted metrics

* fix(prometheus.py): stringify status code before logging it
This commit is contained in:
Krish Dholakia 2024-12-25 18:54:24 -08:00 committed by GitHub
parent 760328b6ad
commit 21e8f212d7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 223 additions and 90 deletions

View file

@ -3,7 +3,7 @@
# On success, log events to Prometheus
import sys
from datetime import datetime, timedelta
from typing import Optional
from typing import List, Optional
from litellm._logging import print_verbose, verbose_logger
from litellm.integrations.custom_logger import CustomLogger
@ -52,48 +52,21 @@ class PrometheusLogger(CustomLogger):
self.litellm_proxy_total_requests_metric = Counter(
name="litellm_proxy_total_requests_metric",
documentation="Total number of requests made to the proxy server - track number of client side requests",
labelnames=[
"end_user",
"hashed_api_key",
"api_key_alias",
REQUESTED_MODEL,
"team",
"team_alias",
"user",
STATUS_CODE,
],
labelnames=PrometheusMetricLabels.litellm_proxy_total_requests_metric.value,
)
# request latency metrics
self.litellm_request_total_latency_metric = Histogram(
"litellm_request_total_latency_metric",
"Total latency (seconds) for a request to LiteLLM",
labelnames=[
UserAPIKeyLabelNames.END_USER.value,
UserAPIKeyLabelNames.API_KEY_HASH.value,
UserAPIKeyLabelNames.API_KEY_ALIAS.value,
REQUESTED_MODEL,
UserAPIKeyLabelNames.TEAM.value,
UserAPIKeyLabelNames.TEAM_ALIAS.value,
UserAPIKeyLabelNames.USER.value,
UserAPIKeyLabelNames.v1_LITELLM_MODEL_NAME.value,
],
labelnames=PrometheusMetricLabels.litellm_request_total_latency_metric.value,
buckets=LATENCY_BUCKETS,
)
self.litellm_llm_api_latency_metric = Histogram(
"litellm_llm_api_latency_metric",
"Total latency (seconds) for a models LLM API call",
labelnames=[
UserAPIKeyLabelNames.v1_LITELLM_MODEL_NAME.value,
UserAPIKeyLabelNames.API_KEY_HASH.value,
UserAPIKeyLabelNames.API_KEY_ALIAS.value,
UserAPIKeyLabelNames.TEAM.value,
UserAPIKeyLabelNames.TEAM_ALIAS.value,
UserAPIKeyLabelNames.REQUESTED_MODEL.value,
UserAPIKeyLabelNames.END_USER.value,
UserAPIKeyLabelNames.USER.value,
],
labelnames=PrometheusMetricLabels.litellm_llm_api_latency_metric.value,
buckets=LATENCY_BUCKETS,
)
@ -419,6 +392,17 @@ class PrometheusLogger(CustomLogger):
f"inside track_prometheus_metrics, model {model}, response_cost {response_cost}, tokens_used {tokens_used}, end_user_id {end_user_id}, user_api_key {user_api_key}"
)
enum_values = UserAPIKeyLabelValues(
end_user=end_user_id,
hashed_api_key=user_api_key,
api_key_alias=user_api_key_alias,
requested_model=model,
team=user_api_team,
team_alias=user_api_team_alias,
user=user_id,
status_code="200",
)
if (
user_api_key is not None
and isinstance(user_api_key, str)
@ -494,16 +478,11 @@ class PrometheusLogger(CustomLogger):
if (
standard_logging_payload["stream"] is True
): # log successful streaming requests from logging event hook.
self.litellm_proxy_total_requests_metric.labels(
end_user=end_user_id,
hashed_api_key=user_api_key,
api_key_alias=user_api_key_alias,
requested_model=model,
team=user_api_team,
team_alias=user_api_team_alias,
user=user_id,
status_code="200",
).inc()
_labels = prometheus_label_factory(
supported_enum_labels=PrometheusMetricLabels.litellm_proxy_total_requests_metric.value,
enum_values=enum_values,
)
self.litellm_proxy_total_requests_metric.labels(**_labels).inc()
def _increment_token_metrics(
self,
@ -683,6 +662,24 @@ class PrometheusLogger(CustomLogger):
completion_start_time = kwargs.get("completion_start_time", None)
enum_values = UserAPIKeyLabelValues(
end_user=standard_logging_payload["metadata"]["user_api_key_end_user_id"],
user=standard_logging_payload["metadata"]["user_api_key_user_id"],
hashed_api_key=user_api_key,
api_key_alias=user_api_key_alias,
team=user_api_team,
team_alias=user_api_team_alias,
requested_model=standard_logging_payload["model_group"],
model=model,
litellm_model_name=standard_logging_payload["model_group"],
tags=standard_logging_payload["request_tags"],
model_id=standard_logging_payload["model_id"],
api_base=standard_logging_payload["api_base"],
api_provider=standard_logging_payload["custom_llm_provider"],
exception_status=None,
exception_class=None,
)
if (
completion_start_time is not None
and isinstance(completion_start_time, datetime)
@ -708,46 +705,25 @@ class PrometheusLogger(CustomLogger):
):
api_call_total_time: timedelta = end_time - api_call_start_time
api_call_total_time_seconds = api_call_total_time.total_seconds()
self.litellm_llm_api_latency_metric.labels(
**{
UserAPIKeyLabelNames.v1_LITELLM_MODEL_NAME.value: model,
UserAPIKeyLabelNames.API_KEY_HASH.value: user_api_key,
UserAPIKeyLabelNames.API_KEY_ALIAS.value: user_api_key_alias,
UserAPIKeyLabelNames.TEAM.value: user_api_team,
UserAPIKeyLabelNames.TEAM_ALIAS.value: user_api_team_alias,
UserAPIKeyLabelNames.USER.value: standard_logging_payload[
"metadata"
]["user_api_key_user_id"],
UserAPIKeyLabelNames.END_USER.value: standard_logging_payload[
"metadata"
]["user_api_key_end_user_id"],
UserAPIKeyLabelNames.REQUESTED_MODEL.value: standard_logging_payload[
"model_group"
],
}
).observe(api_call_total_time_seconds)
_labels = prometheus_label_factory(
supported_enum_labels=PrometheusMetricLabels.litellm_llm_api_latency_metric.value,
enum_values=enum_values,
)
self.litellm_llm_api_latency_metric.labels(**_labels).observe(
api_call_total_time_seconds
)
# total request latency
if start_time is not None and isinstance(start_time, datetime):
total_time: timedelta = end_time - start_time
total_time_seconds = total_time.total_seconds()
self.litellm_request_total_latency_metric.labels(
**{
UserAPIKeyLabelNames.END_USER.value: standard_logging_payload[
"metadata"
]["user_api_key_end_user_id"],
UserAPIKeyLabelNames.API_KEY_HASH.value: user_api_key,
UserAPIKeyLabelNames.API_KEY_ALIAS.value: user_api_key_alias,
REQUESTED_MODEL: standard_logging_payload["model_group"],
UserAPIKeyLabelNames.TEAM.value: user_api_team,
UserAPIKeyLabelNames.TEAM_ALIAS.value: user_api_team_alias,
UserAPIKeyLabelNames.USER.value: standard_logging_payload[
"metadata"
]["user_api_key_user_id"],
UserAPIKeyLabelNames.v1_LITELLM_MODEL_NAME.value: model,
}
).observe(total_time_seconds)
_labels = prometheus_label_factory(
supported_enum_labels=PrometheusMetricLabels.litellm_request_total_latency_metric.value,
enum_values=enum_values,
)
self.litellm_request_total_latency_metric.labels(**_labels).observe(
total_time_seconds
)
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
from litellm.types.utils import StandardLoggingPayload
@ -813,6 +789,18 @@ class PrometheusLogger(CustomLogger):
] + EXCEPTION_LABELS,
"""
try:
enum_values = UserAPIKeyLabelValues(
end_user=user_api_key_dict.end_user_id,
user=user_api_key_dict.user_id,
hashed_api_key=user_api_key_dict.api_key,
api_key_alias=user_api_key_dict.key_alias,
team=user_api_key_dict.team_id,
team_alias=user_api_key_dict.team_alias,
requested_model=request_data.get("model", ""),
status_code=str(getattr(original_exception, "status_code", None)),
exception_class=str(original_exception.__class__.__name__),
)
self.litellm_proxy_failed_requests_metric.labels(
end_user=user_api_key_dict.end_user_id,
hashed_api_key=user_api_key_dict.api_key,
@ -825,16 +813,11 @@ class PrometheusLogger(CustomLogger):
exception_class=str(original_exception.__class__.__name__),
).inc()
self.litellm_proxy_total_requests_metric.labels(
end_user=user_api_key_dict.end_user_id,
hashed_api_key=user_api_key_dict.api_key,
api_key_alias=user_api_key_dict.key_alias,
requested_model=request_data.get("model", ""),
team=user_api_key_dict.team_id,
team_alias=user_api_key_dict.team_alias,
user=user_api_key_dict.user_id,
status_code=str(getattr(original_exception, "status_code", None)),
).inc()
_labels = prometheus_label_factory(
supported_enum_labels=PrometheusMetricLabels.litellm_proxy_total_requests_metric.value,
enum_values=enum_values,
)
self.litellm_proxy_total_requests_metric.labels(**_labels).inc()
pass
except Exception as e:
verbose_logger.exception(
@ -849,7 +832,7 @@ class PrometheusLogger(CustomLogger):
Proxy level tracking - triggered when the proxy responds with a success response to the client
"""
try:
self.litellm_proxy_total_requests_metric.labels(
enum_values = UserAPIKeyLabelValues(
end_user=user_api_key_dict.end_user_id,
hashed_api_key=user_api_key_dict.api_key,
api_key_alias=user_api_key_dict.key_alias,
@ -858,7 +841,12 @@ class PrometheusLogger(CustomLogger):
team_alias=user_api_key_dict.team_alias,
user=user_api_key_dict.user_id,
status_code="200",
).inc()
)
_labels = prometheus_label_factory(
supported_enum_labels=PrometheusMetricLabels.litellm_proxy_total_requests_metric.value,
enum_values=enum_values,
)
self.litellm_proxy_total_requests_metric.labels(**_labels).inc()
except Exception as e:
verbose_logger.exception(
"prometheus Layer Error(): Exception occured - {}".format(str(e))
@ -1278,3 +1266,30 @@ class PrometheusLogger(CustomLogger):
return max_budget
return max_budget - spend
def prometheus_label_factory(
supported_enum_labels: List[str], enum_values: UserAPIKeyLabelValues
) -> dict:
"""
Returns a dictionary of label + values for prometheus.
Ensures end_user param is not sent to prometheus if it is not supported.
"""
# Extract dictionary from Pydantic object
enum_dict = enum_values.model_dump()
# Filter supported labels
filtered_labels = {
label: value
for label, value in enum_dict.items()
if label in supported_enum_labels
}
if UserAPIKeyLabelNames.END_USER.value in filtered_labels:
filtered_labels["end_user"] = get_end_user_id_for_cost_tracking(
litellm_params={"user_api_key_end_user_id": enum_values.end_user},
service_type="prometheus",
)
return filtered_labels