(QOL improvement) add unit testing for all static_methods in litellm_logging.py (#6640)

* add unit testing for standard logging payload

* unit testing for static methods in litellm_logging

* add code coverage check for litellm_logging

* litellm_logging_code_coverage

* test_get_final_response_obj

* fix validate_redacted_message_span_attributes

* test validate_redacted_message_span_attributes
This commit is contained in:
Ishaan Jaff 2024-11-07 16:26:53 -08:00 committed by GitHub
parent 6e4a9bb3b7
commit ae385cfcdc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 334 additions and 9 deletions

View file

@ -722,6 +722,7 @@ jobs:
- run: python ./tests/documentation_tests/test_general_setting_keys.py
- run: python ./tests/code_coverage_tests/router_code_coverage.py
- run: python ./tests/code_coverage_tests/test_router_strategy_async.py
- run: python ./tests/code_coverage_tests/litellm_logging_code_coverage.py
- run: python ./tests/documentation_tests/test_env_keys.py
- run: helm lint ./deploy/charts/litellm-helm

View file

@ -2474,6 +2474,14 @@ class StandardLoggingPayloadSetup:
) -> Tuple[float, float, float]:
"""
Convert datetime objects to floats
Args:
start_time: Union[dt_object, float]
end_time: Union[dt_object, float]
completion_start_time: Union[dt_object, float]
Returns:
Tuple[float, float, float]: A tuple containing the start time, end time, and completion start time as floats.
"""
if isinstance(start_time, datetime.datetime):
@ -2534,13 +2542,10 @@ class StandardLoggingPayloadSetup:
)
if isinstance(metadata, dict):
# Filter the metadata dictionary to include only the specified keys
clean_metadata = StandardLoggingMetadata(
**{ # type: ignore
key: metadata[key]
for key in StandardLoggingMetadata.__annotations__.keys()
if key in metadata
}
)
supported_keys = StandardLoggingMetadata.__annotations__.keys()
for key in supported_keys:
if key in metadata:
clean_metadata[key] = metadata[key] # type: ignore
if metadata.get("user_api_key") is not None:
if is_valid_sha256_hash(str(metadata.get("user_api_key"))):

View file

@ -0,0 +1,95 @@
import ast
import os
from typing import List
def get_function_names_from_file(file_path: str) -> List[str]:
"""
Extracts all static method names from litellm_logging.py
"""
with open(file_path, "r") as file:
tree = ast.parse(file.read())
function_names = []
for node in tree.body:
if isinstance(node, ast.ClassDef):
# Functions inside classes
for class_node in node.body:
if isinstance(class_node, (ast.FunctionDef, ast.AsyncFunctionDef)):
# Check if the function has @staticmethod decorator
for decorator in class_node.decorator_list:
if (
isinstance(decorator, ast.Name)
and decorator.id == "staticmethod"
):
function_names.append(class_node.name)
return function_names
def get_all_functions_called_in_tests(base_dir: str) -> set:
"""
Returns a set of function names that are called in test functions
inside test files containing the word 'logging'.
"""
called_functions = set()
for root, _, files in os.walk(base_dir):
for file in files:
if file.endswith(".py") and "logging" in file.lower():
file_path = os.path.join(root, file)
with open(file_path, "r") as f:
try:
tree = ast.parse(f.read())
except SyntaxError:
print(f"Warning: Syntax error in file {file_path}")
continue
for node in ast.walk(tree):
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
called_functions.add(node.func.id)
elif isinstance(node.func, ast.Attribute):
called_functions.add(node.func.attr)
return called_functions
# Functions that can be ignored in test coverage
ignored_function_names = [
"__init__",
# Add other functions to ignore here
]
def main():
logging_file = "./litellm/litellm_core_utils/litellm_logging.py"
tests_dir = "./tests/"
# LOCAL TESTING
# logging_file = "../../litellm/litellm_core_utils/litellm_logging.py"
# tests_dir = "../../tests/"
logging_functions = get_function_names_from_file(logging_file)
print("logging_functions:", logging_functions)
called_functions_in_tests = get_all_functions_called_in_tests(tests_dir)
untested_functions = [
fn
for fn in logging_functions
if fn not in called_functions_in_tests and fn not in ignored_function_names
]
if untested_functions:
untested_perc = len(untested_functions) / len(logging_functions)
print(f"untested_percentage: {untested_perc * 100:.2f}%")
raise Exception(
f"{untested_perc * 100:.2f}% of functions in litellm_logging.py are not tested: {untested_functions}"
)
else:
print("All functions in litellm_logging.py are covered by tests.")
if __name__ == "__main__":
main()

View file

@ -260,6 +260,15 @@ def validate_redacted_message_span_attributes(span):
"llm.usage.total_tokens",
"gen_ai.usage.completion_tokens",
"gen_ai.usage.prompt_tokens",
"metadata.user_api_key_hash",
"metadata.requester_ip_address",
"metadata.user_api_key_team_alias",
"metadata.requester_metadata",
"metadata.user_api_key_team_id",
"metadata.spend_logs_metadata",
"metadata.user_api_key_alias",
"metadata.user_api_key_user_id",
"metadata.user_api_key_org_id",
]
_all_attributes = set([name for name in span.attributes.keys()])

View file

@ -13,10 +13,16 @@ from pydantic.main import Model
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system-path
from datetime import datetime as dt_object
import time
import pytest
import litellm
from litellm.types.utils import Usage
from litellm.types.utils import (
Usage,
StandardLoggingMetadata,
StandardLoggingModelInformation,
StandardLoggingHiddenParams,
)
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
@ -104,3 +110,212 @@ def test_get_additional_headers():
"x_ratelimit_limit_tokens": 160000,
"x_ratelimit_remaining_tokens": 160000,
}
def all_fields_present(standard_logging_metadata: StandardLoggingMetadata):
for field in StandardLoggingMetadata.__annotations__.keys():
assert field in standard_logging_metadata
@pytest.mark.parametrize(
"metadata_key, metadata_value",
[
("user_api_key_alias", "test_alias"),
("user_api_key_hash", "test_hash"),
("user_api_key_team_id", "test_team_id"),
("user_api_key_user_id", "test_user_id"),
("user_api_key_team_alias", "test_team_alias"),
("spend_logs_metadata", {"key": "value"}),
("requester_ip_address", "127.0.0.1"),
("requester_metadata", {"user_agent": "test_agent"}),
],
)
def test_get_standard_logging_metadata(metadata_key, metadata_value):
"""
Test that the get_standard_logging_metadata function correctly sets the metadata fields.
All fields in StandardLoggingMetadata should ALWAYS be present.
"""
metadata = {metadata_key: metadata_value}
standard_logging_metadata = (
StandardLoggingPayloadSetup.get_standard_logging_metadata(metadata)
)
print("standard_logging_metadata", standard_logging_metadata)
# Assert that all fields in StandardLoggingMetadata are present
all_fields_present(standard_logging_metadata)
# Assert that the specific metadata field is set correctly
assert standard_logging_metadata[metadata_key] == metadata_value
def test_get_standard_logging_metadata_user_api_key_hash():
valid_hash = "a" * 64 # 64 character string
metadata = {"user_api_key": valid_hash}
result = StandardLoggingPayloadSetup.get_standard_logging_metadata(metadata)
assert result["user_api_key_hash"] == valid_hash
def test_get_standard_logging_metadata_invalid_user_api_key():
invalid_hash = "not_a_valid_hash"
metadata = {"user_api_key": invalid_hash}
result = StandardLoggingPayloadSetup.get_standard_logging_metadata(metadata)
all_fields_present(result)
assert result["user_api_key_hash"] is None
def test_get_standard_logging_metadata_invalid_keys():
metadata = {
"user_api_key_alias": "test_alias",
"invalid_key": "should_be_ignored",
"another_invalid_key": 123,
}
result = StandardLoggingPayloadSetup.get_standard_logging_metadata(metadata)
all_fields_present(result)
assert result["user_api_key_alias"] == "test_alias"
assert "invalid_key" not in result
assert "another_invalid_key" not in result
def test_cleanup_timestamps():
"""Test cleanup_timestamps with different input types"""
# Test with datetime objects
now = dt_object.now()
start = now
end = now
completion = now
result = StandardLoggingPayloadSetup.cleanup_timestamps(start, end, completion)
assert all(isinstance(x, float) for x in result)
assert len(result) == 3
# Test with float timestamps
start_float = time.time()
end_float = start_float + 1
completion_float = end_float
result = StandardLoggingPayloadSetup.cleanup_timestamps(
start_float, end_float, completion_float
)
assert all(isinstance(x, float) for x in result)
assert result[0] == start_float
assert result[1] == end_float
assert result[2] == completion_float
# Test with mixed types
result = StandardLoggingPayloadSetup.cleanup_timestamps(
start_float, end, completion_float
)
assert all(isinstance(x, float) for x in result)
# Test invalid input
with pytest.raises(ValueError):
StandardLoggingPayloadSetup.cleanup_timestamps(
"invalid", end_float, completion_float
)
def test_get_model_cost_information():
"""Test get_model_cost_information with different inputs"""
# Test with None values
result = StandardLoggingPayloadSetup.get_model_cost_information(
base_model=None,
custom_pricing=None,
custom_llm_provider=None,
init_response_obj={},
)
assert result["model_map_key"] == ""
assert result["model_map_value"] is None # this was not found in model cost map
# assert all fields in StandardLoggingModelInformation are present
assert all(
field in result for field in StandardLoggingModelInformation.__annotations__
)
# Test with valid model
result = StandardLoggingPayloadSetup.get_model_cost_information(
base_model="gpt-3.5-turbo",
custom_pricing=False,
custom_llm_provider="openai",
init_response_obj={},
)
litellm_info_gpt_3_5_turbo_model_map_value = litellm.get_model_info(
model="gpt-3.5-turbo", custom_llm_provider="openai"
)
print("result", result)
assert result["model_map_key"] == "gpt-3.5-turbo"
assert result["model_map_value"] is not None
assert result["model_map_value"] == litellm_info_gpt_3_5_turbo_model_map_value
# assert all fields in StandardLoggingModelInformation are present
assert all(
field in result for field in StandardLoggingModelInformation.__annotations__
)
def test_get_hidden_params():
"""Test get_hidden_params with different inputs"""
# Test with None
result = StandardLoggingPayloadSetup.get_hidden_params(None)
assert result["model_id"] is None
assert result["cache_key"] is None
assert result["api_base"] is None
assert result["response_cost"] is None
assert result["additional_headers"] is None
# assert all fields in StandardLoggingHiddenParams are present
assert all(field in result for field in StandardLoggingHiddenParams.__annotations__)
# Test with valid params
hidden_params = {
"model_id": "test-model",
"cache_key": "test-cache",
"api_base": "https://api.test.com",
"response_cost": 0.001,
"additional_headers": {
"x-ratelimit-limit-requests": "2000",
"x-ratelimit-remaining-requests": "1999",
},
}
result = StandardLoggingPayloadSetup.get_hidden_params(hidden_params)
assert result["model_id"] == "test-model"
assert result["cache_key"] == "test-cache"
assert result["api_base"] == "https://api.test.com"
assert result["response_cost"] == 0.001
assert result["additional_headers"] is not None
assert result["additional_headers"]["x_ratelimit_limit_requests"] == 2000
# assert all fields in StandardLoggingHiddenParams are present
assert all(field in result for field in StandardLoggingHiddenParams.__annotations__)
def test_get_final_response_obj():
"""Test get_final_response_obj with different input types and redaction scenarios"""
# Test with direct response_obj
response_obj = {"choices": [{"message": {"content": "test content"}}]}
result = StandardLoggingPayloadSetup.get_final_response_obj(
response_obj=response_obj, init_response_obj=None, kwargs={}
)
assert result == response_obj
# Test redaction when litellm.turn_off_message_logging is True
litellm.turn_off_message_logging = True
try:
model_response = litellm.ModelResponse(
choices=[
litellm.Choices(message=litellm.Message(content="sensitive content"))
]
)
kwargs = {"messages": [{"role": "user", "content": "original message"}]}
result = StandardLoggingPayloadSetup.get_final_response_obj(
response_obj=model_response, init_response_obj=model_response, kwargs=kwargs
)
print("result", result)
print("type(result)", type(result))
# Verify response message content was redacted
assert result["choices"][0]["message"]["content"] == "redacted-by-litellm"
# Verify that redaction occurred in kwargs
assert kwargs["messages"][0]["content"] == "redacted-by-litellm"
finally:
# Reset litellm.turn_off_message_logging to its original value
litellm.turn_off_message_logging = False