diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 69d6adca4..9239b7b90 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -2359,6 +2359,7 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 _in_memory_loggers.append(_mlflow_logger) return _mlflow_logger # type: ignore + def get_custom_logger_compatible_class( logging_integration: litellm._custom_logger_compatible_callbacks_literal, ) -> Optional[CustomLogger]: @@ -2719,6 +2720,31 @@ class StandardLoggingPayloadSetup: return clean_hidden_params +class StandardLoggingPayloadAccessors: + """ + Accessor methods for StandardLoggingPayload + + Class that allows easily reading fields from StandardLoggingPayload + + """ + + @staticmethod + def get_custom_llm_provider_from_standard_logging_payload( + standard_logging_payload: Optional[StandardLoggingPayload], + ) -> Optional[str]: + """ + Accessor method to safely get custom_llm_provider from standard_logging_payload + """ + if standard_logging_payload is None: + return None + model_map_information = ( + standard_logging_payload.get("model_map_information") or {} + ) + model_map_value = model_map_information.get("model_map_value") or {} + custom_llm_provider = model_map_value.get("litellm_provider") + return custom_llm_provider + + def get_standard_logging_object_payload( kwargs: Optional[dict], init_response_obj: Union[Any, BaseModel, dict], diff --git a/litellm/proxy/spend_tracking/spend_tracking_utils.py b/litellm/proxy/spend_tracking/spend_tracking_utils.py index 6f3d1b522..e6a79b05e 100644 --- a/litellm/proxy/spend_tracking/spend_tracking_utils.py +++ b/litellm/proxy/spend_tracking/spend_tracking_utils.py @@ -10,8 +10,10 @@ from pydantic import BaseModel import litellm from litellm._logging import verbose_proxy_logger +from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadAccessors from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload from litellm.proxy.utils import PrismaClient, hash_token +from litellm.types.utils import StandardLoggingPayload def _is_master_key(api_key: str, _master_key: Optional[str]) -> bool: @@ -49,6 +51,10 @@ def get_logging_payload( response_obj = {} # standardize this function to be used across, s3, dynamoDB, langfuse logging litellm_params = kwargs.get("litellm_params", {}) + standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object", None + ) + metadata = ( litellm_params.get("metadata", {}) or {} ) # if litellm_params['metadata'] == None @@ -150,7 +156,9 @@ def get_logging_payload( request_tags=request_tags, end_user=end_user_id or "", api_base=litellm_params.get("api_base", ""), - custom_llm_provider=litellm_params.get("custom_llm_provider", None), + custom_llm_provider=StandardLoggingPayloadAccessors.get_custom_llm_provider_from_standard_logging_payload( + standard_logging_payload + ), model_group=_model_group, model_id=_model_id, requester_ip_address=clean_metadata.get("requester_ip_address", None),