From 9864459f4debc8236e3873f2be046c042187f5ab Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 4 Nov 2024 11:01:58 -0800 Subject: [PATCH] fix use standard_logging_payload for track cost callback --- .../proxy/hooks/proxy_track_cost_callback.py | 70 ++++++++++--------- 1 file changed, 37 insertions(+), 33 deletions(-) diff --git a/litellm/proxy/hooks/proxy_track_cost_callback.py b/litellm/proxy/hooks/proxy_track_cost_callback.py index 40e339516..6d5b441cd 100644 --- a/litellm/proxy/hooks/proxy_track_cost_callback.py +++ b/litellm/proxy/hooks/proxy_track_cost_callback.py @@ -9,6 +9,7 @@ Updates cost for the following in LiteLLM DB: import asyncio import traceback +from typing import Optional import litellm from litellm._logging import verbose_proxy_logger @@ -17,6 +18,7 @@ from litellm.proxy.utils import ( get_litellm_metadata_from_kwargs, log_to_opentelemetry, ) +from litellm.types.utils import StandardLoggingPayload @log_to_opentelemetry @@ -49,31 +51,36 @@ async def _PROXY_track_cost_callback( kwargs, completion_response, ) - verbose_proxy_logger.debug( - f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}" - ) parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs) - litellm_params = kwargs.get("litellm_params", {}) or {} - proxy_server_request = litellm_params.get("proxy_server_request") or {} - end_user_id = proxy_server_request.get("body", {}).get("user", None) - metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs) - user_id = metadata.get("user_api_key_user_id", None) - team_id = metadata.get("user_api_key_team_id", None) - org_id = metadata.get("user_api_key_org_id", None) - key_alias = metadata.get("user_api_key_alias", None) - end_user_max_budget = metadata.get("user_api_end_user_max_budget", None) - if kwargs.get("response_cost", None) is not None: - response_cost = kwargs["response_cost"] - user_api_key = metadata.get("user_api_key", None) - if kwargs.get("cache_hit", False) is True: - response_cost = 0.0 - verbose_proxy_logger.info( - f"Cache Hit: response_cost {response_cost}, for user_id {user_id}" - ) - - verbose_proxy_logger.debug( - f"user_api_key {user_api_key}, prisma_client: {prisma_client}" + standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object", None + ) + if standard_logging_payload is None: + raise ValueError( + "standard_logging_payload is none in kwargs, cannot track cost without it" ) + end_user_id = standard_logging_payload.get("end_user") + user_api_key = standard_logging_payload.get("metadata", {}).get( + "user_api_key_hash" + ) + user_id = standard_logging_payload.get("metadata", {}).get( + "user_api_key_user_id" + ) + team_id = standard_logging_payload.get("metadata", {}).get( + "user_api_key_team_id" + ) + org_id = standard_logging_payload.get("metadata", {}).get("user_api_key_org_id") + key_alias = standard_logging_payload.get("metadata", {}).get( + "user_api_key_alias" + ) + response_cost: Optional[float] = standard_logging_payload.get("response_cost") + + # end_user_max_budget = metadata.get("user_api_end_user_max_budget", None) + end_user_max_budget = standard_logging_payload.get("metadata", {}).get( + "user_api_end_user_max_budget" + ) + + if response_cost is not None: if user_api_key is not None or user_id is not None or team_id is not None: ## UPDATE DATABASE await update_database( @@ -113,16 +120,13 @@ async def _PROXY_track_cost_callback( "User API key and team id and user id missing from custom callback." ) else: - if kwargs["stream"] is not True or ( - kwargs["stream"] is True and "complete_streaming_response" in kwargs - ): - cost_tracking_failure_debug_info = kwargs.get( - "response_cost_failure_debug_information" - ) - model = kwargs.get("model") - raise Exception( - f"Cost tracking failed for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing" - ) + cost_tracking_failure_debug_info = standard_logging_payload.get( + "response_cost_failure_debug_info" + ) + model = kwargs.get("model") + raise ValueError( + f"Failed to write cost to DB, for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing" + ) except Exception as e: error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}" model = kwargs.get("model", "")