fix use standard_logging_payload for track cost callback

This commit is contained in:
Ishaan Jaff 2024-11-04 11:01:58 -08:00
parent 2639c1971d
commit 9864459f4d

View file

@ -9,6 +9,7 @@ Updates cost for the following in LiteLLM DB:
import asyncio import asyncio
import traceback import traceback
from typing import Optional
import litellm import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
@ -17,6 +18,7 @@ from litellm.proxy.utils import (
get_litellm_metadata_from_kwargs, get_litellm_metadata_from_kwargs,
log_to_opentelemetry, log_to_opentelemetry,
) )
from litellm.types.utils import StandardLoggingPayload
@log_to_opentelemetry @log_to_opentelemetry
@ -49,31 +51,36 @@ async def _PROXY_track_cost_callback(
kwargs, kwargs,
completion_response, completion_response,
) )
verbose_proxy_logger.debug(
f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}"
)
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs) parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs)
litellm_params = kwargs.get("litellm_params", {}) or {} standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
proxy_server_request = litellm_params.get("proxy_server_request") or {} "standard_logging_object", None
end_user_id = proxy_server_request.get("body", {}).get("user", None) )
metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs) if standard_logging_payload is None:
user_id = metadata.get("user_api_key_user_id", None) raise ValueError(
team_id = metadata.get("user_api_key_team_id", None) "standard_logging_payload is none in kwargs, cannot track cost without it"
org_id = metadata.get("user_api_key_org_id", None)
key_alias = metadata.get("user_api_key_alias", None)
end_user_max_budget = metadata.get("user_api_end_user_max_budget", None)
if kwargs.get("response_cost", None) is not None:
response_cost = kwargs["response_cost"]
user_api_key = metadata.get("user_api_key", None)
if kwargs.get("cache_hit", False) is True:
response_cost = 0.0
verbose_proxy_logger.info(
f"Cache Hit: response_cost {response_cost}, for user_id {user_id}"
)
verbose_proxy_logger.debug(
f"user_api_key {user_api_key}, prisma_client: {prisma_client}"
) )
end_user_id = standard_logging_payload.get("end_user")
user_api_key = standard_logging_payload.get("metadata", {}).get(
"user_api_key_hash"
)
user_id = standard_logging_payload.get("metadata", {}).get(
"user_api_key_user_id"
)
team_id = standard_logging_payload.get("metadata", {}).get(
"user_api_key_team_id"
)
org_id = standard_logging_payload.get("metadata", {}).get("user_api_key_org_id")
key_alias = standard_logging_payload.get("metadata", {}).get(
"user_api_key_alias"
)
response_cost: Optional[float] = standard_logging_payload.get("response_cost")
# end_user_max_budget = metadata.get("user_api_end_user_max_budget", None)
end_user_max_budget = standard_logging_payload.get("metadata", {}).get(
"user_api_end_user_max_budget"
)
if response_cost is not None:
if user_api_key is not None or user_id is not None or team_id is not None: if user_api_key is not None or user_id is not None or team_id is not None:
## UPDATE DATABASE ## UPDATE DATABASE
await update_database( await update_database(
@ -113,16 +120,13 @@ async def _PROXY_track_cost_callback(
"User API key and team id and user id missing from custom callback." "User API key and team id and user id missing from custom callback."
) )
else: else:
if kwargs["stream"] is not True or ( cost_tracking_failure_debug_info = standard_logging_payload.get(
kwargs["stream"] is True and "complete_streaming_response" in kwargs "response_cost_failure_debug_info"
): )
cost_tracking_failure_debug_info = kwargs.get( model = kwargs.get("model")
"response_cost_failure_debug_information" raise ValueError(
) f"Failed to write cost to DB, for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
model = kwargs.get("model") )
raise Exception(
f"Cost tracking failed for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
)
except Exception as e: except Exception as e:
error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}" error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}"
model = kwargs.get("model", "") model = kwargs.get("model", "")