feat(litellm_logging.py): refactor standard_logging_payload function … (#6388)

* feat(litellm_logging.py): refactor standard_logging_payload function to be <50 LOC

fixes issue where usage information was not following typed values

* fix(litellm_logging.py): fix completion start time handling
This commit is contained in:
Krish Dholakia 2024-10-24 18:59:01 -07:00 committed by GitHub
parent d59f8f952d
commit c04c4a82f1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 286 additions and 121 deletions

View file

@ -346,8 +346,12 @@ class PrometheusLogger(CustomLogger):
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object"
)
if standard_logging_payload is None:
raise ValueError("standard_logging_object is required")
if standard_logging_payload is None or not isinstance(
standard_logging_payload, dict
):
raise ValueError(
f"standard_logging_object is required, got={standard_logging_payload}"
)
model = kwargs.get("model", "")
litellm_params = kwargs.get("litellm_params", {}) or {}
_metadata = litellm_params.get("metadata", {})
@ -991,7 +995,7 @@ class PrometheusLogger(CustomLogger):
"""
from litellm.litellm_core_utils.litellm_logging import (
StandardLoggingMetadata,
get_standard_logging_metadata,
StandardLoggingPayloadSetup,
)
verbose_logger.debug(
@ -1000,8 +1004,10 @@ class PrometheusLogger(CustomLogger):
kwargs,
)
_metadata = kwargs.get("metadata", {})
standard_metadata: StandardLoggingMetadata = get_standard_logging_metadata(
metadata=_metadata
standard_metadata: StandardLoggingMetadata = (
StandardLoggingPayloadSetup.get_standard_logging_metadata(
metadata=_metadata
)
)
_new_model = kwargs.get("model")
self.litellm_deployment_successful_fallbacks.labels(
@ -1023,7 +1029,7 @@ class PrometheusLogger(CustomLogger):
"""
from litellm.litellm_core_utils.litellm_logging import (
StandardLoggingMetadata,
get_standard_logging_metadata,
StandardLoggingPayloadSetup,
)
verbose_logger.debug(
@ -1033,8 +1039,10 @@ class PrometheusLogger(CustomLogger):
)
_new_model = kwargs.get("model")
_metadata = kwargs.get("metadata", {})
standard_metadata: StandardLoggingMetadata = get_standard_logging_metadata(
metadata=_metadata
standard_metadata: StandardLoggingMetadata = (
StandardLoggingPayloadSetup.get_standard_logging_metadata(
metadata=_metadata
)
)
self.litellm_deployment_failed_fallbacks.labels(
requested_model=original_model_group,

View file

@ -12,7 +12,7 @@ import time
import traceback
import uuid
from datetime import datetime as dt_object
from typing import Any, Callable, Dict, List, Literal, Optional, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from pydantic import BaseModel
@ -51,6 +51,7 @@ from litellm.types.utils import (
StandardPassThroughResponseObject,
TextCompletionResponse,
TranscriptionResponse,
Usage,
)
from litellm.utils import (
_get_base_model_from_metadata,
@ -2454,7 +2455,183 @@ def is_valid_sha256_hash(value: str) -> bool:
return bool(re.fullmatch(r"[a-fA-F0-9]{64}", value))
def get_standard_logging_object_payload( # noqa: PLR0915
class StandardLoggingPayloadSetup:
@staticmethod
def cleanup_timestamps(
start_time: Union[dt_object, float],
end_time: Union[dt_object, float],
completion_start_time: Union[dt_object, float],
) -> Tuple[float, float, float]:
"""
Convert datetime objects to floats
"""
if isinstance(start_time, datetime.datetime):
start_time_float = start_time.timestamp()
elif isinstance(start_time, float):
start_time_float = start_time
else:
raise ValueError(
f"start_time is required, got={start_time} of type {type(start_time)}"
)
if isinstance(end_time, datetime.datetime):
end_time_float = end_time.timestamp()
elif isinstance(end_time, float):
end_time_float = end_time
else:
raise ValueError(
f"end_time is required, got={end_time} of type {type(end_time)}"
)
if isinstance(completion_start_time, datetime.datetime):
completion_start_time_float = completion_start_time.timestamp()
elif isinstance(completion_start_time, float):
completion_start_time_float = completion_start_time
else:
completion_start_time_float = end_time_float
return start_time_float, end_time_float, completion_start_time_float
@staticmethod
def get_standard_logging_metadata(
metadata: Optional[Dict[str, Any]]
) -> StandardLoggingMetadata:
"""
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
Args:
metadata (Optional[Dict[str, Any]]): The original metadata dictionary.
Returns:
StandardLoggingMetadata: A StandardLoggingMetadata object containing the cleaned metadata.
Note:
- If the input metadata is None or not a dictionary, an empty StandardLoggingMetadata object is returned.
- If 'user_api_key' is present in metadata and is a valid SHA256 hash, it's stored as 'user_api_key_hash'.
"""
# Initialize with default values
clean_metadata = StandardLoggingMetadata(
user_api_key_hash=None,
user_api_key_alias=None,
user_api_key_team_id=None,
user_api_key_user_id=None,
user_api_key_team_alias=None,
spend_logs_metadata=None,
requester_ip_address=None,
requester_metadata=None,
)
if isinstance(metadata, dict):
# Filter the metadata dictionary to include only the specified keys
clean_metadata = StandardLoggingMetadata(
**{ # type: ignore
key: metadata[key]
for key in StandardLoggingMetadata.__annotations__.keys()
if key in metadata
}
)
if metadata.get("user_api_key") is not None:
if is_valid_sha256_hash(str(metadata.get("user_api_key"))):
clean_metadata["user_api_key_hash"] = metadata.get(
"user_api_key"
) # this is the hash
return clean_metadata
@staticmethod
def get_usage_from_response_obj(response_obj: Optional[dict]) -> Usage:
## BASE CASE ##
if response_obj is None:
return Usage(
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
)
usage = response_obj.get("usage", None) or {}
if usage is None or (
not isinstance(usage, dict) and not isinstance(usage, Usage)
):
return Usage(
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
)
elif isinstance(usage, Usage):
return usage
elif isinstance(usage, dict):
return Usage(**usage)
raise ValueError(f"usage is required, got={usage} of type {type(usage)}")
@staticmethod
def get_model_cost_information(
base_model: Optional[str],
custom_pricing: Optional[bool],
custom_llm_provider: Optional[str],
init_response_obj: Union[Any, BaseModel, dict],
) -> StandardLoggingModelInformation:
model_cost_name = _select_model_name_for_cost_calc(
model=None,
completion_response=init_response_obj, # type: ignore
base_model=base_model,
custom_pricing=custom_pricing,
)
if model_cost_name is None:
model_cost_information = StandardLoggingModelInformation(
model_map_key="", model_map_value=None
)
else:
try:
_model_cost_information = litellm.get_model_info(
model=model_cost_name, custom_llm_provider=custom_llm_provider
)
model_cost_information = StandardLoggingModelInformation(
model_map_key=model_cost_name,
model_map_value=_model_cost_information,
)
except Exception:
verbose_logger.debug( # keep in debug otherwise it will trigger on every call
"Model={} is not mapped in model cost map. Defaulting to None model_cost_information for standard_logging_payload".format(
model_cost_name
)
)
model_cost_information = StandardLoggingModelInformation(
model_map_key=model_cost_name, model_map_value=None
)
return model_cost_information
@staticmethod
def get_final_response_obj(
response_obj: dict, init_response_obj: Union[Any, BaseModel, dict], kwargs: dict
) -> Optional[Union[dict, str, list]]:
"""
Get final response object after redacting the message input/output from logging
"""
if response_obj is not None:
final_response_obj: Optional[Union[dict, str, list]] = response_obj
elif isinstance(init_response_obj, list) or isinstance(init_response_obj, str):
final_response_obj = init_response_obj
else:
final_response_obj = None
modified_final_response_obj = redact_message_input_output_from_logging(
model_call_details=kwargs,
result=final_response_obj,
)
if modified_final_response_obj is not None and isinstance(
modified_final_response_obj, BaseModel
):
final_response_obj = modified_final_response_obj.model_dump()
else:
final_response_obj = modified_final_response_obj
return final_response_obj
def get_standard_logging_object_payload(
kwargs: Optional[dict],
init_response_obj: Union[Any, BaseModel, dict],
start_time: dt_object,
@ -2502,9 +2679,9 @@ def get_standard_logging_object_payload( # noqa: PLR0915
completion_start_time = kwargs.get("completion_start_time", end_time)
call_type = kwargs.get("call_type")
cache_hit = kwargs.get("cache_hit", False)
usage = response_obj.get("usage", None) or {}
if type(usage) is litellm.Usage:
usage = dict(usage)
usage = StandardLoggingPayloadSetup.get_usage_from_response_obj(
response_obj=response_obj
)
id = response_obj.get("id", kwargs.get("litellm_call_id"))
_model_id = metadata.get("model_info", {}).get("id", "")
@ -2517,20 +2694,13 @@ def get_standard_logging_object_payload( # noqa: PLR0915
)
# cleanup timestamps
if isinstance(start_time, datetime.datetime):
start_time_float = start_time.timestamp()
elif isinstance(start_time, float):
start_time_float = start_time
if isinstance(end_time, datetime.datetime):
end_time_float = end_time.timestamp()
elif isinstance(end_time, float):
end_time_float = end_time
if isinstance(completion_start_time, datetime.datetime):
completion_start_time_float = completion_start_time.timestamp()
elif isinstance(completion_start_time, float):
completion_start_time_float = completion_start_time
else:
completion_start_time_float = end_time_float
start_time_float, end_time_float, completion_start_time_float = (
StandardLoggingPayloadSetup.cleanup_timestamps(
start_time=start_time,
end_time=end_time,
completion_start_time=completion_start_time,
)
)
# clean up litellm hidden params
clean_hidden_params = StandardLoggingHiddenParams(
model_id=None,
@ -2548,7 +2718,9 @@ def get_standard_logging_object_payload( # noqa: PLR0915
}
)
# clean up litellm metadata
clean_metadata = get_standard_logging_metadata(metadata=metadata)
clean_metadata = StandardLoggingPayloadSetup.get_standard_logging_metadata(
metadata=metadata
)
if litellm.cache is not None:
cache_key = litellm.cache.get_cache_key(**kwargs)
@ -2570,58 +2742,21 @@ def get_standard_logging_object_payload( # noqa: PLR0915
## Get model cost information ##
base_model = _get_base_model_from_metadata(model_call_details=kwargs)
custom_pricing = use_custom_pricing_for_model(litellm_params=litellm_params)
model_cost_name = _select_model_name_for_cost_calc(
model=None,
completion_response=init_response_obj, # type: ignore
model_cost_information = StandardLoggingPayloadSetup.get_model_cost_information(
base_model=base_model,
custom_pricing=custom_pricing,
custom_llm_provider=kwargs.get("custom_llm_provider"),
init_response_obj=init_response_obj,
)
if model_cost_name is None:
model_cost_information = StandardLoggingModelInformation(
model_map_key="", model_map_value=None
)
else:
custom_llm_provider = kwargs.get("custom_llm_provider", None)
try:
_model_cost_information = litellm.get_model_info(
model=model_cost_name, custom_llm_provider=custom_llm_provider
)
model_cost_information = StandardLoggingModelInformation(
model_map_key=model_cost_name,
model_map_value=_model_cost_information,
)
except Exception:
verbose_logger.debug( # keep in debug otherwise it will trigger on every call
"Model={} is not mapped in model cost map. Defaulting to None model_cost_information for standard_logging_payload".format(
model_cost_name
)
)
model_cost_information = StandardLoggingModelInformation(
model_map_key=model_cost_name, model_map_value=None
)
response_cost: float = kwargs.get("response_cost", 0) or 0.0
if response_obj is not None:
final_response_obj: Optional[Union[dict, str, list]] = response_obj
elif isinstance(init_response_obj, list) or isinstance(init_response_obj, str):
final_response_obj = init_response_obj
else:
final_response_obj = None
modified_final_response_obj = redact_message_input_output_from_logging(
model_call_details=kwargs,
result=final_response_obj,
## get final response object ##
final_response_obj = StandardLoggingPayloadSetup.get_final_response_obj(
response_obj=response_obj,
init_response_obj=init_response_obj,
kwargs=kwargs,
)
if modified_final_response_obj is not None and isinstance(
modified_final_response_obj, BaseModel
):
final_response_obj = modified_final_response_obj.model_dump()
else:
final_response_obj = modified_final_response_obj
payload: StandardLoggingPayload = StandardLoggingPayload(
id=str(id),
call_type=call_type or "",
@ -2635,9 +2770,9 @@ def get_standard_logging_object_payload( # noqa: PLR0915
metadata=clean_metadata,
cache_key=cache_key,
response_cost=response_cost,
total_tokens=usage.get("total_tokens", 0),
prompt_tokens=usage.get("prompt_tokens", 0),
completion_tokens=usage.get("completion_tokens", 0),
total_tokens=usage.total_tokens,
prompt_tokens=usage.prompt_tokens,
completion_tokens=usage.completion_tokens,
request_tags=request_tags,
end_user=end_user_id or "",
api_base=litellm_params.get("api_base", ""),
@ -2663,51 +2798,6 @@ def get_standard_logging_object_payload( # noqa: PLR0915
return None
def get_standard_logging_metadata(
metadata: Optional[Dict[str, Any]]
) -> StandardLoggingMetadata:
"""
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
Args:
metadata (Optional[Dict[str, Any]]): The original metadata dictionary.
Returns:
StandardLoggingMetadata: A StandardLoggingMetadata object containing the cleaned metadata.
Note:
- If the input metadata is None or not a dictionary, an empty StandardLoggingMetadata object is returned.
- If 'user_api_key' is present in metadata and is a valid SHA256 hash, it's stored as 'user_api_key_hash'.
"""
# Initialize with default values
clean_metadata = StandardLoggingMetadata(
user_api_key_hash=None,
user_api_key_alias=None,
user_api_key_team_id=None,
user_api_key_user_id=None,
user_api_key_team_alias=None,
spend_logs_metadata=None,
requester_ip_address=None,
requester_metadata=None,
)
if isinstance(metadata, dict):
# Filter the metadata dictionary to include only the specified keys
clean_metadata = StandardLoggingMetadata(
**{ # type: ignore
key: metadata[key]
for key in StandardLoggingMetadata.__annotations__.keys()
if key in metadata
}
)
if metadata.get("user_api_key") is not None:
if is_valid_sha256_hash(str(metadata.get("user_api_key"))):
clean_metadata["user_api_key_hash"] = metadata.get(
"user_api_key"
) # this is the hash
return clean_metadata
def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]):
if litellm_params is None:
litellm_params = {}

View file

@ -0,0 +1,67 @@
"""
Unit tests for StandardLoggingPayloadSetup
"""
import json
import os
import sys
from datetime import datetime
from unittest.mock import AsyncMock
from pydantic.main import Model
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system-path
import pytest
import litellm
from litellm.types.utils import Usage
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
@pytest.mark.parametrize(
"response_obj,expected_values",
[
# Test None input
(None, (0, 0, 0)),
# Test empty dict
({}, (0, 0, 0)),
# Test valid usage dict
(
{
"usage": {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30,
}
},
(10, 20, 30),
),
# Test with litellm.Usage object
(
{"usage": Usage(prompt_tokens=15, completion_tokens=25, total_tokens=40)},
(15, 25, 40),
),
# Test invalid usage type
({"usage": "invalid"}, (0, 0, 0)),
# Test None usage
({"usage": None}, (0, 0, 0)),
],
)
def test_get_usage(response_obj, expected_values):
"""
Make sure values returned from get_usage are always integers
"""
usage = StandardLoggingPayloadSetup.get_usage_from_response_obj(response_obj)
# Check types
assert isinstance(usage.prompt_tokens, int)
assert isinstance(usage.completion_tokens, int)
assert isinstance(usage.total_tokens, int)
# Check values
assert usage.prompt_tokens == expected_values[0]
assert usage.completion_tokens == expected_values[1]
assert usage.total_tokens == expected_values[2]