fix use standard_logging_payload for track cost callback

This commit is contained in:
Ishaan Jaff 2024-11-04 11:01:58 -08:00
parent 2639c1971d
commit 9864459f4d

View file

@ -9,6 +9,7 @@ Updates cost for the following in LiteLLM DB:
import asyncio
import traceback
from typing import Optional
import litellm
from litellm._logging import verbose_proxy_logger
@ -17,6 +18,7 @@ from litellm.proxy.utils import (
get_litellm_metadata_from_kwargs,
log_to_opentelemetry,
)
from litellm.types.utils import StandardLoggingPayload
@log_to_opentelemetry
@ -49,31 +51,36 @@ async def _PROXY_track_cost_callback(
kwargs,
completion_response,
)
verbose_proxy_logger.debug(
f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}"
)
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs)
litellm_params = kwargs.get("litellm_params", {}) or {}
proxy_server_request = litellm_params.get("proxy_server_request") or {}
end_user_id = proxy_server_request.get("body", {}).get("user", None)
metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
user_id = metadata.get("user_api_key_user_id", None)
team_id = metadata.get("user_api_key_team_id", None)
org_id = metadata.get("user_api_key_org_id", None)
key_alias = metadata.get("user_api_key_alias", None)
end_user_max_budget = metadata.get("user_api_end_user_max_budget", None)
if kwargs.get("response_cost", None) is not None:
response_cost = kwargs["response_cost"]
user_api_key = metadata.get("user_api_key", None)
if kwargs.get("cache_hit", False) is True:
response_cost = 0.0
verbose_proxy_logger.info(
f"Cache Hit: response_cost {response_cost}, for user_id {user_id}"
)
verbose_proxy_logger.debug(
f"user_api_key {user_api_key}, prisma_client: {prisma_client}"
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
if standard_logging_payload is None:
raise ValueError(
"standard_logging_payload is none in kwargs, cannot track cost without it"
)
end_user_id = standard_logging_payload.get("end_user")
user_api_key = standard_logging_payload.get("metadata", {}).get(
"user_api_key_hash"
)
user_id = standard_logging_payload.get("metadata", {}).get(
"user_api_key_user_id"
)
team_id = standard_logging_payload.get("metadata", {}).get(
"user_api_key_team_id"
)
org_id = standard_logging_payload.get("metadata", {}).get("user_api_key_org_id")
key_alias = standard_logging_payload.get("metadata", {}).get(
"user_api_key_alias"
)
response_cost: Optional[float] = standard_logging_payload.get("response_cost")
# end_user_max_budget = metadata.get("user_api_end_user_max_budget", None)
end_user_max_budget = standard_logging_payload.get("metadata", {}).get(
"user_api_end_user_max_budget"
)
if response_cost is not None:
if user_api_key is not None or user_id is not None or team_id is not None:
## UPDATE DATABASE
await update_database(
@ -113,16 +120,13 @@ async def _PROXY_track_cost_callback(
"User API key and team id and user id missing from custom callback."
)
else:
if kwargs["stream"] is not True or (
kwargs["stream"] is True and "complete_streaming_response" in kwargs
):
cost_tracking_failure_debug_info = kwargs.get(
"response_cost_failure_debug_information"
)
model = kwargs.get("model")
raise Exception(
f"Cost tracking failed for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
)
cost_tracking_failure_debug_info = standard_logging_payload.get(
"response_cost_failure_debug_info"
)
model = kwargs.get("model")
raise ValueError(
f"Failed to write cost to DB, for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
)
except Exception as e:
error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}"
model = kwargs.get("model", "")