forked from phoenix/litellm-mirror
feat(litellm_logging.py): refactor standard_logging_payload function … (#6388)
* feat(litellm_logging.py): refactor standard_logging_payload function to be <50 LOC fixes issue where usage information was not following typed values * fix(litellm_logging.py): fix completion start time handling
This commit is contained in:
parent
d59f8f952d
commit
c04c4a82f1
3 changed files with 286 additions and 121 deletions
|
@ -346,8 +346,12 @@ class PrometheusLogger(CustomLogger):
|
|||
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object"
|
||||
)
|
||||
if standard_logging_payload is None:
|
||||
raise ValueError("standard_logging_object is required")
|
||||
if standard_logging_payload is None or not isinstance(
|
||||
standard_logging_payload, dict
|
||||
):
|
||||
raise ValueError(
|
||||
f"standard_logging_object is required, got={standard_logging_payload}"
|
||||
)
|
||||
model = kwargs.get("model", "")
|
||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
_metadata = litellm_params.get("metadata", {})
|
||||
|
@ -991,7 +995,7 @@ class PrometheusLogger(CustomLogger):
|
|||
"""
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
StandardLoggingMetadata,
|
||||
get_standard_logging_metadata,
|
||||
StandardLoggingPayloadSetup,
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
|
@ -1000,8 +1004,10 @@ class PrometheusLogger(CustomLogger):
|
|||
kwargs,
|
||||
)
|
||||
_metadata = kwargs.get("metadata", {})
|
||||
standard_metadata: StandardLoggingMetadata = get_standard_logging_metadata(
|
||||
metadata=_metadata
|
||||
standard_metadata: StandardLoggingMetadata = (
|
||||
StandardLoggingPayloadSetup.get_standard_logging_metadata(
|
||||
metadata=_metadata
|
||||
)
|
||||
)
|
||||
_new_model = kwargs.get("model")
|
||||
self.litellm_deployment_successful_fallbacks.labels(
|
||||
|
@ -1023,7 +1029,7 @@ class PrometheusLogger(CustomLogger):
|
|||
"""
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
StandardLoggingMetadata,
|
||||
get_standard_logging_metadata,
|
||||
StandardLoggingPayloadSetup,
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
|
@ -1033,8 +1039,10 @@ class PrometheusLogger(CustomLogger):
|
|||
)
|
||||
_new_model = kwargs.get("model")
|
||||
_metadata = kwargs.get("metadata", {})
|
||||
standard_metadata: StandardLoggingMetadata = get_standard_logging_metadata(
|
||||
metadata=_metadata
|
||||
standard_metadata: StandardLoggingMetadata = (
|
||||
StandardLoggingPayloadSetup.get_standard_logging_metadata(
|
||||
metadata=_metadata
|
||||
)
|
||||
)
|
||||
self.litellm_deployment_failed_fallbacks.labels(
|
||||
requested_model=original_model_group,
|
||||
|
|
|
@ -12,7 +12,7 @@ import time
|
|||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime as dt_object
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -51,6 +51,7 @@ from litellm.types.utils import (
|
|||
StandardPassThroughResponseObject,
|
||||
TextCompletionResponse,
|
||||
TranscriptionResponse,
|
||||
Usage,
|
||||
)
|
||||
from litellm.utils import (
|
||||
_get_base_model_from_metadata,
|
||||
|
@ -2454,7 +2455,183 @@ def is_valid_sha256_hash(value: str) -> bool:
|
|||
return bool(re.fullmatch(r"[a-fA-F0-9]{64}", value))
|
||||
|
||||
|
||||
def get_standard_logging_object_payload( # noqa: PLR0915
|
||||
class StandardLoggingPayloadSetup:
|
||||
@staticmethod
|
||||
def cleanup_timestamps(
|
||||
start_time: Union[dt_object, float],
|
||||
end_time: Union[dt_object, float],
|
||||
completion_start_time: Union[dt_object, float],
|
||||
) -> Tuple[float, float, float]:
|
||||
"""
|
||||
Convert datetime objects to floats
|
||||
"""
|
||||
|
||||
if isinstance(start_time, datetime.datetime):
|
||||
start_time_float = start_time.timestamp()
|
||||
elif isinstance(start_time, float):
|
||||
start_time_float = start_time
|
||||
else:
|
||||
raise ValueError(
|
||||
f"start_time is required, got={start_time} of type {type(start_time)}"
|
||||
)
|
||||
|
||||
if isinstance(end_time, datetime.datetime):
|
||||
end_time_float = end_time.timestamp()
|
||||
elif isinstance(end_time, float):
|
||||
end_time_float = end_time
|
||||
else:
|
||||
raise ValueError(
|
||||
f"end_time is required, got={end_time} of type {type(end_time)}"
|
||||
)
|
||||
|
||||
if isinstance(completion_start_time, datetime.datetime):
|
||||
completion_start_time_float = completion_start_time.timestamp()
|
||||
elif isinstance(completion_start_time, float):
|
||||
completion_start_time_float = completion_start_time
|
||||
else:
|
||||
completion_start_time_float = end_time_float
|
||||
|
||||
return start_time_float, end_time_float, completion_start_time_float
|
||||
|
||||
@staticmethod
|
||||
def get_standard_logging_metadata(
|
||||
metadata: Optional[Dict[str, Any]]
|
||||
) -> StandardLoggingMetadata:
|
||||
"""
|
||||
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
|
||||
|
||||
Args:
|
||||
metadata (Optional[Dict[str, Any]]): The original metadata dictionary.
|
||||
|
||||
Returns:
|
||||
StandardLoggingMetadata: A StandardLoggingMetadata object containing the cleaned metadata.
|
||||
|
||||
Note:
|
||||
- If the input metadata is None or not a dictionary, an empty StandardLoggingMetadata object is returned.
|
||||
- If 'user_api_key' is present in metadata and is a valid SHA256 hash, it's stored as 'user_api_key_hash'.
|
||||
"""
|
||||
# Initialize with default values
|
||||
clean_metadata = StandardLoggingMetadata(
|
||||
user_api_key_hash=None,
|
||||
user_api_key_alias=None,
|
||||
user_api_key_team_id=None,
|
||||
user_api_key_user_id=None,
|
||||
user_api_key_team_alias=None,
|
||||
spend_logs_metadata=None,
|
||||
requester_ip_address=None,
|
||||
requester_metadata=None,
|
||||
)
|
||||
if isinstance(metadata, dict):
|
||||
# Filter the metadata dictionary to include only the specified keys
|
||||
clean_metadata = StandardLoggingMetadata(
|
||||
**{ # type: ignore
|
||||
key: metadata[key]
|
||||
for key in StandardLoggingMetadata.__annotations__.keys()
|
||||
if key in metadata
|
||||
}
|
||||
)
|
||||
|
||||
if metadata.get("user_api_key") is not None:
|
||||
if is_valid_sha256_hash(str(metadata.get("user_api_key"))):
|
||||
clean_metadata["user_api_key_hash"] = metadata.get(
|
||||
"user_api_key"
|
||||
) # this is the hash
|
||||
return clean_metadata
|
||||
|
||||
@staticmethod
|
||||
def get_usage_from_response_obj(response_obj: Optional[dict]) -> Usage:
|
||||
## BASE CASE ##
|
||||
if response_obj is None:
|
||||
return Usage(
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
)
|
||||
|
||||
usage = response_obj.get("usage", None) or {}
|
||||
if usage is None or (
|
||||
not isinstance(usage, dict) and not isinstance(usage, Usage)
|
||||
):
|
||||
return Usage(
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
)
|
||||
elif isinstance(usage, Usage):
|
||||
return usage
|
||||
elif isinstance(usage, dict):
|
||||
return Usage(**usage)
|
||||
|
||||
raise ValueError(f"usage is required, got={usage} of type {type(usage)}")
|
||||
|
||||
@staticmethod
|
||||
def get_model_cost_information(
|
||||
base_model: Optional[str],
|
||||
custom_pricing: Optional[bool],
|
||||
custom_llm_provider: Optional[str],
|
||||
init_response_obj: Union[Any, BaseModel, dict],
|
||||
) -> StandardLoggingModelInformation:
|
||||
|
||||
model_cost_name = _select_model_name_for_cost_calc(
|
||||
model=None,
|
||||
completion_response=init_response_obj, # type: ignore
|
||||
base_model=base_model,
|
||||
custom_pricing=custom_pricing,
|
||||
)
|
||||
if model_cost_name is None:
|
||||
model_cost_information = StandardLoggingModelInformation(
|
||||
model_map_key="", model_map_value=None
|
||||
)
|
||||
else:
|
||||
try:
|
||||
_model_cost_information = litellm.get_model_info(
|
||||
model=model_cost_name, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
model_cost_information = StandardLoggingModelInformation(
|
||||
model_map_key=model_cost_name,
|
||||
model_map_value=_model_cost_information,
|
||||
)
|
||||
except Exception:
|
||||
verbose_logger.debug( # keep in debug otherwise it will trigger on every call
|
||||
"Model={} is not mapped in model cost map. Defaulting to None model_cost_information for standard_logging_payload".format(
|
||||
model_cost_name
|
||||
)
|
||||
)
|
||||
model_cost_information = StandardLoggingModelInformation(
|
||||
model_map_key=model_cost_name, model_map_value=None
|
||||
)
|
||||
return model_cost_information
|
||||
|
||||
@staticmethod
|
||||
def get_final_response_obj(
|
||||
response_obj: dict, init_response_obj: Union[Any, BaseModel, dict], kwargs: dict
|
||||
) -> Optional[Union[dict, str, list]]:
|
||||
"""
|
||||
Get final response object after redacting the message input/output from logging
|
||||
"""
|
||||
if response_obj is not None:
|
||||
final_response_obj: Optional[Union[dict, str, list]] = response_obj
|
||||
elif isinstance(init_response_obj, list) or isinstance(init_response_obj, str):
|
||||
final_response_obj = init_response_obj
|
||||
else:
|
||||
final_response_obj = None
|
||||
|
||||
modified_final_response_obj = redact_message_input_output_from_logging(
|
||||
model_call_details=kwargs,
|
||||
result=final_response_obj,
|
||||
)
|
||||
|
||||
if modified_final_response_obj is not None and isinstance(
|
||||
modified_final_response_obj, BaseModel
|
||||
):
|
||||
final_response_obj = modified_final_response_obj.model_dump()
|
||||
else:
|
||||
final_response_obj = modified_final_response_obj
|
||||
|
||||
return final_response_obj
|
||||
|
||||
|
||||
def get_standard_logging_object_payload(
|
||||
kwargs: Optional[dict],
|
||||
init_response_obj: Union[Any, BaseModel, dict],
|
||||
start_time: dt_object,
|
||||
|
@ -2502,9 +2679,9 @@ def get_standard_logging_object_payload( # noqa: PLR0915
|
|||
completion_start_time = kwargs.get("completion_start_time", end_time)
|
||||
call_type = kwargs.get("call_type")
|
||||
cache_hit = kwargs.get("cache_hit", False)
|
||||
usage = response_obj.get("usage", None) or {}
|
||||
if type(usage) is litellm.Usage:
|
||||
usage = dict(usage)
|
||||
usage = StandardLoggingPayloadSetup.get_usage_from_response_obj(
|
||||
response_obj=response_obj
|
||||
)
|
||||
id = response_obj.get("id", kwargs.get("litellm_call_id"))
|
||||
|
||||
_model_id = metadata.get("model_info", {}).get("id", "")
|
||||
|
@ -2517,20 +2694,13 @@ def get_standard_logging_object_payload( # noqa: PLR0915
|
|||
)
|
||||
|
||||
# cleanup timestamps
|
||||
if isinstance(start_time, datetime.datetime):
|
||||
start_time_float = start_time.timestamp()
|
||||
elif isinstance(start_time, float):
|
||||
start_time_float = start_time
|
||||
if isinstance(end_time, datetime.datetime):
|
||||
end_time_float = end_time.timestamp()
|
||||
elif isinstance(end_time, float):
|
||||
end_time_float = end_time
|
||||
if isinstance(completion_start_time, datetime.datetime):
|
||||
completion_start_time_float = completion_start_time.timestamp()
|
||||
elif isinstance(completion_start_time, float):
|
||||
completion_start_time_float = completion_start_time
|
||||
else:
|
||||
completion_start_time_float = end_time_float
|
||||
start_time_float, end_time_float, completion_start_time_float = (
|
||||
StandardLoggingPayloadSetup.cleanup_timestamps(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
completion_start_time=completion_start_time,
|
||||
)
|
||||
)
|
||||
# clean up litellm hidden params
|
||||
clean_hidden_params = StandardLoggingHiddenParams(
|
||||
model_id=None,
|
||||
|
@ -2548,7 +2718,9 @@ def get_standard_logging_object_payload( # noqa: PLR0915
|
|||
}
|
||||
)
|
||||
# clean up litellm metadata
|
||||
clean_metadata = get_standard_logging_metadata(metadata=metadata)
|
||||
clean_metadata = StandardLoggingPayloadSetup.get_standard_logging_metadata(
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
if litellm.cache is not None:
|
||||
cache_key = litellm.cache.get_cache_key(**kwargs)
|
||||
|
@ -2570,58 +2742,21 @@ def get_standard_logging_object_payload( # noqa: PLR0915
|
|||
## Get model cost information ##
|
||||
base_model = _get_base_model_from_metadata(model_call_details=kwargs)
|
||||
custom_pricing = use_custom_pricing_for_model(litellm_params=litellm_params)
|
||||
model_cost_name = _select_model_name_for_cost_calc(
|
||||
model=None,
|
||||
completion_response=init_response_obj, # type: ignore
|
||||
model_cost_information = StandardLoggingPayloadSetup.get_model_cost_information(
|
||||
base_model=base_model,
|
||||
custom_pricing=custom_pricing,
|
||||
custom_llm_provider=kwargs.get("custom_llm_provider"),
|
||||
init_response_obj=init_response_obj,
|
||||
)
|
||||
if model_cost_name is None:
|
||||
model_cost_information = StandardLoggingModelInformation(
|
||||
model_map_key="", model_map_value=None
|
||||
)
|
||||
else:
|
||||
custom_llm_provider = kwargs.get("custom_llm_provider", None)
|
||||
|
||||
try:
|
||||
_model_cost_information = litellm.get_model_info(
|
||||
model=model_cost_name, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
model_cost_information = StandardLoggingModelInformation(
|
||||
model_map_key=model_cost_name,
|
||||
model_map_value=_model_cost_information,
|
||||
)
|
||||
except Exception:
|
||||
verbose_logger.debug( # keep in debug otherwise it will trigger on every call
|
||||
"Model={} is not mapped in model cost map. Defaulting to None model_cost_information for standard_logging_payload".format(
|
||||
model_cost_name
|
||||
)
|
||||
)
|
||||
model_cost_information = StandardLoggingModelInformation(
|
||||
model_map_key=model_cost_name, model_map_value=None
|
||||
)
|
||||
|
||||
response_cost: float = kwargs.get("response_cost", 0) or 0.0
|
||||
|
||||
if response_obj is not None:
|
||||
final_response_obj: Optional[Union[dict, str, list]] = response_obj
|
||||
elif isinstance(init_response_obj, list) or isinstance(init_response_obj, str):
|
||||
final_response_obj = init_response_obj
|
||||
else:
|
||||
final_response_obj = None
|
||||
|
||||
modified_final_response_obj = redact_message_input_output_from_logging(
|
||||
model_call_details=kwargs,
|
||||
result=final_response_obj,
|
||||
## get final response object ##
|
||||
final_response_obj = StandardLoggingPayloadSetup.get_final_response_obj(
|
||||
response_obj=response_obj,
|
||||
init_response_obj=init_response_obj,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
if modified_final_response_obj is not None and isinstance(
|
||||
modified_final_response_obj, BaseModel
|
||||
):
|
||||
final_response_obj = modified_final_response_obj.model_dump()
|
||||
else:
|
||||
final_response_obj = modified_final_response_obj
|
||||
|
||||
payload: StandardLoggingPayload = StandardLoggingPayload(
|
||||
id=str(id),
|
||||
call_type=call_type or "",
|
||||
|
@ -2635,9 +2770,9 @@ def get_standard_logging_object_payload( # noqa: PLR0915
|
|||
metadata=clean_metadata,
|
||||
cache_key=cache_key,
|
||||
response_cost=response_cost,
|
||||
total_tokens=usage.get("total_tokens", 0),
|
||||
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||
completion_tokens=usage.get("completion_tokens", 0),
|
||||
total_tokens=usage.total_tokens,
|
||||
prompt_tokens=usage.prompt_tokens,
|
||||
completion_tokens=usage.completion_tokens,
|
||||
request_tags=request_tags,
|
||||
end_user=end_user_id or "",
|
||||
api_base=litellm_params.get("api_base", ""),
|
||||
|
@ -2663,51 +2798,6 @@ def get_standard_logging_object_payload( # noqa: PLR0915
|
|||
return None
|
||||
|
||||
|
||||
def get_standard_logging_metadata(
|
||||
metadata: Optional[Dict[str, Any]]
|
||||
) -> StandardLoggingMetadata:
|
||||
"""
|
||||
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
|
||||
|
||||
Args:
|
||||
metadata (Optional[Dict[str, Any]]): The original metadata dictionary.
|
||||
|
||||
Returns:
|
||||
StandardLoggingMetadata: A StandardLoggingMetadata object containing the cleaned metadata.
|
||||
|
||||
Note:
|
||||
- If the input metadata is None or not a dictionary, an empty StandardLoggingMetadata object is returned.
|
||||
- If 'user_api_key' is present in metadata and is a valid SHA256 hash, it's stored as 'user_api_key_hash'.
|
||||
"""
|
||||
# Initialize with default values
|
||||
clean_metadata = StandardLoggingMetadata(
|
||||
user_api_key_hash=None,
|
||||
user_api_key_alias=None,
|
||||
user_api_key_team_id=None,
|
||||
user_api_key_user_id=None,
|
||||
user_api_key_team_alias=None,
|
||||
spend_logs_metadata=None,
|
||||
requester_ip_address=None,
|
||||
requester_metadata=None,
|
||||
)
|
||||
if isinstance(metadata, dict):
|
||||
# Filter the metadata dictionary to include only the specified keys
|
||||
clean_metadata = StandardLoggingMetadata(
|
||||
**{ # type: ignore
|
||||
key: metadata[key]
|
||||
for key in StandardLoggingMetadata.__annotations__.keys()
|
||||
if key in metadata
|
||||
}
|
||||
)
|
||||
|
||||
if metadata.get("user_api_key") is not None:
|
||||
if is_valid_sha256_hash(str(metadata.get("user_api_key"))):
|
||||
clean_metadata["user_api_key_hash"] = metadata.get(
|
||||
"user_api_key"
|
||||
) # this is the hash
|
||||
return clean_metadata
|
||||
|
||||
|
||||
def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]):
|
||||
if litellm_params is None:
|
||||
litellm_params = {}
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
"""
|
||||
Unit tests for StandardLoggingPayloadSetup
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from pydantic.main import Model
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system-path
|
||||
|
||||
import pytest
|
||||
import litellm
|
||||
from litellm.types.utils import Usage
|
||||
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"response_obj,expected_values",
|
||||
[
|
||||
# Test None input
|
||||
(None, (0, 0, 0)),
|
||||
# Test empty dict
|
||||
({}, (0, 0, 0)),
|
||||
# Test valid usage dict
|
||||
(
|
||||
{
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 20,
|
||||
"total_tokens": 30,
|
||||
}
|
||||
},
|
||||
(10, 20, 30),
|
||||
),
|
||||
# Test with litellm.Usage object
|
||||
(
|
||||
{"usage": Usage(prompt_tokens=15, completion_tokens=25, total_tokens=40)},
|
||||
(15, 25, 40),
|
||||
),
|
||||
# Test invalid usage type
|
||||
({"usage": "invalid"}, (0, 0, 0)),
|
||||
# Test None usage
|
||||
({"usage": None}, (0, 0, 0)),
|
||||
],
|
||||
)
|
||||
def test_get_usage(response_obj, expected_values):
|
||||
"""
|
||||
Make sure values returned from get_usage are always integers
|
||||
"""
|
||||
|
||||
usage = StandardLoggingPayloadSetup.get_usage_from_response_obj(response_obj)
|
||||
|
||||
# Check types
|
||||
assert isinstance(usage.prompt_tokens, int)
|
||||
assert isinstance(usage.completion_tokens, int)
|
||||
assert isinstance(usage.total_tokens, int)
|
||||
|
||||
# Check values
|
||||
assert usage.prompt_tokens == expected_values[0]
|
||||
assert usage.completion_tokens == expected_values[1]
|
||||
assert usage.total_tokens == expected_values[2]
|
Loading…
Add table
Add a link
Reference in a new issue