forked from phoenix/litellm-mirror
feat(litellm_logging.py): refactor standard_logging_payload function … (#6388)
* feat(litellm_logging.py): refactor standard_logging_payload function to be <50 LOC fixes issue where usage information was not following typed values * fix(litellm_logging.py): fix completion start time handling
This commit is contained in:
parent
d59f8f952d
commit
c04c4a82f1
3 changed files with 286 additions and 121 deletions
|
@ -346,8 +346,12 @@ class PrometheusLogger(CustomLogger):
|
||||||
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||||
"standard_logging_object"
|
"standard_logging_object"
|
||||||
)
|
)
|
||||||
if standard_logging_payload is None:
|
if standard_logging_payload is None or not isinstance(
|
||||||
raise ValueError("standard_logging_object is required")
|
standard_logging_payload, dict
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"standard_logging_object is required, got={standard_logging_payload}"
|
||||||
|
)
|
||||||
model = kwargs.get("model", "")
|
model = kwargs.get("model", "")
|
||||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||||
_metadata = litellm_params.get("metadata", {})
|
_metadata = litellm_params.get("metadata", {})
|
||||||
|
@ -991,7 +995,7 @@ class PrometheusLogger(CustomLogger):
|
||||||
"""
|
"""
|
||||||
from litellm.litellm_core_utils.litellm_logging import (
|
from litellm.litellm_core_utils.litellm_logging import (
|
||||||
StandardLoggingMetadata,
|
StandardLoggingMetadata,
|
||||||
get_standard_logging_metadata,
|
StandardLoggingPayloadSetup,
|
||||||
)
|
)
|
||||||
|
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
|
@ -1000,8 +1004,10 @@ class PrometheusLogger(CustomLogger):
|
||||||
kwargs,
|
kwargs,
|
||||||
)
|
)
|
||||||
_metadata = kwargs.get("metadata", {})
|
_metadata = kwargs.get("metadata", {})
|
||||||
standard_metadata: StandardLoggingMetadata = get_standard_logging_metadata(
|
standard_metadata: StandardLoggingMetadata = (
|
||||||
metadata=_metadata
|
StandardLoggingPayloadSetup.get_standard_logging_metadata(
|
||||||
|
metadata=_metadata
|
||||||
|
)
|
||||||
)
|
)
|
||||||
_new_model = kwargs.get("model")
|
_new_model = kwargs.get("model")
|
||||||
self.litellm_deployment_successful_fallbacks.labels(
|
self.litellm_deployment_successful_fallbacks.labels(
|
||||||
|
@ -1023,7 +1029,7 @@ class PrometheusLogger(CustomLogger):
|
||||||
"""
|
"""
|
||||||
from litellm.litellm_core_utils.litellm_logging import (
|
from litellm.litellm_core_utils.litellm_logging import (
|
||||||
StandardLoggingMetadata,
|
StandardLoggingMetadata,
|
||||||
get_standard_logging_metadata,
|
StandardLoggingPayloadSetup,
|
||||||
)
|
)
|
||||||
|
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
|
@ -1033,8 +1039,10 @@ class PrometheusLogger(CustomLogger):
|
||||||
)
|
)
|
||||||
_new_model = kwargs.get("model")
|
_new_model = kwargs.get("model")
|
||||||
_metadata = kwargs.get("metadata", {})
|
_metadata = kwargs.get("metadata", {})
|
||||||
standard_metadata: StandardLoggingMetadata = get_standard_logging_metadata(
|
standard_metadata: StandardLoggingMetadata = (
|
||||||
metadata=_metadata
|
StandardLoggingPayloadSetup.get_standard_logging_metadata(
|
||||||
|
metadata=_metadata
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.litellm_deployment_failed_fallbacks.labels(
|
self.litellm_deployment_failed_fallbacks.labels(
|
||||||
requested_model=original_model_group,
|
requested_model=original_model_group,
|
||||||
|
|
|
@ -12,7 +12,7 @@ import time
|
||||||
import traceback
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime as dt_object
|
from datetime import datetime as dt_object
|
||||||
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -51,6 +51,7 @@ from litellm.types.utils import (
|
||||||
StandardPassThroughResponseObject,
|
StandardPassThroughResponseObject,
|
||||||
TextCompletionResponse,
|
TextCompletionResponse,
|
||||||
TranscriptionResponse,
|
TranscriptionResponse,
|
||||||
|
Usage,
|
||||||
)
|
)
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
_get_base_model_from_metadata,
|
_get_base_model_from_metadata,
|
||||||
|
@ -2454,7 +2455,183 @@ def is_valid_sha256_hash(value: str) -> bool:
|
||||||
return bool(re.fullmatch(r"[a-fA-F0-9]{64}", value))
|
return bool(re.fullmatch(r"[a-fA-F0-9]{64}", value))
|
||||||
|
|
||||||
|
|
||||||
def get_standard_logging_object_payload( # noqa: PLR0915
|
class StandardLoggingPayloadSetup:
|
||||||
|
@staticmethod
|
||||||
|
def cleanup_timestamps(
|
||||||
|
start_time: Union[dt_object, float],
|
||||||
|
end_time: Union[dt_object, float],
|
||||||
|
completion_start_time: Union[dt_object, float],
|
||||||
|
) -> Tuple[float, float, float]:
|
||||||
|
"""
|
||||||
|
Convert datetime objects to floats
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(start_time, datetime.datetime):
|
||||||
|
start_time_float = start_time.timestamp()
|
||||||
|
elif isinstance(start_time, float):
|
||||||
|
start_time_float = start_time
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"start_time is required, got={start_time} of type {type(start_time)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(end_time, datetime.datetime):
|
||||||
|
end_time_float = end_time.timestamp()
|
||||||
|
elif isinstance(end_time, float):
|
||||||
|
end_time_float = end_time
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"end_time is required, got={end_time} of type {type(end_time)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(completion_start_time, datetime.datetime):
|
||||||
|
completion_start_time_float = completion_start_time.timestamp()
|
||||||
|
elif isinstance(completion_start_time, float):
|
||||||
|
completion_start_time_float = completion_start_time
|
||||||
|
else:
|
||||||
|
completion_start_time_float = end_time_float
|
||||||
|
|
||||||
|
return start_time_float, end_time_float, completion_start_time_float
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_standard_logging_metadata(
|
||||||
|
metadata: Optional[Dict[str, Any]]
|
||||||
|
) -> StandardLoggingMetadata:
|
||||||
|
"""
|
||||||
|
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata (Optional[Dict[str, Any]]): The original metadata dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StandardLoggingMetadata: A StandardLoggingMetadata object containing the cleaned metadata.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- If the input metadata is None or not a dictionary, an empty StandardLoggingMetadata object is returned.
|
||||||
|
- If 'user_api_key' is present in metadata and is a valid SHA256 hash, it's stored as 'user_api_key_hash'.
|
||||||
|
"""
|
||||||
|
# Initialize with default values
|
||||||
|
clean_metadata = StandardLoggingMetadata(
|
||||||
|
user_api_key_hash=None,
|
||||||
|
user_api_key_alias=None,
|
||||||
|
user_api_key_team_id=None,
|
||||||
|
user_api_key_user_id=None,
|
||||||
|
user_api_key_team_alias=None,
|
||||||
|
spend_logs_metadata=None,
|
||||||
|
requester_ip_address=None,
|
||||||
|
requester_metadata=None,
|
||||||
|
)
|
||||||
|
if isinstance(metadata, dict):
|
||||||
|
# Filter the metadata dictionary to include only the specified keys
|
||||||
|
clean_metadata = StandardLoggingMetadata(
|
||||||
|
**{ # type: ignore
|
||||||
|
key: metadata[key]
|
||||||
|
for key in StandardLoggingMetadata.__annotations__.keys()
|
||||||
|
if key in metadata
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if metadata.get("user_api_key") is not None:
|
||||||
|
if is_valid_sha256_hash(str(metadata.get("user_api_key"))):
|
||||||
|
clean_metadata["user_api_key_hash"] = metadata.get(
|
||||||
|
"user_api_key"
|
||||||
|
) # this is the hash
|
||||||
|
return clean_metadata
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_usage_from_response_obj(response_obj: Optional[dict]) -> Usage:
|
||||||
|
## BASE CASE ##
|
||||||
|
if response_obj is None:
|
||||||
|
return Usage(
|
||||||
|
prompt_tokens=0,
|
||||||
|
completion_tokens=0,
|
||||||
|
total_tokens=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
usage = response_obj.get("usage", None) or {}
|
||||||
|
if usage is None or (
|
||||||
|
not isinstance(usage, dict) and not isinstance(usage, Usage)
|
||||||
|
):
|
||||||
|
return Usage(
|
||||||
|
prompt_tokens=0,
|
||||||
|
completion_tokens=0,
|
||||||
|
total_tokens=0,
|
||||||
|
)
|
||||||
|
elif isinstance(usage, Usage):
|
||||||
|
return usage
|
||||||
|
elif isinstance(usage, dict):
|
||||||
|
return Usage(**usage)
|
||||||
|
|
||||||
|
raise ValueError(f"usage is required, got={usage} of type {type(usage)}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_model_cost_information(
|
||||||
|
base_model: Optional[str],
|
||||||
|
custom_pricing: Optional[bool],
|
||||||
|
custom_llm_provider: Optional[str],
|
||||||
|
init_response_obj: Union[Any, BaseModel, dict],
|
||||||
|
) -> StandardLoggingModelInformation:
|
||||||
|
|
||||||
|
model_cost_name = _select_model_name_for_cost_calc(
|
||||||
|
model=None,
|
||||||
|
completion_response=init_response_obj, # type: ignore
|
||||||
|
base_model=base_model,
|
||||||
|
custom_pricing=custom_pricing,
|
||||||
|
)
|
||||||
|
if model_cost_name is None:
|
||||||
|
model_cost_information = StandardLoggingModelInformation(
|
||||||
|
model_map_key="", model_map_value=None
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
_model_cost_information = litellm.get_model_info(
|
||||||
|
model=model_cost_name, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
model_cost_information = StandardLoggingModelInformation(
|
||||||
|
model_map_key=model_cost_name,
|
||||||
|
model_map_value=_model_cost_information,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
verbose_logger.debug( # keep in debug otherwise it will trigger on every call
|
||||||
|
"Model={} is not mapped in model cost map. Defaulting to None model_cost_information for standard_logging_payload".format(
|
||||||
|
model_cost_name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
model_cost_information = StandardLoggingModelInformation(
|
||||||
|
model_map_key=model_cost_name, model_map_value=None
|
||||||
|
)
|
||||||
|
return model_cost_information
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_final_response_obj(
|
||||||
|
response_obj: dict, init_response_obj: Union[Any, BaseModel, dict], kwargs: dict
|
||||||
|
) -> Optional[Union[dict, str, list]]:
|
||||||
|
"""
|
||||||
|
Get final response object after redacting the message input/output from logging
|
||||||
|
"""
|
||||||
|
if response_obj is not None:
|
||||||
|
final_response_obj: Optional[Union[dict, str, list]] = response_obj
|
||||||
|
elif isinstance(init_response_obj, list) or isinstance(init_response_obj, str):
|
||||||
|
final_response_obj = init_response_obj
|
||||||
|
else:
|
||||||
|
final_response_obj = None
|
||||||
|
|
||||||
|
modified_final_response_obj = redact_message_input_output_from_logging(
|
||||||
|
model_call_details=kwargs,
|
||||||
|
result=final_response_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
if modified_final_response_obj is not None and isinstance(
|
||||||
|
modified_final_response_obj, BaseModel
|
||||||
|
):
|
||||||
|
final_response_obj = modified_final_response_obj.model_dump()
|
||||||
|
else:
|
||||||
|
final_response_obj = modified_final_response_obj
|
||||||
|
|
||||||
|
return final_response_obj
|
||||||
|
|
||||||
|
|
||||||
|
def get_standard_logging_object_payload(
|
||||||
kwargs: Optional[dict],
|
kwargs: Optional[dict],
|
||||||
init_response_obj: Union[Any, BaseModel, dict],
|
init_response_obj: Union[Any, BaseModel, dict],
|
||||||
start_time: dt_object,
|
start_time: dt_object,
|
||||||
|
@ -2502,9 +2679,9 @@ def get_standard_logging_object_payload( # noqa: PLR0915
|
||||||
completion_start_time = kwargs.get("completion_start_time", end_time)
|
completion_start_time = kwargs.get("completion_start_time", end_time)
|
||||||
call_type = kwargs.get("call_type")
|
call_type = kwargs.get("call_type")
|
||||||
cache_hit = kwargs.get("cache_hit", False)
|
cache_hit = kwargs.get("cache_hit", False)
|
||||||
usage = response_obj.get("usage", None) or {}
|
usage = StandardLoggingPayloadSetup.get_usage_from_response_obj(
|
||||||
if type(usage) is litellm.Usage:
|
response_obj=response_obj
|
||||||
usage = dict(usage)
|
)
|
||||||
id = response_obj.get("id", kwargs.get("litellm_call_id"))
|
id = response_obj.get("id", kwargs.get("litellm_call_id"))
|
||||||
|
|
||||||
_model_id = metadata.get("model_info", {}).get("id", "")
|
_model_id = metadata.get("model_info", {}).get("id", "")
|
||||||
|
@ -2517,20 +2694,13 @@ def get_standard_logging_object_payload( # noqa: PLR0915
|
||||||
)
|
)
|
||||||
|
|
||||||
# cleanup timestamps
|
# cleanup timestamps
|
||||||
if isinstance(start_time, datetime.datetime):
|
start_time_float, end_time_float, completion_start_time_float = (
|
||||||
start_time_float = start_time.timestamp()
|
StandardLoggingPayloadSetup.cleanup_timestamps(
|
||||||
elif isinstance(start_time, float):
|
start_time=start_time,
|
||||||
start_time_float = start_time
|
end_time=end_time,
|
||||||
if isinstance(end_time, datetime.datetime):
|
completion_start_time=completion_start_time,
|
||||||
end_time_float = end_time.timestamp()
|
)
|
||||||
elif isinstance(end_time, float):
|
)
|
||||||
end_time_float = end_time
|
|
||||||
if isinstance(completion_start_time, datetime.datetime):
|
|
||||||
completion_start_time_float = completion_start_time.timestamp()
|
|
||||||
elif isinstance(completion_start_time, float):
|
|
||||||
completion_start_time_float = completion_start_time
|
|
||||||
else:
|
|
||||||
completion_start_time_float = end_time_float
|
|
||||||
# clean up litellm hidden params
|
# clean up litellm hidden params
|
||||||
clean_hidden_params = StandardLoggingHiddenParams(
|
clean_hidden_params = StandardLoggingHiddenParams(
|
||||||
model_id=None,
|
model_id=None,
|
||||||
|
@ -2548,7 +2718,9 @@ def get_standard_logging_object_payload( # noqa: PLR0915
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
# clean up litellm metadata
|
# clean up litellm metadata
|
||||||
clean_metadata = get_standard_logging_metadata(metadata=metadata)
|
clean_metadata = StandardLoggingPayloadSetup.get_standard_logging_metadata(
|
||||||
|
metadata=metadata
|
||||||
|
)
|
||||||
|
|
||||||
if litellm.cache is not None:
|
if litellm.cache is not None:
|
||||||
cache_key = litellm.cache.get_cache_key(**kwargs)
|
cache_key = litellm.cache.get_cache_key(**kwargs)
|
||||||
|
@ -2570,58 +2742,21 @@ def get_standard_logging_object_payload( # noqa: PLR0915
|
||||||
## Get model cost information ##
|
## Get model cost information ##
|
||||||
base_model = _get_base_model_from_metadata(model_call_details=kwargs)
|
base_model = _get_base_model_from_metadata(model_call_details=kwargs)
|
||||||
custom_pricing = use_custom_pricing_for_model(litellm_params=litellm_params)
|
custom_pricing = use_custom_pricing_for_model(litellm_params=litellm_params)
|
||||||
model_cost_name = _select_model_name_for_cost_calc(
|
model_cost_information = StandardLoggingPayloadSetup.get_model_cost_information(
|
||||||
model=None,
|
|
||||||
completion_response=init_response_obj, # type: ignore
|
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
custom_pricing=custom_pricing,
|
custom_pricing=custom_pricing,
|
||||||
|
custom_llm_provider=kwargs.get("custom_llm_provider"),
|
||||||
|
init_response_obj=init_response_obj,
|
||||||
)
|
)
|
||||||
if model_cost_name is None:
|
|
||||||
model_cost_information = StandardLoggingModelInformation(
|
|
||||||
model_map_key="", model_map_value=None
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
custom_llm_provider = kwargs.get("custom_llm_provider", None)
|
|
||||||
|
|
||||||
try:
|
|
||||||
_model_cost_information = litellm.get_model_info(
|
|
||||||
model=model_cost_name, custom_llm_provider=custom_llm_provider
|
|
||||||
)
|
|
||||||
model_cost_information = StandardLoggingModelInformation(
|
|
||||||
model_map_key=model_cost_name,
|
|
||||||
model_map_value=_model_cost_information,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
verbose_logger.debug( # keep in debug otherwise it will trigger on every call
|
|
||||||
"Model={} is not mapped in model cost map. Defaulting to None model_cost_information for standard_logging_payload".format(
|
|
||||||
model_cost_name
|
|
||||||
)
|
|
||||||
)
|
|
||||||
model_cost_information = StandardLoggingModelInformation(
|
|
||||||
model_map_key=model_cost_name, model_map_value=None
|
|
||||||
)
|
|
||||||
|
|
||||||
response_cost: float = kwargs.get("response_cost", 0) or 0.0
|
response_cost: float = kwargs.get("response_cost", 0) or 0.0
|
||||||
|
|
||||||
if response_obj is not None:
|
## get final response object ##
|
||||||
final_response_obj: Optional[Union[dict, str, list]] = response_obj
|
final_response_obj = StandardLoggingPayloadSetup.get_final_response_obj(
|
||||||
elif isinstance(init_response_obj, list) or isinstance(init_response_obj, str):
|
response_obj=response_obj,
|
||||||
final_response_obj = init_response_obj
|
init_response_obj=init_response_obj,
|
||||||
else:
|
kwargs=kwargs,
|
||||||
final_response_obj = None
|
|
||||||
|
|
||||||
modified_final_response_obj = redact_message_input_output_from_logging(
|
|
||||||
model_call_details=kwargs,
|
|
||||||
result=final_response_obj,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if modified_final_response_obj is not None and isinstance(
|
|
||||||
modified_final_response_obj, BaseModel
|
|
||||||
):
|
|
||||||
final_response_obj = modified_final_response_obj.model_dump()
|
|
||||||
else:
|
|
||||||
final_response_obj = modified_final_response_obj
|
|
||||||
|
|
||||||
payload: StandardLoggingPayload = StandardLoggingPayload(
|
payload: StandardLoggingPayload = StandardLoggingPayload(
|
||||||
id=str(id),
|
id=str(id),
|
||||||
call_type=call_type or "",
|
call_type=call_type or "",
|
||||||
|
@ -2635,9 +2770,9 @@ def get_standard_logging_object_payload( # noqa: PLR0915
|
||||||
metadata=clean_metadata,
|
metadata=clean_metadata,
|
||||||
cache_key=cache_key,
|
cache_key=cache_key,
|
||||||
response_cost=response_cost,
|
response_cost=response_cost,
|
||||||
total_tokens=usage.get("total_tokens", 0),
|
total_tokens=usage.total_tokens,
|
||||||
prompt_tokens=usage.get("prompt_tokens", 0),
|
prompt_tokens=usage.prompt_tokens,
|
||||||
completion_tokens=usage.get("completion_tokens", 0),
|
completion_tokens=usage.completion_tokens,
|
||||||
request_tags=request_tags,
|
request_tags=request_tags,
|
||||||
end_user=end_user_id or "",
|
end_user=end_user_id or "",
|
||||||
api_base=litellm_params.get("api_base", ""),
|
api_base=litellm_params.get("api_base", ""),
|
||||||
|
@ -2663,51 +2798,6 @@ def get_standard_logging_object_payload( # noqa: PLR0915
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_standard_logging_metadata(
|
|
||||||
metadata: Optional[Dict[str, Any]]
|
|
||||||
) -> StandardLoggingMetadata:
|
|
||||||
"""
|
|
||||||
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
metadata (Optional[Dict[str, Any]]): The original metadata dictionary.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
StandardLoggingMetadata: A StandardLoggingMetadata object containing the cleaned metadata.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
- If the input metadata is None or not a dictionary, an empty StandardLoggingMetadata object is returned.
|
|
||||||
- If 'user_api_key' is present in metadata and is a valid SHA256 hash, it's stored as 'user_api_key_hash'.
|
|
||||||
"""
|
|
||||||
# Initialize with default values
|
|
||||||
clean_metadata = StandardLoggingMetadata(
|
|
||||||
user_api_key_hash=None,
|
|
||||||
user_api_key_alias=None,
|
|
||||||
user_api_key_team_id=None,
|
|
||||||
user_api_key_user_id=None,
|
|
||||||
user_api_key_team_alias=None,
|
|
||||||
spend_logs_metadata=None,
|
|
||||||
requester_ip_address=None,
|
|
||||||
requester_metadata=None,
|
|
||||||
)
|
|
||||||
if isinstance(metadata, dict):
|
|
||||||
# Filter the metadata dictionary to include only the specified keys
|
|
||||||
clean_metadata = StandardLoggingMetadata(
|
|
||||||
**{ # type: ignore
|
|
||||||
key: metadata[key]
|
|
||||||
for key in StandardLoggingMetadata.__annotations__.keys()
|
|
||||||
if key in metadata
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if metadata.get("user_api_key") is not None:
|
|
||||||
if is_valid_sha256_hash(str(metadata.get("user_api_key"))):
|
|
||||||
clean_metadata["user_api_key_hash"] = metadata.get(
|
|
||||||
"user_api_key"
|
|
||||||
) # this is the hash
|
|
||||||
return clean_metadata
|
|
||||||
|
|
||||||
|
|
||||||
def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]):
|
def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]):
|
||||||
if litellm_params is None:
|
if litellm_params is None:
|
||||||
litellm_params = {}
|
litellm_params = {}
|
||||||
|
|
|
@ -0,0 +1,67 @@
|
||||||
|
"""
|
||||||
|
Unit tests for StandardLoggingPayloadSetup
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
from pydantic.main import Model
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system-path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import litellm
|
||||||
|
from litellm.types.utils import Usage
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"response_obj,expected_values",
|
||||||
|
[
|
||||||
|
# Test None input
|
||||||
|
(None, (0, 0, 0)),
|
||||||
|
# Test empty dict
|
||||||
|
({}, (0, 0, 0)),
|
||||||
|
# Test valid usage dict
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 10,
|
||||||
|
"completion_tokens": 20,
|
||||||
|
"total_tokens": 30,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
(10, 20, 30),
|
||||||
|
),
|
||||||
|
# Test with litellm.Usage object
|
||||||
|
(
|
||||||
|
{"usage": Usage(prompt_tokens=15, completion_tokens=25, total_tokens=40)},
|
||||||
|
(15, 25, 40),
|
||||||
|
),
|
||||||
|
# Test invalid usage type
|
||||||
|
({"usage": "invalid"}, (0, 0, 0)),
|
||||||
|
# Test None usage
|
||||||
|
({"usage": None}, (0, 0, 0)),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_get_usage(response_obj, expected_values):
|
||||||
|
"""
|
||||||
|
Make sure values returned from get_usage are always integers
|
||||||
|
"""
|
||||||
|
|
||||||
|
usage = StandardLoggingPayloadSetup.get_usage_from_response_obj(response_obj)
|
||||||
|
|
||||||
|
# Check types
|
||||||
|
assert isinstance(usage.prompt_tokens, int)
|
||||||
|
assert isinstance(usage.completion_tokens, int)
|
||||||
|
assert isinstance(usage.total_tokens, int)
|
||||||
|
|
||||||
|
# Check values
|
||||||
|
assert usage.prompt_tokens == expected_values[0]
|
||||||
|
assert usage.completion_tokens == expected_values[1]
|
||||||
|
assert usage.total_tokens == expected_values[2]
|
Loading…
Add table
Add a link
Reference in a new issue