diff --git a/litellm/__init__.py b/litellm/__init__.py index c978b24ee..e6dc61dc7 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -280,6 +280,7 @@ default_max_internal_user_budget: Optional[float] = None max_internal_user_budget: Optional[float] = None internal_user_budget_duration: Optional[str] = None max_end_user_budget: Optional[float] = None +disable_end_user_cost_tracking: Optional[bool] = None #### REQUEST PRIORITIZATION #### priority_reservation: Optional[Dict[str, float]] = None #### RELIABILITY #### diff --git a/litellm/integrations/prometheus.py b/litellm/integrations/prometheus.py index bb28719a3..1460a1d7f 100644 --- a/litellm/integrations/prometheus.py +++ b/litellm/integrations/prometheus.py @@ -18,6 +18,7 @@ from litellm.integrations.custom_logger import CustomLogger from litellm.proxy._types import UserAPIKeyAuth from litellm.types.integrations.prometheus import * from litellm.types.utils import StandardLoggingPayload +from litellm.utils import get_end_user_id_for_cost_tracking class PrometheusLogger(CustomLogger): @@ -364,8 +365,7 @@ class PrometheusLogger(CustomLogger): model = kwargs.get("model", "") litellm_params = kwargs.get("litellm_params", {}) or {} _metadata = litellm_params.get("metadata", {}) - proxy_server_request = litellm_params.get("proxy_server_request") or {} - end_user_id = proxy_server_request.get("body", {}).get("user", None) + end_user_id = get_end_user_id_for_cost_tracking(litellm_params) user_id = standard_logging_payload["metadata"]["user_api_key_user_id"] user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"] user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"] @@ -664,13 +664,11 @@ class PrometheusLogger(CustomLogger): # unpack kwargs model = kwargs.get("model", "") - litellm_params = kwargs.get("litellm_params", {}) or {} standard_logging_payload: StandardLoggingPayload = kwargs.get( "standard_logging_object", {} ) - proxy_server_request = litellm_params.get("proxy_server_request") or {} - - end_user_id = proxy_server_request.get("body", {}).get("user", None) + litellm_params = kwargs.get("litellm_params", {}) or {} + end_user_id = get_end_user_id_for_cost_tracking(litellm_params) user_id = standard_logging_payload["metadata"]["user_api_key_user_id"] user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"] user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"] diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 2c25b61db..f12226736 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -13,4 +13,6 @@ model_list: vertex_ai_location: "us-east5" litellm_settings: - success_callback: ["langfuse"] \ No newline at end of file + success_callback: ["langfuse"] + callbacks: ["prometheus"] + # disable_end_user_cost_tracking: true diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 9d7c120a7..70bf5b523 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -268,6 +268,7 @@ from litellm.types.llms.anthropic import ( from litellm.types.llms.openai import HttpxBinaryResponseContent from litellm.types.router import RouterGeneralSettings from litellm.types.utils import StandardLoggingPayload +from litellm.utils import get_end_user_id_for_cost_tracking try: from litellm._version import version @@ -763,8 +764,7 @@ async def _PROXY_track_cost_callback( ) parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs) litellm_params = kwargs.get("litellm_params", {}) or {} - proxy_server_request = litellm_params.get("proxy_server_request") or {} - end_user_id = proxy_server_request.get("body", {}).get("user", None) + end_user_id = get_end_user_id_for_cost_tracking(litellm_params) metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs) user_id = metadata.get("user_api_key_user_id", None) team_id = metadata.get("user_api_key_team_id", None) diff --git a/litellm/utils.py b/litellm/utils.py index 003971142..262af3418 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -6170,3 +6170,13 @@ class ProviderConfigManager: return litellm.GroqChatConfig() return OpenAIGPTConfig() + + +def get_end_user_id_for_cost_tracking(litellm_params: dict) -> Optional[str]: + """ + Used for enforcing `disable_end_user_cost_tracking` param. + """ + proxy_server_request = litellm_params.get("proxy_server_request") or {} + if litellm.disable_end_user_cost_tracking: + return None + return proxy_server_request.get("body", {}).get("user", None) diff --git a/tests/local_testing/test_utils.py b/tests/local_testing/test_utils.py index 52946ca30..cf1db27e8 100644 --- a/tests/local_testing/test_utils.py +++ b/tests/local_testing/test_utils.py @@ -1012,3 +1012,23 @@ def test_models_by_provider(): for provider in providers: assert provider in models_by_provider.keys() + + +@pytest.mark.parametrize( + "litellm_params, disable_end_user_cost_tracking, expected_end_user_id", + [ + ({}, False, None), + ({"proxy_server_request": {"body": {"user": "123"}}}, False, "123"), + ({"proxy_server_request": {"body": {"user": "123"}}}, True, None), + ], +) +def test_get_end_user_id_for_cost_tracking( + litellm_params, disable_end_user_cost_tracking, expected_end_user_id +): + from litellm.utils import get_end_user_id_for_cost_tracking + + litellm.disable_end_user_cost_tracking = disable_end_user_cost_tracking + assert ( + get_end_user_id_for_cost_tracking(litellm_params=litellm_params) + == expected_end_user_id + )