forked from phoenix/litellm-mirror
fix use standard_logging_payload for track cost callback
This commit is contained in:
parent
2639c1971d
commit
9864459f4d
1 changed files with 37 additions and 33 deletions
|
@ -9,6 +9,7 @@ Updates cost for the following in LiteLLM DB:
|
|||
|
||||
import asyncio
|
||||
import traceback
|
||||
from typing import Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
@ -17,6 +18,7 @@ from litellm.proxy.utils import (
|
|||
get_litellm_metadata_from_kwargs,
|
||||
log_to_opentelemetry,
|
||||
)
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
|
||||
@log_to_opentelemetry
|
||||
|
@ -49,31 +51,36 @@ async def _PROXY_track_cost_callback(
|
|||
kwargs,
|
||||
completion_response,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}"
|
||||
)
|
||||
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs)
|
||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
||||
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
||||
metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
|
||||
user_id = metadata.get("user_api_key_user_id", None)
|
||||
team_id = metadata.get("user_api_key_team_id", None)
|
||||
org_id = metadata.get("user_api_key_org_id", None)
|
||||
key_alias = metadata.get("user_api_key_alias", None)
|
||||
end_user_max_budget = metadata.get("user_api_end_user_max_budget", None)
|
||||
if kwargs.get("response_cost", None) is not None:
|
||||
response_cost = kwargs["response_cost"]
|
||||
user_api_key = metadata.get("user_api_key", None)
|
||||
if kwargs.get("cache_hit", False) is True:
|
||||
response_cost = 0.0
|
||||
verbose_proxy_logger.info(
|
||||
f"Cache Hit: response_cost {response_cost}, for user_id {user_id}"
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"user_api_key {user_api_key}, prisma_client: {prisma_client}"
|
||||
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
if standard_logging_payload is None:
|
||||
raise ValueError(
|
||||
"standard_logging_payload is none in kwargs, cannot track cost without it"
|
||||
)
|
||||
end_user_id = standard_logging_payload.get("end_user")
|
||||
user_api_key = standard_logging_payload.get("metadata", {}).get(
|
||||
"user_api_key_hash"
|
||||
)
|
||||
user_id = standard_logging_payload.get("metadata", {}).get(
|
||||
"user_api_key_user_id"
|
||||
)
|
||||
team_id = standard_logging_payload.get("metadata", {}).get(
|
||||
"user_api_key_team_id"
|
||||
)
|
||||
org_id = standard_logging_payload.get("metadata", {}).get("user_api_key_org_id")
|
||||
key_alias = standard_logging_payload.get("metadata", {}).get(
|
||||
"user_api_key_alias"
|
||||
)
|
||||
response_cost: Optional[float] = standard_logging_payload.get("response_cost")
|
||||
|
||||
# end_user_max_budget = metadata.get("user_api_end_user_max_budget", None)
|
||||
end_user_max_budget = standard_logging_payload.get("metadata", {}).get(
|
||||
"user_api_end_user_max_budget"
|
||||
)
|
||||
|
||||
if response_cost is not None:
|
||||
if user_api_key is not None or user_id is not None or team_id is not None:
|
||||
## UPDATE DATABASE
|
||||
await update_database(
|
||||
|
@ -113,16 +120,13 @@ async def _PROXY_track_cost_callback(
|
|||
"User API key and team id and user id missing from custom callback."
|
||||
)
|
||||
else:
|
||||
if kwargs["stream"] is not True or (
|
||||
kwargs["stream"] is True and "complete_streaming_response" in kwargs
|
||||
):
|
||||
cost_tracking_failure_debug_info = kwargs.get(
|
||||
"response_cost_failure_debug_information"
|
||||
)
|
||||
model = kwargs.get("model")
|
||||
raise Exception(
|
||||
f"Cost tracking failed for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
|
||||
)
|
||||
cost_tracking_failure_debug_info = standard_logging_payload.get(
|
||||
"response_cost_failure_debug_info"
|
||||
)
|
||||
model = kwargs.get("model")
|
||||
raise ValueError(
|
||||
f"Failed to write cost to DB, for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}"
|
||||
model = kwargs.get("model", "")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue