diff --git a/.circleci/config.yml b/.circleci/config.yml index 8e63cfe25..7a742afe0 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -722,6 +722,7 @@ jobs: - run: python ./tests/documentation_tests/test_general_setting_keys.py - run: python ./tests/code_coverage_tests/router_code_coverage.py - run: python ./tests/code_coverage_tests/test_router_strategy_async.py + - run: python ./tests/code_coverage_tests/litellm_logging_code_coverage.py - run: python ./tests/documentation_tests/test_env_keys.py - run: helm lint ./deploy/charts/litellm-helm diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 4753779c0..2ab905e85 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -2474,6 +2474,14 @@ class StandardLoggingPayloadSetup: ) -> Tuple[float, float, float]: """ Convert datetime objects to floats + + Args: + start_time: Union[dt_object, float] + end_time: Union[dt_object, float] + completion_start_time: Union[dt_object, float] + + Returns: + Tuple[float, float, float]: A tuple containing the start time, end time, and completion start time as floats. """ if isinstance(start_time, datetime.datetime): @@ -2534,13 +2542,10 @@ class StandardLoggingPayloadSetup: ) if isinstance(metadata, dict): # Filter the metadata dictionary to include only the specified keys - clean_metadata = StandardLoggingMetadata( - **{ # type: ignore - key: metadata[key] - for key in StandardLoggingMetadata.__annotations__.keys() - if key in metadata - } - ) + supported_keys = StandardLoggingMetadata.__annotations__.keys() + for key in supported_keys: + if key in metadata: + clean_metadata[key] = metadata[key] # type: ignore if metadata.get("user_api_key") is not None: if is_valid_sha256_hash(str(metadata.get("user_api_key"))): diff --git a/tests/code_coverage_tests/litellm_logging_code_coverage.py b/tests/code_coverage_tests/litellm_logging_code_coverage.py new file mode 100644 index 000000000..9825cfba1 --- /dev/null +++ b/tests/code_coverage_tests/litellm_logging_code_coverage.py @@ -0,0 +1,95 @@ +import ast +import os +from typing import List + + +def get_function_names_from_file(file_path: str) -> List[str]: + """ + Extracts all static method names from litellm_logging.py + """ + with open(file_path, "r") as file: + tree = ast.parse(file.read()) + + function_names = [] + + for node in tree.body: + if isinstance(node, ast.ClassDef): + # Functions inside classes + for class_node in node.body: + if isinstance(class_node, (ast.FunctionDef, ast.AsyncFunctionDef)): + # Check if the function has @staticmethod decorator + for decorator in class_node.decorator_list: + if ( + isinstance(decorator, ast.Name) + and decorator.id == "staticmethod" + ): + function_names.append(class_node.name) + + return function_names + + +def get_all_functions_called_in_tests(base_dir: str) -> set: + """ + Returns a set of function names that are called in test functions + inside test files containing the word 'logging'. + """ + called_functions = set() + + for root, _, files in os.walk(base_dir): + for file in files: + if file.endswith(".py") and "logging" in file.lower(): + file_path = os.path.join(root, file) + with open(file_path, "r") as f: + try: + tree = ast.parse(f.read()) + except SyntaxError: + print(f"Warning: Syntax error in file {file_path}") + continue + + for node in ast.walk(tree): + if isinstance(node, ast.Call): + if isinstance(node.func, ast.Name): + called_functions.add(node.func.id) + elif isinstance(node.func, ast.Attribute): + called_functions.add(node.func.attr) + + return called_functions + + +# Functions that can be ignored in test coverage +ignored_function_names = [ + "__init__", + # Add other functions to ignore here +] + + +def main(): + logging_file = "./litellm/litellm_core_utils/litellm_logging.py" + tests_dir = "./tests/" + + # LOCAL TESTING + # logging_file = "../../litellm/litellm_core_utils/litellm_logging.py" + # tests_dir = "../../tests/" + + logging_functions = get_function_names_from_file(logging_file) + print("logging_functions:", logging_functions) + + called_functions_in_tests = get_all_functions_called_in_tests(tests_dir) + untested_functions = [ + fn + for fn in logging_functions + if fn not in called_functions_in_tests and fn not in ignored_function_names + ] + + if untested_functions: + untested_perc = len(untested_functions) / len(logging_functions) + print(f"untested_percentage: {untested_perc * 100:.2f}%") + raise Exception( + f"{untested_perc * 100:.2f}% of functions in litellm_logging.py are not tested: {untested_functions}" + ) + else: + print("All functions in litellm_logging.py are covered by tests.") + + +if __name__ == "__main__": + main() diff --git a/tests/logging_callback_tests/test_otel_logging.py b/tests/logging_callback_tests/test_otel_logging.py index 49212607b..f93cc1ec2 100644 --- a/tests/logging_callback_tests/test_otel_logging.py +++ b/tests/logging_callback_tests/test_otel_logging.py @@ -260,6 +260,15 @@ def validate_redacted_message_span_attributes(span): "llm.usage.total_tokens", "gen_ai.usage.completion_tokens", "gen_ai.usage.prompt_tokens", + "metadata.user_api_key_hash", + "metadata.requester_ip_address", + "metadata.user_api_key_team_alias", + "metadata.requester_metadata", + "metadata.user_api_key_team_id", + "metadata.spend_logs_metadata", + "metadata.user_api_key_alias", + "metadata.user_api_key_user_id", + "metadata.user_api_key_org_id", ] _all_attributes = set([name for name in span.attributes.keys()]) diff --git a/tests/logging_callback_tests/test_standard_logging_payload.py b/tests/logging_callback_tests/test_standard_logging_payload.py index 42d504a1e..654103663 100644 --- a/tests/logging_callback_tests/test_standard_logging_payload.py +++ b/tests/logging_callback_tests/test_standard_logging_payload.py @@ -13,10 +13,16 @@ from pydantic.main import Model sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system-path - +from datetime import datetime as dt_object +import time import pytest import litellm -from litellm.types.utils import Usage +from litellm.types.utils import ( + Usage, + StandardLoggingMetadata, + StandardLoggingModelInformation, + StandardLoggingHiddenParams, +) from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup @@ -104,3 +110,212 @@ def test_get_additional_headers(): "x_ratelimit_limit_tokens": 160000, "x_ratelimit_remaining_tokens": 160000, } + + +def all_fields_present(standard_logging_metadata: StandardLoggingMetadata): + for field in StandardLoggingMetadata.__annotations__.keys(): + assert field in standard_logging_metadata + + +@pytest.mark.parametrize( + "metadata_key, metadata_value", + [ + ("user_api_key_alias", "test_alias"), + ("user_api_key_hash", "test_hash"), + ("user_api_key_team_id", "test_team_id"), + ("user_api_key_user_id", "test_user_id"), + ("user_api_key_team_alias", "test_team_alias"), + ("spend_logs_metadata", {"key": "value"}), + ("requester_ip_address", "127.0.0.1"), + ("requester_metadata", {"user_agent": "test_agent"}), + ], +) +def test_get_standard_logging_metadata(metadata_key, metadata_value): + """ + Test that the get_standard_logging_metadata function correctly sets the metadata fields. + All fields in StandardLoggingMetadata should ALWAYS be present. + """ + metadata = {metadata_key: metadata_value} + standard_logging_metadata = ( + StandardLoggingPayloadSetup.get_standard_logging_metadata(metadata) + ) + + print("standard_logging_metadata", standard_logging_metadata) + + # Assert that all fields in StandardLoggingMetadata are present + all_fields_present(standard_logging_metadata) + + # Assert that the specific metadata field is set correctly + assert standard_logging_metadata[metadata_key] == metadata_value + + +def test_get_standard_logging_metadata_user_api_key_hash(): + valid_hash = "a" * 64 # 64 character string + metadata = {"user_api_key": valid_hash} + result = StandardLoggingPayloadSetup.get_standard_logging_metadata(metadata) + assert result["user_api_key_hash"] == valid_hash + + +def test_get_standard_logging_metadata_invalid_user_api_key(): + invalid_hash = "not_a_valid_hash" + metadata = {"user_api_key": invalid_hash} + result = StandardLoggingPayloadSetup.get_standard_logging_metadata(metadata) + all_fields_present(result) + assert result["user_api_key_hash"] is None + + +def test_get_standard_logging_metadata_invalid_keys(): + metadata = { + "user_api_key_alias": "test_alias", + "invalid_key": "should_be_ignored", + "another_invalid_key": 123, + } + result = StandardLoggingPayloadSetup.get_standard_logging_metadata(metadata) + all_fields_present(result) + assert result["user_api_key_alias"] == "test_alias" + assert "invalid_key" not in result + assert "another_invalid_key" not in result + + +def test_cleanup_timestamps(): + """Test cleanup_timestamps with different input types""" + # Test with datetime objects + now = dt_object.now() + start = now + end = now + completion = now + + result = StandardLoggingPayloadSetup.cleanup_timestamps(start, end, completion) + + assert all(isinstance(x, float) for x in result) + assert len(result) == 3 + + # Test with float timestamps + start_float = time.time() + end_float = start_float + 1 + completion_float = end_float + + result = StandardLoggingPayloadSetup.cleanup_timestamps( + start_float, end_float, completion_float + ) + + assert all(isinstance(x, float) for x in result) + assert result[0] == start_float + assert result[1] == end_float + assert result[2] == completion_float + + # Test with mixed types + result = StandardLoggingPayloadSetup.cleanup_timestamps( + start_float, end, completion_float + ) + assert all(isinstance(x, float) for x in result) + + # Test invalid input + with pytest.raises(ValueError): + StandardLoggingPayloadSetup.cleanup_timestamps( + "invalid", end_float, completion_float + ) + + +def test_get_model_cost_information(): + """Test get_model_cost_information with different inputs""" + # Test with None values + result = StandardLoggingPayloadSetup.get_model_cost_information( + base_model=None, + custom_pricing=None, + custom_llm_provider=None, + init_response_obj={}, + ) + assert result["model_map_key"] == "" + assert result["model_map_value"] is None # this was not found in model cost map + # assert all fields in StandardLoggingModelInformation are present + assert all( + field in result for field in StandardLoggingModelInformation.__annotations__ + ) + + # Test with valid model + result = StandardLoggingPayloadSetup.get_model_cost_information( + base_model="gpt-3.5-turbo", + custom_pricing=False, + custom_llm_provider="openai", + init_response_obj={}, + ) + litellm_info_gpt_3_5_turbo_model_map_value = litellm.get_model_info( + model="gpt-3.5-turbo", custom_llm_provider="openai" + ) + print("result", result) + assert result["model_map_key"] == "gpt-3.5-turbo" + assert result["model_map_value"] is not None + assert result["model_map_value"] == litellm_info_gpt_3_5_turbo_model_map_value + # assert all fields in StandardLoggingModelInformation are present + assert all( + field in result for field in StandardLoggingModelInformation.__annotations__ + ) + + +def test_get_hidden_params(): + """Test get_hidden_params with different inputs""" + # Test with None + result = StandardLoggingPayloadSetup.get_hidden_params(None) + assert result["model_id"] is None + assert result["cache_key"] is None + assert result["api_base"] is None + assert result["response_cost"] is None + assert result["additional_headers"] is None + + # assert all fields in StandardLoggingHiddenParams are present + assert all(field in result for field in StandardLoggingHiddenParams.__annotations__) + + # Test with valid params + hidden_params = { + "model_id": "test-model", + "cache_key": "test-cache", + "api_base": "https://api.test.com", + "response_cost": 0.001, + "additional_headers": { + "x-ratelimit-limit-requests": "2000", + "x-ratelimit-remaining-requests": "1999", + }, + } + result = StandardLoggingPayloadSetup.get_hidden_params(hidden_params) + assert result["model_id"] == "test-model" + assert result["cache_key"] == "test-cache" + assert result["api_base"] == "https://api.test.com" + assert result["response_cost"] == 0.001 + assert result["additional_headers"] is not None + assert result["additional_headers"]["x_ratelimit_limit_requests"] == 2000 + # assert all fields in StandardLoggingHiddenParams are present + assert all(field in result for field in StandardLoggingHiddenParams.__annotations__) + + +def test_get_final_response_obj(): + """Test get_final_response_obj with different input types and redaction scenarios""" + # Test with direct response_obj + response_obj = {"choices": [{"message": {"content": "test content"}}]} + result = StandardLoggingPayloadSetup.get_final_response_obj( + response_obj=response_obj, init_response_obj=None, kwargs={} + ) + assert result == response_obj + + # Test redaction when litellm.turn_off_message_logging is True + litellm.turn_off_message_logging = True + try: + model_response = litellm.ModelResponse( + choices=[ + litellm.Choices(message=litellm.Message(content="sensitive content")) + ] + ) + kwargs = {"messages": [{"role": "user", "content": "original message"}]} + result = StandardLoggingPayloadSetup.get_final_response_obj( + response_obj=model_response, init_response_obj=model_response, kwargs=kwargs + ) + + print("result", result) + print("type(result)", type(result)) + # Verify response message content was redacted + assert result["choices"][0]["message"]["content"] == "redacted-by-litellm" + # Verify that redaction occurred in kwargs + assert kwargs["messages"][0]["content"] == "redacted-by-litellm" + finally: + # Reset litellm.turn_off_message_logging to its original value + litellm.turn_off_message_logging = False