forked from phoenix/litellm-mirror
(QOL improvement) add unit testing for all static_methods in litellm_logging.py (#6640)
* add unit testing for standard logging payload * unit testing for static methods in litellm_logging * add code coverage check for litellm_logging * litellm_logging_code_coverage * test_get_final_response_obj * fix validate_redacted_message_span_attributes * test validate_redacted_message_span_attributes
This commit is contained in:
parent
6e4a9bb3b7
commit
ae385cfcdc
5 changed files with 334 additions and 9 deletions
|
@ -722,6 +722,7 @@ jobs:
|
|||
- run: python ./tests/documentation_tests/test_general_setting_keys.py
|
||||
- run: python ./tests/code_coverage_tests/router_code_coverage.py
|
||||
- run: python ./tests/code_coverage_tests/test_router_strategy_async.py
|
||||
- run: python ./tests/code_coverage_tests/litellm_logging_code_coverage.py
|
||||
- run: python ./tests/documentation_tests/test_env_keys.py
|
||||
- run: helm lint ./deploy/charts/litellm-helm
|
||||
|
||||
|
|
|
@ -2474,6 +2474,14 @@ class StandardLoggingPayloadSetup:
|
|||
) -> Tuple[float, float, float]:
|
||||
"""
|
||||
Convert datetime objects to floats
|
||||
|
||||
Args:
|
||||
start_time: Union[dt_object, float]
|
||||
end_time: Union[dt_object, float]
|
||||
completion_start_time: Union[dt_object, float]
|
||||
|
||||
Returns:
|
||||
Tuple[float, float, float]: A tuple containing the start time, end time, and completion start time as floats.
|
||||
"""
|
||||
|
||||
if isinstance(start_time, datetime.datetime):
|
||||
|
@ -2534,13 +2542,10 @@ class StandardLoggingPayloadSetup:
|
|||
)
|
||||
if isinstance(metadata, dict):
|
||||
# Filter the metadata dictionary to include only the specified keys
|
||||
clean_metadata = StandardLoggingMetadata(
|
||||
**{ # type: ignore
|
||||
key: metadata[key]
|
||||
for key in StandardLoggingMetadata.__annotations__.keys()
|
||||
if key in metadata
|
||||
}
|
||||
)
|
||||
supported_keys = StandardLoggingMetadata.__annotations__.keys()
|
||||
for key in supported_keys:
|
||||
if key in metadata:
|
||||
clean_metadata[key] = metadata[key] # type: ignore
|
||||
|
||||
if metadata.get("user_api_key") is not None:
|
||||
if is_valid_sha256_hash(str(metadata.get("user_api_key"))):
|
||||
|
|
95
tests/code_coverage_tests/litellm_logging_code_coverage.py
Normal file
95
tests/code_coverage_tests/litellm_logging_code_coverage.py
Normal file
|
@ -0,0 +1,95 @@
|
|||
import ast
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
|
||||
def get_function_names_from_file(file_path: str) -> List[str]:
|
||||
"""
|
||||
Extracts all static method names from litellm_logging.py
|
||||
"""
|
||||
with open(file_path, "r") as file:
|
||||
tree = ast.parse(file.read())
|
||||
|
||||
function_names = []
|
||||
|
||||
for node in tree.body:
|
||||
if isinstance(node, ast.ClassDef):
|
||||
# Functions inside classes
|
||||
for class_node in node.body:
|
||||
if isinstance(class_node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
# Check if the function has @staticmethod decorator
|
||||
for decorator in class_node.decorator_list:
|
||||
if (
|
||||
isinstance(decorator, ast.Name)
|
||||
and decorator.id == "staticmethod"
|
||||
):
|
||||
function_names.append(class_node.name)
|
||||
|
||||
return function_names
|
||||
|
||||
|
||||
def get_all_functions_called_in_tests(base_dir: str) -> set:
|
||||
"""
|
||||
Returns a set of function names that are called in test functions
|
||||
inside test files containing the word 'logging'.
|
||||
"""
|
||||
called_functions = set()
|
||||
|
||||
for root, _, files in os.walk(base_dir):
|
||||
for file in files:
|
||||
if file.endswith(".py") and "logging" in file.lower():
|
||||
file_path = os.path.join(root, file)
|
||||
with open(file_path, "r") as f:
|
||||
try:
|
||||
tree = ast.parse(f.read())
|
||||
except SyntaxError:
|
||||
print(f"Warning: Syntax error in file {file_path}")
|
||||
continue
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Call):
|
||||
if isinstance(node.func, ast.Name):
|
||||
called_functions.add(node.func.id)
|
||||
elif isinstance(node.func, ast.Attribute):
|
||||
called_functions.add(node.func.attr)
|
||||
|
||||
return called_functions
|
||||
|
||||
|
||||
# Functions that can be ignored in test coverage
|
||||
ignored_function_names = [
|
||||
"__init__",
|
||||
# Add other functions to ignore here
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
logging_file = "./litellm/litellm_core_utils/litellm_logging.py"
|
||||
tests_dir = "./tests/"
|
||||
|
||||
# LOCAL TESTING
|
||||
# logging_file = "../../litellm/litellm_core_utils/litellm_logging.py"
|
||||
# tests_dir = "../../tests/"
|
||||
|
||||
logging_functions = get_function_names_from_file(logging_file)
|
||||
print("logging_functions:", logging_functions)
|
||||
|
||||
called_functions_in_tests = get_all_functions_called_in_tests(tests_dir)
|
||||
untested_functions = [
|
||||
fn
|
||||
for fn in logging_functions
|
||||
if fn not in called_functions_in_tests and fn not in ignored_function_names
|
||||
]
|
||||
|
||||
if untested_functions:
|
||||
untested_perc = len(untested_functions) / len(logging_functions)
|
||||
print(f"untested_percentage: {untested_perc * 100:.2f}%")
|
||||
raise Exception(
|
||||
f"{untested_perc * 100:.2f}% of functions in litellm_logging.py are not tested: {untested_functions}"
|
||||
)
|
||||
else:
|
||||
print("All functions in litellm_logging.py are covered by tests.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -260,6 +260,15 @@ def validate_redacted_message_span_attributes(span):
|
|||
"llm.usage.total_tokens",
|
||||
"gen_ai.usage.completion_tokens",
|
||||
"gen_ai.usage.prompt_tokens",
|
||||
"metadata.user_api_key_hash",
|
||||
"metadata.requester_ip_address",
|
||||
"metadata.user_api_key_team_alias",
|
||||
"metadata.requester_metadata",
|
||||
"metadata.user_api_key_team_id",
|
||||
"metadata.spend_logs_metadata",
|
||||
"metadata.user_api_key_alias",
|
||||
"metadata.user_api_key_user_id",
|
||||
"metadata.user_api_key_org_id",
|
||||
]
|
||||
|
||||
_all_attributes = set([name for name in span.attributes.keys()])
|
||||
|
|
|
@ -13,10 +13,16 @@ from pydantic.main import Model
|
|||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system-path
|
||||
|
||||
from datetime import datetime as dt_object
|
||||
import time
|
||||
import pytest
|
||||
import litellm
|
||||
from litellm.types.utils import Usage
|
||||
from litellm.types.utils import (
|
||||
Usage,
|
||||
StandardLoggingMetadata,
|
||||
StandardLoggingModelInformation,
|
||||
StandardLoggingHiddenParams,
|
||||
)
|
||||
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
|
||||
|
||||
|
||||
|
@ -104,3 +110,212 @@ def test_get_additional_headers():
|
|||
"x_ratelimit_limit_tokens": 160000,
|
||||
"x_ratelimit_remaining_tokens": 160000,
|
||||
}
|
||||
|
||||
|
||||
def all_fields_present(standard_logging_metadata: StandardLoggingMetadata):
|
||||
for field in StandardLoggingMetadata.__annotations__.keys():
|
||||
assert field in standard_logging_metadata
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"metadata_key, metadata_value",
|
||||
[
|
||||
("user_api_key_alias", "test_alias"),
|
||||
("user_api_key_hash", "test_hash"),
|
||||
("user_api_key_team_id", "test_team_id"),
|
||||
("user_api_key_user_id", "test_user_id"),
|
||||
("user_api_key_team_alias", "test_team_alias"),
|
||||
("spend_logs_metadata", {"key": "value"}),
|
||||
("requester_ip_address", "127.0.0.1"),
|
||||
("requester_metadata", {"user_agent": "test_agent"}),
|
||||
],
|
||||
)
|
||||
def test_get_standard_logging_metadata(metadata_key, metadata_value):
|
||||
"""
|
||||
Test that the get_standard_logging_metadata function correctly sets the metadata fields.
|
||||
All fields in StandardLoggingMetadata should ALWAYS be present.
|
||||
"""
|
||||
metadata = {metadata_key: metadata_value}
|
||||
standard_logging_metadata = (
|
||||
StandardLoggingPayloadSetup.get_standard_logging_metadata(metadata)
|
||||
)
|
||||
|
||||
print("standard_logging_metadata", standard_logging_metadata)
|
||||
|
||||
# Assert that all fields in StandardLoggingMetadata are present
|
||||
all_fields_present(standard_logging_metadata)
|
||||
|
||||
# Assert that the specific metadata field is set correctly
|
||||
assert standard_logging_metadata[metadata_key] == metadata_value
|
||||
|
||||
|
||||
def test_get_standard_logging_metadata_user_api_key_hash():
|
||||
valid_hash = "a" * 64 # 64 character string
|
||||
metadata = {"user_api_key": valid_hash}
|
||||
result = StandardLoggingPayloadSetup.get_standard_logging_metadata(metadata)
|
||||
assert result["user_api_key_hash"] == valid_hash
|
||||
|
||||
|
||||
def test_get_standard_logging_metadata_invalid_user_api_key():
|
||||
invalid_hash = "not_a_valid_hash"
|
||||
metadata = {"user_api_key": invalid_hash}
|
||||
result = StandardLoggingPayloadSetup.get_standard_logging_metadata(metadata)
|
||||
all_fields_present(result)
|
||||
assert result["user_api_key_hash"] is None
|
||||
|
||||
|
||||
def test_get_standard_logging_metadata_invalid_keys():
|
||||
metadata = {
|
||||
"user_api_key_alias": "test_alias",
|
||||
"invalid_key": "should_be_ignored",
|
||||
"another_invalid_key": 123,
|
||||
}
|
||||
result = StandardLoggingPayloadSetup.get_standard_logging_metadata(metadata)
|
||||
all_fields_present(result)
|
||||
assert result["user_api_key_alias"] == "test_alias"
|
||||
assert "invalid_key" not in result
|
||||
assert "another_invalid_key" not in result
|
||||
|
||||
|
||||
def test_cleanup_timestamps():
|
||||
"""Test cleanup_timestamps with different input types"""
|
||||
# Test with datetime objects
|
||||
now = dt_object.now()
|
||||
start = now
|
||||
end = now
|
||||
completion = now
|
||||
|
||||
result = StandardLoggingPayloadSetup.cleanup_timestamps(start, end, completion)
|
||||
|
||||
assert all(isinstance(x, float) for x in result)
|
||||
assert len(result) == 3
|
||||
|
||||
# Test with float timestamps
|
||||
start_float = time.time()
|
||||
end_float = start_float + 1
|
||||
completion_float = end_float
|
||||
|
||||
result = StandardLoggingPayloadSetup.cleanup_timestamps(
|
||||
start_float, end_float, completion_float
|
||||
)
|
||||
|
||||
assert all(isinstance(x, float) for x in result)
|
||||
assert result[0] == start_float
|
||||
assert result[1] == end_float
|
||||
assert result[2] == completion_float
|
||||
|
||||
# Test with mixed types
|
||||
result = StandardLoggingPayloadSetup.cleanup_timestamps(
|
||||
start_float, end, completion_float
|
||||
)
|
||||
assert all(isinstance(x, float) for x in result)
|
||||
|
||||
# Test invalid input
|
||||
with pytest.raises(ValueError):
|
||||
StandardLoggingPayloadSetup.cleanup_timestamps(
|
||||
"invalid", end_float, completion_float
|
||||
)
|
||||
|
||||
|
||||
def test_get_model_cost_information():
|
||||
"""Test get_model_cost_information with different inputs"""
|
||||
# Test with None values
|
||||
result = StandardLoggingPayloadSetup.get_model_cost_information(
|
||||
base_model=None,
|
||||
custom_pricing=None,
|
||||
custom_llm_provider=None,
|
||||
init_response_obj={},
|
||||
)
|
||||
assert result["model_map_key"] == ""
|
||||
assert result["model_map_value"] is None # this was not found in model cost map
|
||||
# assert all fields in StandardLoggingModelInformation are present
|
||||
assert all(
|
||||
field in result for field in StandardLoggingModelInformation.__annotations__
|
||||
)
|
||||
|
||||
# Test with valid model
|
||||
result = StandardLoggingPayloadSetup.get_model_cost_information(
|
||||
base_model="gpt-3.5-turbo",
|
||||
custom_pricing=False,
|
||||
custom_llm_provider="openai",
|
||||
init_response_obj={},
|
||||
)
|
||||
litellm_info_gpt_3_5_turbo_model_map_value = litellm.get_model_info(
|
||||
model="gpt-3.5-turbo", custom_llm_provider="openai"
|
||||
)
|
||||
print("result", result)
|
||||
assert result["model_map_key"] == "gpt-3.5-turbo"
|
||||
assert result["model_map_value"] is not None
|
||||
assert result["model_map_value"] == litellm_info_gpt_3_5_turbo_model_map_value
|
||||
# assert all fields in StandardLoggingModelInformation are present
|
||||
assert all(
|
||||
field in result for field in StandardLoggingModelInformation.__annotations__
|
||||
)
|
||||
|
||||
|
||||
def test_get_hidden_params():
|
||||
"""Test get_hidden_params with different inputs"""
|
||||
# Test with None
|
||||
result = StandardLoggingPayloadSetup.get_hidden_params(None)
|
||||
assert result["model_id"] is None
|
||||
assert result["cache_key"] is None
|
||||
assert result["api_base"] is None
|
||||
assert result["response_cost"] is None
|
||||
assert result["additional_headers"] is None
|
||||
|
||||
# assert all fields in StandardLoggingHiddenParams are present
|
||||
assert all(field in result for field in StandardLoggingHiddenParams.__annotations__)
|
||||
|
||||
# Test with valid params
|
||||
hidden_params = {
|
||||
"model_id": "test-model",
|
||||
"cache_key": "test-cache",
|
||||
"api_base": "https://api.test.com",
|
||||
"response_cost": 0.001,
|
||||
"additional_headers": {
|
||||
"x-ratelimit-limit-requests": "2000",
|
||||
"x-ratelimit-remaining-requests": "1999",
|
||||
},
|
||||
}
|
||||
result = StandardLoggingPayloadSetup.get_hidden_params(hidden_params)
|
||||
assert result["model_id"] == "test-model"
|
||||
assert result["cache_key"] == "test-cache"
|
||||
assert result["api_base"] == "https://api.test.com"
|
||||
assert result["response_cost"] == 0.001
|
||||
assert result["additional_headers"] is not None
|
||||
assert result["additional_headers"]["x_ratelimit_limit_requests"] == 2000
|
||||
# assert all fields in StandardLoggingHiddenParams are present
|
||||
assert all(field in result for field in StandardLoggingHiddenParams.__annotations__)
|
||||
|
||||
|
||||
def test_get_final_response_obj():
|
||||
"""Test get_final_response_obj with different input types and redaction scenarios"""
|
||||
# Test with direct response_obj
|
||||
response_obj = {"choices": [{"message": {"content": "test content"}}]}
|
||||
result = StandardLoggingPayloadSetup.get_final_response_obj(
|
||||
response_obj=response_obj, init_response_obj=None, kwargs={}
|
||||
)
|
||||
assert result == response_obj
|
||||
|
||||
# Test redaction when litellm.turn_off_message_logging is True
|
||||
litellm.turn_off_message_logging = True
|
||||
try:
|
||||
model_response = litellm.ModelResponse(
|
||||
choices=[
|
||||
litellm.Choices(message=litellm.Message(content="sensitive content"))
|
||||
]
|
||||
)
|
||||
kwargs = {"messages": [{"role": "user", "content": "original message"}]}
|
||||
result = StandardLoggingPayloadSetup.get_final_response_obj(
|
||||
response_obj=model_response, init_response_obj=model_response, kwargs=kwargs
|
||||
)
|
||||
|
||||
print("result", result)
|
||||
print("type(result)", type(result))
|
||||
# Verify response message content was redacted
|
||||
assert result["choices"][0]["message"]["content"] == "redacted-by-litellm"
|
||||
# Verify that redaction occurred in kwargs
|
||||
assert kwargs["messages"][0]["content"] == "redacted-by-litellm"
|
||||
finally:
|
||||
# Reset litellm.turn_off_message_logging to its original value
|
||||
litellm.turn_off_message_logging = False
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue