diff --git a/litellm/integrations/prometheus.py b/litellm/integrations/prometheus.py index 8a4f409b6..bf19c364e 100644 --- a/litellm/integrations/prometheus.py +++ b/litellm/integrations/prometheus.py @@ -346,8 +346,12 @@ class PrometheusLogger(CustomLogger): standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( "standard_logging_object" ) - if standard_logging_payload is None: - raise ValueError("standard_logging_object is required") + if standard_logging_payload is None or not isinstance( + standard_logging_payload, dict + ): + raise ValueError( + f"standard_logging_object is required, got={standard_logging_payload}" + ) model = kwargs.get("model", "") litellm_params = kwargs.get("litellm_params", {}) or {} _metadata = litellm_params.get("metadata", {}) @@ -991,7 +995,7 @@ class PrometheusLogger(CustomLogger): """ from litellm.litellm_core_utils.litellm_logging import ( StandardLoggingMetadata, - get_standard_logging_metadata, + StandardLoggingPayloadSetup, ) verbose_logger.debug( @@ -1000,8 +1004,10 @@ class PrometheusLogger(CustomLogger): kwargs, ) _metadata = kwargs.get("metadata", {}) - standard_metadata: StandardLoggingMetadata = get_standard_logging_metadata( - metadata=_metadata + standard_metadata: StandardLoggingMetadata = ( + StandardLoggingPayloadSetup.get_standard_logging_metadata( + metadata=_metadata + ) ) _new_model = kwargs.get("model") self.litellm_deployment_successful_fallbacks.labels( @@ -1023,7 +1029,7 @@ class PrometheusLogger(CustomLogger): """ from litellm.litellm_core_utils.litellm_logging import ( StandardLoggingMetadata, - get_standard_logging_metadata, + StandardLoggingPayloadSetup, ) verbose_logger.debug( @@ -1033,8 +1039,10 @@ class PrometheusLogger(CustomLogger): ) _new_model = kwargs.get("model") _metadata = kwargs.get("metadata", {}) - standard_metadata: StandardLoggingMetadata = get_standard_logging_metadata( - metadata=_metadata + standard_metadata: StandardLoggingMetadata = ( + StandardLoggingPayloadSetup.get_standard_logging_metadata( + metadata=_metadata + ) ) self.litellm_deployment_failed_fallbacks.labels( requested_model=original_model_group, diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 0a298d33b..fd7335201 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -12,7 +12,7 @@ import time import traceback import uuid from datetime import datetime as dt_object -from typing import Any, Callable, Dict, List, Literal, Optional, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union from pydantic import BaseModel @@ -51,6 +51,7 @@ from litellm.types.utils import ( StandardPassThroughResponseObject, TextCompletionResponse, TranscriptionResponse, + Usage, ) from litellm.utils import ( _get_base_model_from_metadata, @@ -2454,7 +2455,183 @@ def is_valid_sha256_hash(value: str) -> bool: return bool(re.fullmatch(r"[a-fA-F0-9]{64}", value)) -def get_standard_logging_object_payload( # noqa: PLR0915 +class StandardLoggingPayloadSetup: + @staticmethod + def cleanup_timestamps( + start_time: Union[dt_object, float], + end_time: Union[dt_object, float], + completion_start_time: Union[dt_object, float], + ) -> Tuple[float, float, float]: + """ + Convert datetime objects to floats + """ + + if isinstance(start_time, datetime.datetime): + start_time_float = start_time.timestamp() + elif isinstance(start_time, float): + start_time_float = start_time + else: + raise ValueError( + f"start_time is required, got={start_time} of type {type(start_time)}" + ) + + if isinstance(end_time, datetime.datetime): + end_time_float = end_time.timestamp() + elif isinstance(end_time, float): + end_time_float = end_time + else: + raise ValueError( + f"end_time is required, got={end_time} of type {type(end_time)}" + ) + + if isinstance(completion_start_time, datetime.datetime): + completion_start_time_float = completion_start_time.timestamp() + elif isinstance(completion_start_time, float): + completion_start_time_float = completion_start_time + else: + completion_start_time_float = end_time_float + + return start_time_float, end_time_float, completion_start_time_float + + @staticmethod + def get_standard_logging_metadata( + metadata: Optional[Dict[str, Any]] + ) -> StandardLoggingMetadata: + """ + Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata. + + Args: + metadata (Optional[Dict[str, Any]]): The original metadata dictionary. + + Returns: + StandardLoggingMetadata: A StandardLoggingMetadata object containing the cleaned metadata. + + Note: + - If the input metadata is None or not a dictionary, an empty StandardLoggingMetadata object is returned. + - If 'user_api_key' is present in metadata and is a valid SHA256 hash, it's stored as 'user_api_key_hash'. + """ + # Initialize with default values + clean_metadata = StandardLoggingMetadata( + user_api_key_hash=None, + user_api_key_alias=None, + user_api_key_team_id=None, + user_api_key_user_id=None, + user_api_key_team_alias=None, + spend_logs_metadata=None, + requester_ip_address=None, + requester_metadata=None, + ) + if isinstance(metadata, dict): + # Filter the metadata dictionary to include only the specified keys + clean_metadata = StandardLoggingMetadata( + **{ # type: ignore + key: metadata[key] + for key in StandardLoggingMetadata.__annotations__.keys() + if key in metadata + } + ) + + if metadata.get("user_api_key") is not None: + if is_valid_sha256_hash(str(metadata.get("user_api_key"))): + clean_metadata["user_api_key_hash"] = metadata.get( + "user_api_key" + ) # this is the hash + return clean_metadata + + @staticmethod + def get_usage_from_response_obj(response_obj: Optional[dict]) -> Usage: + ## BASE CASE ## + if response_obj is None: + return Usage( + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + ) + + usage = response_obj.get("usage", None) or {} + if usage is None or ( + not isinstance(usage, dict) and not isinstance(usage, Usage) + ): + return Usage( + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + ) + elif isinstance(usage, Usage): + return usage + elif isinstance(usage, dict): + return Usage(**usage) + + raise ValueError(f"usage is required, got={usage} of type {type(usage)}") + + @staticmethod + def get_model_cost_information( + base_model: Optional[str], + custom_pricing: Optional[bool], + custom_llm_provider: Optional[str], + init_response_obj: Union[Any, BaseModel, dict], + ) -> StandardLoggingModelInformation: + + model_cost_name = _select_model_name_for_cost_calc( + model=None, + completion_response=init_response_obj, # type: ignore + base_model=base_model, + custom_pricing=custom_pricing, + ) + if model_cost_name is None: + model_cost_information = StandardLoggingModelInformation( + model_map_key="", model_map_value=None + ) + else: + try: + _model_cost_information = litellm.get_model_info( + model=model_cost_name, custom_llm_provider=custom_llm_provider + ) + model_cost_information = StandardLoggingModelInformation( + model_map_key=model_cost_name, + model_map_value=_model_cost_information, + ) + except Exception: + verbose_logger.debug( # keep in debug otherwise it will trigger on every call + "Model={} is not mapped in model cost map. Defaulting to None model_cost_information for standard_logging_payload".format( + model_cost_name + ) + ) + model_cost_information = StandardLoggingModelInformation( + model_map_key=model_cost_name, model_map_value=None + ) + return model_cost_information + + @staticmethod + def get_final_response_obj( + response_obj: dict, init_response_obj: Union[Any, BaseModel, dict], kwargs: dict + ) -> Optional[Union[dict, str, list]]: + """ + Get final response object after redacting the message input/output from logging + """ + if response_obj is not None: + final_response_obj: Optional[Union[dict, str, list]] = response_obj + elif isinstance(init_response_obj, list) or isinstance(init_response_obj, str): + final_response_obj = init_response_obj + else: + final_response_obj = None + + modified_final_response_obj = redact_message_input_output_from_logging( + model_call_details=kwargs, + result=final_response_obj, + ) + + if modified_final_response_obj is not None and isinstance( + modified_final_response_obj, BaseModel + ): + final_response_obj = modified_final_response_obj.model_dump() + else: + final_response_obj = modified_final_response_obj + + return final_response_obj + + +def get_standard_logging_object_payload( kwargs: Optional[dict], init_response_obj: Union[Any, BaseModel, dict], start_time: dt_object, @@ -2502,9 +2679,9 @@ def get_standard_logging_object_payload( # noqa: PLR0915 completion_start_time = kwargs.get("completion_start_time", end_time) call_type = kwargs.get("call_type") cache_hit = kwargs.get("cache_hit", False) - usage = response_obj.get("usage", None) or {} - if type(usage) is litellm.Usage: - usage = dict(usage) + usage = StandardLoggingPayloadSetup.get_usage_from_response_obj( + response_obj=response_obj + ) id = response_obj.get("id", kwargs.get("litellm_call_id")) _model_id = metadata.get("model_info", {}).get("id", "") @@ -2517,20 +2694,13 @@ def get_standard_logging_object_payload( # noqa: PLR0915 ) # cleanup timestamps - if isinstance(start_time, datetime.datetime): - start_time_float = start_time.timestamp() - elif isinstance(start_time, float): - start_time_float = start_time - if isinstance(end_time, datetime.datetime): - end_time_float = end_time.timestamp() - elif isinstance(end_time, float): - end_time_float = end_time - if isinstance(completion_start_time, datetime.datetime): - completion_start_time_float = completion_start_time.timestamp() - elif isinstance(completion_start_time, float): - completion_start_time_float = completion_start_time - else: - completion_start_time_float = end_time_float + start_time_float, end_time_float, completion_start_time_float = ( + StandardLoggingPayloadSetup.cleanup_timestamps( + start_time=start_time, + end_time=end_time, + completion_start_time=completion_start_time, + ) + ) # clean up litellm hidden params clean_hidden_params = StandardLoggingHiddenParams( model_id=None, @@ -2548,7 +2718,9 @@ def get_standard_logging_object_payload( # noqa: PLR0915 } ) # clean up litellm metadata - clean_metadata = get_standard_logging_metadata(metadata=metadata) + clean_metadata = StandardLoggingPayloadSetup.get_standard_logging_metadata( + metadata=metadata + ) if litellm.cache is not None: cache_key = litellm.cache.get_cache_key(**kwargs) @@ -2570,58 +2742,21 @@ def get_standard_logging_object_payload( # noqa: PLR0915 ## Get model cost information ## base_model = _get_base_model_from_metadata(model_call_details=kwargs) custom_pricing = use_custom_pricing_for_model(litellm_params=litellm_params) - model_cost_name = _select_model_name_for_cost_calc( - model=None, - completion_response=init_response_obj, # type: ignore + model_cost_information = StandardLoggingPayloadSetup.get_model_cost_information( base_model=base_model, custom_pricing=custom_pricing, + custom_llm_provider=kwargs.get("custom_llm_provider"), + init_response_obj=init_response_obj, ) - if model_cost_name is None: - model_cost_information = StandardLoggingModelInformation( - model_map_key="", model_map_value=None - ) - else: - custom_llm_provider = kwargs.get("custom_llm_provider", None) - - try: - _model_cost_information = litellm.get_model_info( - model=model_cost_name, custom_llm_provider=custom_llm_provider - ) - model_cost_information = StandardLoggingModelInformation( - model_map_key=model_cost_name, - model_map_value=_model_cost_information, - ) - except Exception: - verbose_logger.debug( # keep in debug otherwise it will trigger on every call - "Model={} is not mapped in model cost map. Defaulting to None model_cost_information for standard_logging_payload".format( - model_cost_name - ) - ) - model_cost_information = StandardLoggingModelInformation( - model_map_key=model_cost_name, model_map_value=None - ) - response_cost: float = kwargs.get("response_cost", 0) or 0.0 - if response_obj is not None: - final_response_obj: Optional[Union[dict, str, list]] = response_obj - elif isinstance(init_response_obj, list) or isinstance(init_response_obj, str): - final_response_obj = init_response_obj - else: - final_response_obj = None - - modified_final_response_obj = redact_message_input_output_from_logging( - model_call_details=kwargs, - result=final_response_obj, + ## get final response object ## + final_response_obj = StandardLoggingPayloadSetup.get_final_response_obj( + response_obj=response_obj, + init_response_obj=init_response_obj, + kwargs=kwargs, ) - if modified_final_response_obj is not None and isinstance( - modified_final_response_obj, BaseModel - ): - final_response_obj = modified_final_response_obj.model_dump() - else: - final_response_obj = modified_final_response_obj - payload: StandardLoggingPayload = StandardLoggingPayload( id=str(id), call_type=call_type or "", @@ -2635,9 +2770,9 @@ def get_standard_logging_object_payload( # noqa: PLR0915 metadata=clean_metadata, cache_key=cache_key, response_cost=response_cost, - total_tokens=usage.get("total_tokens", 0), - prompt_tokens=usage.get("prompt_tokens", 0), - completion_tokens=usage.get("completion_tokens", 0), + total_tokens=usage.total_tokens, + prompt_tokens=usage.prompt_tokens, + completion_tokens=usage.completion_tokens, request_tags=request_tags, end_user=end_user_id or "", api_base=litellm_params.get("api_base", ""), @@ -2663,51 +2798,6 @@ def get_standard_logging_object_payload( # noqa: PLR0915 return None -def get_standard_logging_metadata( - metadata: Optional[Dict[str, Any]] -) -> StandardLoggingMetadata: - """ - Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata. - - Args: - metadata (Optional[Dict[str, Any]]): The original metadata dictionary. - - Returns: - StandardLoggingMetadata: A StandardLoggingMetadata object containing the cleaned metadata. - - Note: - - If the input metadata is None or not a dictionary, an empty StandardLoggingMetadata object is returned. - - If 'user_api_key' is present in metadata and is a valid SHA256 hash, it's stored as 'user_api_key_hash'. - """ - # Initialize with default values - clean_metadata = StandardLoggingMetadata( - user_api_key_hash=None, - user_api_key_alias=None, - user_api_key_team_id=None, - user_api_key_user_id=None, - user_api_key_team_alias=None, - spend_logs_metadata=None, - requester_ip_address=None, - requester_metadata=None, - ) - if isinstance(metadata, dict): - # Filter the metadata dictionary to include only the specified keys - clean_metadata = StandardLoggingMetadata( - **{ # type: ignore - key: metadata[key] - for key in StandardLoggingMetadata.__annotations__.keys() - if key in metadata - } - ) - - if metadata.get("user_api_key") is not None: - if is_valid_sha256_hash(str(metadata.get("user_api_key"))): - clean_metadata["user_api_key_hash"] = metadata.get( - "user_api_key" - ) # this is the hash - return clean_metadata - - def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]): if litellm_params is None: litellm_params = {} diff --git a/tests/logging_callback_tests/test_standard_logging_payload.py b/tests/logging_callback_tests/test_standard_logging_payload.py new file mode 100644 index 000000000..f6599a005 --- /dev/null +++ b/tests/logging_callback_tests/test_standard_logging_payload.py @@ -0,0 +1,67 @@ +""" +Unit tests for StandardLoggingPayloadSetup +""" + +import json +import os +import sys +from datetime import datetime +from unittest.mock import AsyncMock + +from pydantic.main import Model + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system-path + +import pytest +import litellm +from litellm.types.utils import Usage +from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup + + +@pytest.mark.parametrize( + "response_obj,expected_values", + [ + # Test None input + (None, (0, 0, 0)), + # Test empty dict + ({}, (0, 0, 0)), + # Test valid usage dict + ( + { + "usage": { + "prompt_tokens": 10, + "completion_tokens": 20, + "total_tokens": 30, + } + }, + (10, 20, 30), + ), + # Test with litellm.Usage object + ( + {"usage": Usage(prompt_tokens=15, completion_tokens=25, total_tokens=40)}, + (15, 25, 40), + ), + # Test invalid usage type + ({"usage": "invalid"}, (0, 0, 0)), + # Test None usage + ({"usage": None}, (0, 0, 0)), + ], +) +def test_get_usage(response_obj, expected_values): + """ + Make sure values returned from get_usage are always integers + """ + + usage = StandardLoggingPayloadSetup.get_usage_from_response_obj(response_obj) + + # Check types + assert isinstance(usage.prompt_tokens, int) + assert isinstance(usage.completion_tokens, int) + assert isinstance(usage.total_tokens, int) + + # Check values + assert usage.prompt_tokens == expected_values[0] + assert usage.completion_tokens == expected_values[1] + assert usage.total_tokens == expected_values[2]