Litellm dev 12 25 2024 p3 (#7421)

* refactor(prometheus.py): refactor to use a factory method for setting label values

allows for enforcing end user id disabling on prometheus e2e

* fix: fix linting error

* fix(prometheus.py): ensure label factory drops end-user value if disabled by user

* fix(prometheus.py): specify service_type in end user tracking get

* test: fix test

* test: add unit test for prometheus factory

* test: improve test (cover flag not set scenario)

* test(test_prometheus.py): e2e test covering if 'end_user_id' shows up in testing if disabled

scrapes the `/metrics` endpoint and scans text to check if id appears in emitted metrics

* fix(prometheus.py): stringify status code before logging it
This commit is contained in:
Krish Dholakia 2024-12-25 18:54:24 -08:00 committed by GitHub
parent 760328b6ad
commit 21e8f212d7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 223 additions and 90 deletions

View file

@ -3,7 +3,7 @@
# On success, log events to Prometheus # On success, log events to Prometheus
import sys import sys
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Optional from typing import List, Optional
from litellm._logging import print_verbose, verbose_logger from litellm._logging import print_verbose, verbose_logger
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
@ -52,48 +52,21 @@ class PrometheusLogger(CustomLogger):
self.litellm_proxy_total_requests_metric = Counter( self.litellm_proxy_total_requests_metric = Counter(
name="litellm_proxy_total_requests_metric", name="litellm_proxy_total_requests_metric",
documentation="Total number of requests made to the proxy server - track number of client side requests", documentation="Total number of requests made to the proxy server - track number of client side requests",
labelnames=[ labelnames=PrometheusMetricLabels.litellm_proxy_total_requests_metric.value,
"end_user",
"hashed_api_key",
"api_key_alias",
REQUESTED_MODEL,
"team",
"team_alias",
"user",
STATUS_CODE,
],
) )
# request latency metrics # request latency metrics
self.litellm_request_total_latency_metric = Histogram( self.litellm_request_total_latency_metric = Histogram(
"litellm_request_total_latency_metric", "litellm_request_total_latency_metric",
"Total latency (seconds) for a request to LiteLLM", "Total latency (seconds) for a request to LiteLLM",
labelnames=[ labelnames=PrometheusMetricLabels.litellm_request_total_latency_metric.value,
UserAPIKeyLabelNames.END_USER.value,
UserAPIKeyLabelNames.API_KEY_HASH.value,
UserAPIKeyLabelNames.API_KEY_ALIAS.value,
REQUESTED_MODEL,
UserAPIKeyLabelNames.TEAM.value,
UserAPIKeyLabelNames.TEAM_ALIAS.value,
UserAPIKeyLabelNames.USER.value,
UserAPIKeyLabelNames.v1_LITELLM_MODEL_NAME.value,
],
buckets=LATENCY_BUCKETS, buckets=LATENCY_BUCKETS,
) )
self.litellm_llm_api_latency_metric = Histogram( self.litellm_llm_api_latency_metric = Histogram(
"litellm_llm_api_latency_metric", "litellm_llm_api_latency_metric",
"Total latency (seconds) for a models LLM API call", "Total latency (seconds) for a models LLM API call",
labelnames=[ labelnames=PrometheusMetricLabels.litellm_llm_api_latency_metric.value,
UserAPIKeyLabelNames.v1_LITELLM_MODEL_NAME.value,
UserAPIKeyLabelNames.API_KEY_HASH.value,
UserAPIKeyLabelNames.API_KEY_ALIAS.value,
UserAPIKeyLabelNames.TEAM.value,
UserAPIKeyLabelNames.TEAM_ALIAS.value,
UserAPIKeyLabelNames.REQUESTED_MODEL.value,
UserAPIKeyLabelNames.END_USER.value,
UserAPIKeyLabelNames.USER.value,
],
buckets=LATENCY_BUCKETS, buckets=LATENCY_BUCKETS,
) )
@ -419,6 +392,17 @@ class PrometheusLogger(CustomLogger):
f"inside track_prometheus_metrics, model {model}, response_cost {response_cost}, tokens_used {tokens_used}, end_user_id {end_user_id}, user_api_key {user_api_key}" f"inside track_prometheus_metrics, model {model}, response_cost {response_cost}, tokens_used {tokens_used}, end_user_id {end_user_id}, user_api_key {user_api_key}"
) )
enum_values = UserAPIKeyLabelValues(
end_user=end_user_id,
hashed_api_key=user_api_key,
api_key_alias=user_api_key_alias,
requested_model=model,
team=user_api_team,
team_alias=user_api_team_alias,
user=user_id,
status_code="200",
)
if ( if (
user_api_key is not None user_api_key is not None
and isinstance(user_api_key, str) and isinstance(user_api_key, str)
@ -494,16 +478,11 @@ class PrometheusLogger(CustomLogger):
if ( if (
standard_logging_payload["stream"] is True standard_logging_payload["stream"] is True
): # log successful streaming requests from logging event hook. ): # log successful streaming requests from logging event hook.
self.litellm_proxy_total_requests_metric.labels( _labels = prometheus_label_factory(
end_user=end_user_id, supported_enum_labels=PrometheusMetricLabels.litellm_proxy_total_requests_metric.value,
hashed_api_key=user_api_key, enum_values=enum_values,
api_key_alias=user_api_key_alias, )
requested_model=model, self.litellm_proxy_total_requests_metric.labels(**_labels).inc()
team=user_api_team,
team_alias=user_api_team_alias,
user=user_id,
status_code="200",
).inc()
def _increment_token_metrics( def _increment_token_metrics(
self, self,
@ -683,6 +662,24 @@ class PrometheusLogger(CustomLogger):
completion_start_time = kwargs.get("completion_start_time", None) completion_start_time = kwargs.get("completion_start_time", None)
enum_values = UserAPIKeyLabelValues(
end_user=standard_logging_payload["metadata"]["user_api_key_end_user_id"],
user=standard_logging_payload["metadata"]["user_api_key_user_id"],
hashed_api_key=user_api_key,
api_key_alias=user_api_key_alias,
team=user_api_team,
team_alias=user_api_team_alias,
requested_model=standard_logging_payload["model_group"],
model=model,
litellm_model_name=standard_logging_payload["model_group"],
tags=standard_logging_payload["request_tags"],
model_id=standard_logging_payload["model_id"],
api_base=standard_logging_payload["api_base"],
api_provider=standard_logging_payload["custom_llm_provider"],
exception_status=None,
exception_class=None,
)
if ( if (
completion_start_time is not None completion_start_time is not None
and isinstance(completion_start_time, datetime) and isinstance(completion_start_time, datetime)
@ -708,46 +705,25 @@ class PrometheusLogger(CustomLogger):
): ):
api_call_total_time: timedelta = end_time - api_call_start_time api_call_total_time: timedelta = end_time - api_call_start_time
api_call_total_time_seconds = api_call_total_time.total_seconds() api_call_total_time_seconds = api_call_total_time.total_seconds()
self.litellm_llm_api_latency_metric.labels( _labels = prometheus_label_factory(
**{ supported_enum_labels=PrometheusMetricLabels.litellm_llm_api_latency_metric.value,
UserAPIKeyLabelNames.v1_LITELLM_MODEL_NAME.value: model, enum_values=enum_values,
UserAPIKeyLabelNames.API_KEY_HASH.value: user_api_key, )
UserAPIKeyLabelNames.API_KEY_ALIAS.value: user_api_key_alias, self.litellm_llm_api_latency_metric.labels(**_labels).observe(
UserAPIKeyLabelNames.TEAM.value: user_api_team, api_call_total_time_seconds
UserAPIKeyLabelNames.TEAM_ALIAS.value: user_api_team_alias, )
UserAPIKeyLabelNames.USER.value: standard_logging_payload[
"metadata"
]["user_api_key_user_id"],
UserAPIKeyLabelNames.END_USER.value: standard_logging_payload[
"metadata"
]["user_api_key_end_user_id"],
UserAPIKeyLabelNames.REQUESTED_MODEL.value: standard_logging_payload[
"model_group"
],
}
).observe(api_call_total_time_seconds)
# total request latency # total request latency
if start_time is not None and isinstance(start_time, datetime): if start_time is not None and isinstance(start_time, datetime):
total_time: timedelta = end_time - start_time total_time: timedelta = end_time - start_time
total_time_seconds = total_time.total_seconds() total_time_seconds = total_time.total_seconds()
_labels = prometheus_label_factory(
self.litellm_request_total_latency_metric.labels( supported_enum_labels=PrometheusMetricLabels.litellm_request_total_latency_metric.value,
**{ enum_values=enum_values,
UserAPIKeyLabelNames.END_USER.value: standard_logging_payload[ )
"metadata" self.litellm_request_total_latency_metric.labels(**_labels).observe(
]["user_api_key_end_user_id"], total_time_seconds
UserAPIKeyLabelNames.API_KEY_HASH.value: user_api_key, )
UserAPIKeyLabelNames.API_KEY_ALIAS.value: user_api_key_alias,
REQUESTED_MODEL: standard_logging_payload["model_group"],
UserAPIKeyLabelNames.TEAM.value: user_api_team,
UserAPIKeyLabelNames.TEAM_ALIAS.value: user_api_team_alias,
UserAPIKeyLabelNames.USER.value: standard_logging_payload[
"metadata"
]["user_api_key_user_id"],
UserAPIKeyLabelNames.v1_LITELLM_MODEL_NAME.value: model,
}
).observe(total_time_seconds)
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
from litellm.types.utils import StandardLoggingPayload from litellm.types.utils import StandardLoggingPayload
@ -813,6 +789,18 @@ class PrometheusLogger(CustomLogger):
] + EXCEPTION_LABELS, ] + EXCEPTION_LABELS,
""" """
try: try:
enum_values = UserAPIKeyLabelValues(
end_user=user_api_key_dict.end_user_id,
user=user_api_key_dict.user_id,
hashed_api_key=user_api_key_dict.api_key,
api_key_alias=user_api_key_dict.key_alias,
team=user_api_key_dict.team_id,
team_alias=user_api_key_dict.team_alias,
requested_model=request_data.get("model", ""),
status_code=str(getattr(original_exception, "status_code", None)),
exception_class=str(original_exception.__class__.__name__),
)
self.litellm_proxy_failed_requests_metric.labels( self.litellm_proxy_failed_requests_metric.labels(
end_user=user_api_key_dict.end_user_id, end_user=user_api_key_dict.end_user_id,
hashed_api_key=user_api_key_dict.api_key, hashed_api_key=user_api_key_dict.api_key,
@ -825,16 +813,11 @@ class PrometheusLogger(CustomLogger):
exception_class=str(original_exception.__class__.__name__), exception_class=str(original_exception.__class__.__name__),
).inc() ).inc()
self.litellm_proxy_total_requests_metric.labels( _labels = prometheus_label_factory(
end_user=user_api_key_dict.end_user_id, supported_enum_labels=PrometheusMetricLabels.litellm_proxy_total_requests_metric.value,
hashed_api_key=user_api_key_dict.api_key, enum_values=enum_values,
api_key_alias=user_api_key_dict.key_alias, )
requested_model=request_data.get("model", ""), self.litellm_proxy_total_requests_metric.labels(**_labels).inc()
team=user_api_key_dict.team_id,
team_alias=user_api_key_dict.team_alias,
user=user_api_key_dict.user_id,
status_code=str(getattr(original_exception, "status_code", None)),
).inc()
pass pass
except Exception as e: except Exception as e:
verbose_logger.exception( verbose_logger.exception(
@ -849,7 +832,7 @@ class PrometheusLogger(CustomLogger):
Proxy level tracking - triggered when the proxy responds with a success response to the client Proxy level tracking - triggered when the proxy responds with a success response to the client
""" """
try: try:
self.litellm_proxy_total_requests_metric.labels( enum_values = UserAPIKeyLabelValues(
end_user=user_api_key_dict.end_user_id, end_user=user_api_key_dict.end_user_id,
hashed_api_key=user_api_key_dict.api_key, hashed_api_key=user_api_key_dict.api_key,
api_key_alias=user_api_key_dict.key_alias, api_key_alias=user_api_key_dict.key_alias,
@ -858,7 +841,12 @@ class PrometheusLogger(CustomLogger):
team_alias=user_api_key_dict.team_alias, team_alias=user_api_key_dict.team_alias,
user=user_api_key_dict.user_id, user=user_api_key_dict.user_id,
status_code="200", status_code="200",
).inc() )
_labels = prometheus_label_factory(
supported_enum_labels=PrometheusMetricLabels.litellm_proxy_total_requests_metric.value,
enum_values=enum_values,
)
self.litellm_proxy_total_requests_metric.labels(**_labels).inc()
except Exception as e: except Exception as e:
verbose_logger.exception( verbose_logger.exception(
"prometheus Layer Error(): Exception occured - {}".format(str(e)) "prometheus Layer Error(): Exception occured - {}".format(str(e))
@ -1278,3 +1266,30 @@ class PrometheusLogger(CustomLogger):
return max_budget return max_budget
return max_budget - spend return max_budget - spend
def prometheus_label_factory(
supported_enum_labels: List[str], enum_values: UserAPIKeyLabelValues
) -> dict:
"""
Returns a dictionary of label + values for prometheus.
Ensures end_user param is not sent to prometheus if it is not supported.
"""
# Extract dictionary from Pydantic object
enum_dict = enum_values.model_dump()
# Filter supported labels
filtered_labels = {
label: value
for label, value in enum_dict.items()
if label in supported_enum_labels
}
if UserAPIKeyLabelNames.END_USER.value in filtered_labels:
filtered_labels["end_user"] = get_end_user_id_for_cost_tracking(
litellm_params={"user_api_key_end_user_id": enum_values.end_user},
service_type="prometheus",
)
return filtered_labels

View file

@ -12,7 +12,7 @@ import time
import traceback import traceback
import uuid import uuid
from datetime import datetime as dt_object from datetime import datetime as dt_object
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, cast
from pydantic import BaseModel from pydantic import BaseModel
@ -2983,6 +2983,7 @@ def get_standard_logging_object_payload(
cache_hit=cache_hit, cache_hit=cache_hit,
stream=stream, stream=stream,
status=status, status=status,
custom_llm_provider=cast(Optional[str], kwargs.get("custom_llm_provider")),
saved_cache_cost=saved_cache_cost, saved_cache_cost=saved_cache_cost,
startTime=start_time_float, startTime=start_time_float,
endTime=end_time_float, endTime=end_time_float,

View file

@ -16,4 +16,5 @@ model_list:
mode: audio_transcription mode: audio_transcription
litellm_settings: litellm_settings:
callbacks: ["prometheus"] callbacks: ["prometheus"]
disable_end_user_cost_tracking_prometheus_only: true

View file

@ -37,6 +37,7 @@ model_list:
litellm_settings: litellm_settings:
cache: true cache: true
callbacks: ["otel", "prometheus"] callbacks: ["otel", "prometheus"]
disable_end_user_cost_tracking_prometheus_only: True
guardrails: guardrails:
- guardrail_name: "aporia-pre-guard" - guardrail_name: "aporia-pre-guard"

View file

@ -1,4 +1,7 @@
from enum import Enum from enum import Enum
from typing import List, Optional, Union
from pydantic import BaseModel, Field
REQUESTED_MODEL = "requested_model" REQUESTED_MODEL = "requested_model"
EXCEPTION_STATUS = "exception_status" EXCEPTION_STATUS = "exception_status"
@ -61,3 +64,82 @@ class UserAPIKeyLabelNames(Enum):
API_PROVIDER = "api_provider" API_PROVIDER = "api_provider"
EXCEPTION_STATUS = EXCEPTION_STATUS EXCEPTION_STATUS = EXCEPTION_STATUS
EXCEPTION_CLASS = EXCEPTION_CLASS EXCEPTION_CLASS = EXCEPTION_CLASS
STATUS_CODE = "status_code"
class PrometheusMetricLabels(Enum):
litellm_llm_api_latency_metric = [
UserAPIKeyLabelNames.v1_LITELLM_MODEL_NAME.value,
UserAPIKeyLabelNames.API_KEY_HASH.value,
UserAPIKeyLabelNames.API_KEY_ALIAS.value,
UserAPIKeyLabelNames.TEAM.value,
UserAPIKeyLabelNames.TEAM_ALIAS.value,
UserAPIKeyLabelNames.REQUESTED_MODEL.value,
UserAPIKeyLabelNames.END_USER.value,
UserAPIKeyLabelNames.USER.value,
]
litellm_request_total_latency_metric = [
UserAPIKeyLabelNames.END_USER.value,
UserAPIKeyLabelNames.API_KEY_HASH.value,
UserAPIKeyLabelNames.API_KEY_ALIAS.value,
UserAPIKeyLabelNames.REQUESTED_MODEL.value,
UserAPIKeyLabelNames.TEAM.value,
UserAPIKeyLabelNames.TEAM_ALIAS.value,
UserAPIKeyLabelNames.USER.value,
UserAPIKeyLabelNames.v1_LITELLM_MODEL_NAME.value,
]
litellm_proxy_total_requests_metric = [
UserAPIKeyLabelNames.END_USER.value,
UserAPIKeyLabelNames.API_KEY_HASH.value,
UserAPIKeyLabelNames.API_KEY_ALIAS.value,
UserAPIKeyLabelNames.REQUESTED_MODEL.value,
UserAPIKeyLabelNames.TEAM.value,
UserAPIKeyLabelNames.TEAM_ALIAS.value,
UserAPIKeyLabelNames.USER.value,
UserAPIKeyLabelNames.STATUS_CODE.value,
]
from typing import List, Optional
from pydantic import BaseModel, Field
class UserAPIKeyLabelValues(BaseModel):
end_user: Optional[str] = None
user: Optional[str] = None
hashed_api_key: Optional[str] = None
api_key_alias: Optional[str] = None
team: Optional[str] = None
team_alias: Optional[str] = None
requested_model: Optional[str] = None
model: Optional[str] = None
litellm_model_name: Optional[str] = None
tags: List[str] = []
model_id: Optional[str] = None
api_base: Optional[str] = None
api_provider: Optional[str] = None
exception_status: Optional[str] = None
exception_class: Optional[str] = None
status_code: Optional[str] = None
class Config:
fields = {
"end_user": {"alias": UserAPIKeyLabelNames.END_USER},
"user": {"alias": UserAPIKeyLabelNames.USER},
"hashed_api_key": {"alias": UserAPIKeyLabelNames.API_KEY_HASH},
"api_key_alias": {"alias": UserAPIKeyLabelNames.API_KEY_ALIAS},
"team": {"alias": UserAPIKeyLabelNames.TEAM},
"team_alias": {"alias": UserAPIKeyLabelNames.TEAM_ALIAS},
"requested_model": {"alias": UserAPIKeyLabelNames.REQUESTED_MODEL},
"model": {"alias": UserAPIKeyLabelNames.v1_LITELLM_MODEL_NAME},
"litellm_model_name": {"alias": UserAPIKeyLabelNames.v2_LITELLM_MODEL_NAME},
"model_id": {"alias": UserAPIKeyLabelNames.MODEL_ID},
"api_base": {"alias": UserAPIKeyLabelNames.API_BASE},
"api_provider": {"alias": UserAPIKeyLabelNames.API_PROVIDER},
"exception_status": {"alias": UserAPIKeyLabelNames.EXCEPTION_STATUS},
"exception_class": {"alias": UserAPIKeyLabelNames.EXCEPTION_CLASS},
"status_code": {"alias": UserAPIKeyLabelNames.STATUS_CODE},
}

View file

@ -1513,6 +1513,7 @@ class StandardLoggingPayload(TypedDict):
StandardLoggingModelCostFailureDebugInformation StandardLoggingModelCostFailureDebugInformation
] ]
status: StandardLoggingPayloadStatus status: StandardLoggingPayloadStatus
custom_llm_provider: Optional[str]
total_tokens: int total_tokens: int
prompt_tokens: int prompt_tokens: int
completion_tokens: int completion_tokens: int

View file

@ -62,6 +62,7 @@ def create_standard_logging_payload() -> StandardLoggingPayload:
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
model_id="model-123", model_id="model-123",
model_group="openai-gpt", model_group="openai-gpt",
custom_llm_provider="openai",
api_base="https://api.openai.com", api_base="https://api.openai.com",
metadata=StandardLoggingMetadata( metadata=StandardLoggingMetadata(
user_api_key_hash="test_hash", user_api_key_hash="test_hash",
@ -793,3 +794,29 @@ def test_increment_deployment_cooled_down(prometheus_logger):
"gpt-3.5-turbo", "model-123", "https://api.openai.com", "openai", "429" "gpt-3.5-turbo", "model-123", "https://api.openai.com", "openai", "429"
) )
prometheus_logger.litellm_deployment_cooled_down.labels().inc.assert_called_once() prometheus_logger.litellm_deployment_cooled_down.labels().inc.assert_called_once()
@pytest.mark.parametrize("disable_end_user_tracking", [True, False])
def test_prometheus_factory(monkeypatch, disable_end_user_tracking):
from litellm.integrations.prometheus import prometheus_label_factory
from litellm.types.integrations.prometheus import UserAPIKeyLabelValues
monkeypatch.setattr(
"litellm.disable_end_user_cost_tracking_prometheus_only",
disable_end_user_tracking,
)
enum_values = UserAPIKeyLabelValues(
end_user="test_end_user",
api_key_hash="test_hash",
api_key_alias="test_alias",
)
supported_labels = ["end_user", "api_key_hash", "api_key_alias"]
returned_dict = prometheus_label_factory(
supported_enum_labels=supported_labels, enum_values=enum_values
)
if disable_end_user_tracking:
assert returned_dict["end_user"] == None
else:
assert returned_dict["end_user"] == "test_end_user"

View file

@ -13,6 +13,8 @@ sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
END_USER_ID = "my-test-user-34"
async def make_bad_chat_completion_request(session, key): async def make_bad_chat_completion_request(session, key):
url = "http://0.0.0.0:4000/chat/completions" url = "http://0.0.0.0:4000/chat/completions"
@ -41,6 +43,7 @@ async def make_good_chat_completion_request(session, key):
"model": "fake-openai-endpoint", "model": "fake-openai-endpoint",
"messages": [{"role": "user", "content": f"Hello {uuid.uuid4()}"}], "messages": [{"role": "user", "content": f"Hello {uuid.uuid4()}"}],
"tags": ["teamB"], "tags": ["teamB"],
"user": END_USER_ID, # test if disable end user tracking for prometheus works
} }
async with session.post(url, headers=headers, json=data) as response: async with session.post(url, headers=headers, json=data) as response:
status = response.status status = response.status
@ -143,6 +146,8 @@ async def test_proxy_success_metrics():
print("/metrics", metrics) print("/metrics", metrics)
assert END_USER_ID not in metrics
# Check if the success metric is present and correct # Check if the success metric is present and correct
assert ( assert (
'litellm_request_total_latency_metric_bucket{api_key_alias="None",end_user="None",hashed_api_key="88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b",le="0.005",model="fake",requested_model="fake-openai-endpoint",team="None",team_alias="None",user="default_user_id"}' 'litellm_request_total_latency_metric_bucket{api_key_alias="None",end_user="None",hashed_api_key="88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b",le="0.005",model="fake",requested_model="fake-openai-endpoint",team="None",team_alias="None",user="default_user_id"}'