forked from phoenix/litellm-mirror
fix(utils.py): allow disabling end user cost tracking with new param
Allows proxy admin to disable cost tracking for end user - keeps prometheus metrics small
This commit is contained in:
parent
1c9a8c0b68
commit
5a698c678a
6 changed files with 40 additions and 9 deletions
|
@ -280,6 +280,7 @@ default_max_internal_user_budget: Optional[float] = None
|
||||||
max_internal_user_budget: Optional[float] = None
|
max_internal_user_budget: Optional[float] = None
|
||||||
internal_user_budget_duration: Optional[str] = None
|
internal_user_budget_duration: Optional[str] = None
|
||||||
max_end_user_budget: Optional[float] = None
|
max_end_user_budget: Optional[float] = None
|
||||||
|
disable_end_user_cost_tracking: Optional[bool] = None
|
||||||
#### REQUEST PRIORITIZATION ####
|
#### REQUEST PRIORITIZATION ####
|
||||||
priority_reservation: Optional[Dict[str, float]] = None
|
priority_reservation: Optional[Dict[str, float]] = None
|
||||||
#### RELIABILITY ####
|
#### RELIABILITY ####
|
||||||
|
|
|
@ -18,6 +18,7 @@ from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.types.integrations.prometheus import *
|
from litellm.types.integrations.prometheus import *
|
||||||
from litellm.types.utils import StandardLoggingPayload
|
from litellm.types.utils import StandardLoggingPayload
|
||||||
|
from litellm.utils import get_end_user_id_for_cost_tracking
|
||||||
|
|
||||||
|
|
||||||
class PrometheusLogger(CustomLogger):
|
class PrometheusLogger(CustomLogger):
|
||||||
|
@ -364,8 +365,7 @@ class PrometheusLogger(CustomLogger):
|
||||||
model = kwargs.get("model", "")
|
model = kwargs.get("model", "")
|
||||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||||
_metadata = litellm_params.get("metadata", {})
|
_metadata = litellm_params.get("metadata", {})
|
||||||
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
|
||||||
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
|
||||||
user_id = standard_logging_payload["metadata"]["user_api_key_user_id"]
|
user_id = standard_logging_payload["metadata"]["user_api_key_user_id"]
|
||||||
user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"]
|
user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"]
|
||||||
user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"]
|
user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"]
|
||||||
|
@ -664,13 +664,11 @@ class PrometheusLogger(CustomLogger):
|
||||||
|
|
||||||
# unpack kwargs
|
# unpack kwargs
|
||||||
model = kwargs.get("model", "")
|
model = kwargs.get("model", "")
|
||||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
|
||||||
standard_logging_payload: StandardLoggingPayload = kwargs.get(
|
standard_logging_payload: StandardLoggingPayload = kwargs.get(
|
||||||
"standard_logging_object", {}
|
"standard_logging_object", {}
|
||||||
)
|
)
|
||||||
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||||
|
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
|
||||||
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
|
||||||
user_id = standard_logging_payload["metadata"]["user_api_key_user_id"]
|
user_id = standard_logging_payload["metadata"]["user_api_key_user_id"]
|
||||||
user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"]
|
user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"]
|
||||||
user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"]
|
user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"]
|
||||||
|
|
|
@ -13,4 +13,6 @@ model_list:
|
||||||
vertex_ai_location: "us-east5"
|
vertex_ai_location: "us-east5"
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
success_callback: ["langfuse"]
|
success_callback: ["langfuse"]
|
||||||
|
callbacks: ["prometheus"]
|
||||||
|
# disable_end_user_cost_tracking: true
|
||||||
|
|
|
@ -268,6 +268,7 @@ from litellm.types.llms.anthropic import (
|
||||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||||
from litellm.types.router import RouterGeneralSettings
|
from litellm.types.router import RouterGeneralSettings
|
||||||
from litellm.types.utils import StandardLoggingPayload
|
from litellm.types.utils import StandardLoggingPayload
|
||||||
|
from litellm.utils import get_end_user_id_for_cost_tracking
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from litellm._version import version
|
from litellm._version import version
|
||||||
|
@ -763,8 +764,7 @@ async def _PROXY_track_cost_callback(
|
||||||
)
|
)
|
||||||
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs)
|
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs)
|
||||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||||
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
|
||||||
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
|
||||||
metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
|
metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
|
||||||
user_id = metadata.get("user_api_key_user_id", None)
|
user_id = metadata.get("user_api_key_user_id", None)
|
||||||
team_id = metadata.get("user_api_key_team_id", None)
|
team_id = metadata.get("user_api_key_team_id", None)
|
||||||
|
|
|
@ -6170,3 +6170,13 @@ class ProviderConfigManager:
|
||||||
return litellm.GroqChatConfig()
|
return litellm.GroqChatConfig()
|
||||||
|
|
||||||
return OpenAIGPTConfig()
|
return OpenAIGPTConfig()
|
||||||
|
|
||||||
|
|
||||||
|
def get_end_user_id_for_cost_tracking(litellm_params: dict) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Used for enforcing `disable_end_user_cost_tracking` param.
|
||||||
|
"""
|
||||||
|
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
||||||
|
if litellm.disable_end_user_cost_tracking:
|
||||||
|
return None
|
||||||
|
return proxy_server_request.get("body", {}).get("user", None)
|
||||||
|
|
|
@ -1012,3 +1012,23 @@ def test_models_by_provider():
|
||||||
|
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
assert provider in models_by_provider.keys()
|
assert provider in models_by_provider.keys()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"litellm_params, disable_end_user_cost_tracking, expected_end_user_id",
|
||||||
|
[
|
||||||
|
({}, False, None),
|
||||||
|
({"proxy_server_request": {"body": {"user": "123"}}}, False, "123"),
|
||||||
|
({"proxy_server_request": {"body": {"user": "123"}}}, True, None),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_get_end_user_id_for_cost_tracking(
|
||||||
|
litellm_params, disable_end_user_cost_tracking, expected_end_user_id
|
||||||
|
):
|
||||||
|
from litellm.utils import get_end_user_id_for_cost_tracking
|
||||||
|
|
||||||
|
litellm.disable_end_user_cost_tracking = disable_end_user_cost_tracking
|
||||||
|
assert (
|
||||||
|
get_end_user_id_for_cost_tracking(litellm_params=litellm_params)
|
||||||
|
== expected_end_user_id
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue