From 8359cb6fa9bf7b0bf4f3df630cf8666adffa2813 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 22 Oct 2024 19:13:19 +0530 Subject: [PATCH] (fix) standard logging metadata + add unit testing (#6366) * fix setting StandardLoggingMetadata * add unit testing for standard logging metadata * fix otel logging test * fix linting * fix typing --- litellm/litellm_core_utils/litellm_logging.py | 12 +-- .../test_otel_logging.py | 8 ++ .../test_standard_logging_payload.py | 86 +++++++++++++++++++ 3 files changed, 98 insertions(+), 8 deletions(-) create mode 100644 tests/logging_callback_tests/test_standard_logging_payload.py diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index f41ac256b..f1803043c 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -2801,14 +2801,10 @@ def get_standard_logging_metadata( ) if isinstance(metadata, dict): # Filter the metadata dictionary to include only the specified keys - clean_metadata = StandardLoggingMetadata( - **{ # type: ignore - key: metadata[key] - for key in StandardLoggingMetadata.__annotations__.keys() - if key in metadata - } - ) - + supported_keys = StandardLoggingMetadata.__annotations__.keys() + for key in supported_keys: + if key in metadata: + clean_metadata[key] = metadata[key] # type: ignore if metadata.get("user_api_key") is not None: if is_valid_sha256_hash(str(metadata.get("user_api_key"))): clean_metadata["user_api_key_hash"] = metadata.get( diff --git a/tests/logging_callback_tests/test_otel_logging.py b/tests/logging_callback_tests/test_otel_logging.py index 49212607b..14cfa1c13 100644 --- a/tests/logging_callback_tests/test_otel_logging.py +++ b/tests/logging_callback_tests/test_otel_logging.py @@ -260,6 +260,14 @@ def validate_redacted_message_span_attributes(span): "llm.usage.total_tokens", "gen_ai.usage.completion_tokens", "gen_ai.usage.prompt_tokens", + "metadata.user_api_key_hash", + "metadata.requester_ip_address", + "metadata.user_api_key_team_alias", + "metadata.requester_metadata", + "metadata.user_api_key_team_id", + "metadata.spend_logs_metadata", + "metadata.user_api_key_alias", + "metadata.user_api_key_user_id", ] _all_attributes = set([name for name in span.attributes.keys()]) diff --git a/tests/logging_callback_tests/test_standard_logging_payload.py b/tests/logging_callback_tests/test_standard_logging_payload.py new file mode 100644 index 000000000..7ae3ae6ed --- /dev/null +++ b/tests/logging_callback_tests/test_standard_logging_payload.py @@ -0,0 +1,86 @@ +import json +import os +import sys +from datetime import datetime +from unittest.mock import AsyncMock + +from pydantic.main import Model + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system-path + +from typing import Literal + +import pytest +import litellm +import asyncio +import logging +from litellm.litellm_core_utils.litellm_logging import ( + get_standard_logging_metadata, + StandardLoggingMetadata, +) + + +def all_fields_present(standard_logging_metadata: StandardLoggingMetadata): + for field in StandardLoggingMetadata.__annotations__.keys(): + assert field in standard_logging_metadata + + +@pytest.mark.parametrize( + "metadata_key, metadata_value", + [ + ("user_api_key_alias", "test_alias"), + ("user_api_key_hash", "test_hash"), + ("user_api_key_team_id", "test_team_id"), + ("user_api_key_user_id", "test_user_id"), + ("user_api_key_team_alias", "test_team_alias"), + ("spend_logs_metadata", {"key": "value"}), + ("requester_ip_address", "127.0.0.1"), + ("requester_metadata", {"user_agent": "test_agent"}), + ], +) +def test_get_standard_logging_metadata(metadata_key, metadata_value): + """ + Test that the get_standard_logging_metadata function correctly sets the metadata fields. + + All fields in StandardLoggingMetadata should ALWAYS be present. + """ + metadata = {metadata_key: metadata_value} + standard_logging_metadata = get_standard_logging_metadata(metadata) + + print("standard_logging_metadata", standard_logging_metadata) + + # Assert that all fields in StandardLoggingMetadata are present + all_fields_present(standard_logging_metadata) + + # Assert that the specific metadata field is set correctly + assert standard_logging_metadata[metadata_key] == metadata_value + + +def test_get_standard_logging_metadata_user_api_key_hash(): + valid_hash = "a" * 64 # 64 character string + metadata = {"user_api_key": valid_hash} + result = get_standard_logging_metadata(metadata) + assert result["user_api_key_hash"] == valid_hash + + +def test_get_standard_logging_metadata_invalid_user_api_key(): + invalid_hash = "not_a_valid_hash" + metadata = {"user_api_key": invalid_hash} + result = get_standard_logging_metadata(metadata) + all_fields_present(result) + assert result["user_api_key_hash"] is None + + +def test_get_standard_logging_metadata_invalid_keys(): + metadata = { + "user_api_key_alias": "test_alias", + "invalid_key": "should_be_ignored", + "another_invalid_key": 123, + } + result = get_standard_logging_metadata(metadata) + all_fields_present(result) + assert result["user_api_key_alias"] == "test_alias" + assert "invalid_key" not in result + assert "another_invalid_key" not in result