fix(utils.py): allow disabling end user cost tracking with new param

Allows proxy admin to disable cost tracking for end user - keeps prometheus metrics small
This commit is contained in:
Krrish Dholakia 2024-11-22 16:41:58 +05:30
parent 1c9a8c0b68
commit 5a698c678a
6 changed files with 40 additions and 9 deletions

View file

@ -280,6 +280,7 @@ default_max_internal_user_budget: Optional[float] = None
max_internal_user_budget: Optional[float] = None
internal_user_budget_duration: Optional[str] = None
max_end_user_budget: Optional[float] = None
disable_end_user_cost_tracking: Optional[bool] = None
#### REQUEST PRIORITIZATION ####
priority_reservation: Optional[Dict[str, float]] = None
#### RELIABILITY ####

View file

@ -18,6 +18,7 @@ from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.integrations.prometheus import *
from litellm.types.utils import StandardLoggingPayload
from litellm.utils import get_end_user_id_for_cost_tracking
class PrometheusLogger(CustomLogger):
@ -364,8 +365,7 @@ class PrometheusLogger(CustomLogger):
model = kwargs.get("model", "")
litellm_params = kwargs.get("litellm_params", {}) or {}
_metadata = litellm_params.get("metadata", {})
proxy_server_request = litellm_params.get("proxy_server_request") or {}
end_user_id = proxy_server_request.get("body", {}).get("user", None)
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
user_id = standard_logging_payload["metadata"]["user_api_key_user_id"]
user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"]
user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"]
@ -664,13 +664,11 @@ class PrometheusLogger(CustomLogger):
# unpack kwargs
model = kwargs.get("model", "")
litellm_params = kwargs.get("litellm_params", {}) or {}
standard_logging_payload: StandardLoggingPayload = kwargs.get(
"standard_logging_object", {}
)
proxy_server_request = litellm_params.get("proxy_server_request") or {}
end_user_id = proxy_server_request.get("body", {}).get("user", None)
litellm_params = kwargs.get("litellm_params", {}) or {}
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
user_id = standard_logging_payload["metadata"]["user_api_key_user_id"]
user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"]
user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"]

View file

@ -13,4 +13,6 @@ model_list:
vertex_ai_location: "us-east5"
litellm_settings:
success_callback: ["langfuse"]
success_callback: ["langfuse"]
callbacks: ["prometheus"]
# disable_end_user_cost_tracking: true

View file

@ -268,6 +268,7 @@ from litellm.types.llms.anthropic import (
from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.types.router import RouterGeneralSettings
from litellm.types.utils import StandardLoggingPayload
from litellm.utils import get_end_user_id_for_cost_tracking
try:
from litellm._version import version
@ -763,8 +764,7 @@ async def _PROXY_track_cost_callback(
)
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs)
litellm_params = kwargs.get("litellm_params", {}) or {}
proxy_server_request = litellm_params.get("proxy_server_request") or {}
end_user_id = proxy_server_request.get("body", {}).get("user", None)
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
user_id = metadata.get("user_api_key_user_id", None)
team_id = metadata.get("user_api_key_team_id", None)

View file

@ -6170,3 +6170,13 @@ class ProviderConfigManager:
return litellm.GroqChatConfig()
return OpenAIGPTConfig()
def get_end_user_id_for_cost_tracking(litellm_params: dict) -> Optional[str]:
"""
Used for enforcing `disable_end_user_cost_tracking` param.
"""
proxy_server_request = litellm_params.get("proxy_server_request") or {}
if litellm.disable_end_user_cost_tracking:
return None
return proxy_server_request.get("body", {}).get("user", None)

View file

@ -1012,3 +1012,23 @@ def test_models_by_provider():
for provider in providers:
assert provider in models_by_provider.keys()
@pytest.mark.parametrize(
"litellm_params, disable_end_user_cost_tracking, expected_end_user_id",
[
({}, False, None),
({"proxy_server_request": {"body": {"user": "123"}}}, False, "123"),
({"proxy_server_request": {"body": {"user": "123"}}}, True, None),
],
)
def test_get_end_user_id_for_cost_tracking(
litellm_params, disable_end_user_cost_tracking, expected_end_user_id
):
from litellm.utils import get_end_user_id_for_cost_tracking
litellm.disable_end_user_cost_tracking = disable_end_user_cost_tracking
assert (
get_end_user_id_for_cost_tracking(litellm_params=litellm_params)
== expected_end_user_id
)