diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 5a0a9c55ef..b763a4f636 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -48,6 +48,7 @@ from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS from litellm.types.utils import ( CallTypes, EmbeddingResponse, + FieldsWithMessageContent, ImageResponse, LiteLLMLoggingBaseClass, ModelResponse, @@ -3160,6 +3161,32 @@ class StandardLoggingPayloadSetup: else: return end_time_float - start_time_float + @staticmethod + def _remove_message_content_from_dict(original_dict: Optional[dict]) -> dict: + """ + Filters out any params with message content `messages`, `input`, `prompt` + + eg. We don't want to log the prompt in the model parameters + """ + if original_dict is None: + return {} + sensitive_keys = FieldsWithMessageContent.get_all_fields() + cleaned_optional_params = {} + for key in original_dict: + if key not in sensitive_keys: + cleaned_optional_params[key] = original_dict[key] + return cleaned_optional_params + + @staticmethod + def _get_model_parameters(kwargs: dict) -> dict: + """ + Get the model parameters from the kwargs + """ + optional_params = kwargs.get("optional_params", {}) or {} + return StandardLoggingPayloadSetup._remove_message_content_from_dict( + optional_params + ) + def get_standard_logging_object_payload( kwargs: Optional[dict], @@ -3330,7 +3357,7 @@ def get_standard_logging_object_payload( requester_ip_address=clean_metadata.get("requester_ip_address", None), messages=kwargs.get("messages"), response=final_response_obj, - model_parameters=kwargs.get("optional_params", None), + model_parameters=StandardLoggingPayloadSetup._get_model_parameters(kwargs), hidden_params=clean_hidden_params, model_map_information=model_cost_information, error_str=error_str, diff --git a/litellm/types/utils.py b/litellm/types/utils.py index dcaf5f35d1..0c6014cdba 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1963,3 +1963,14 @@ class ProviderSpecificHeader(TypedDict): class SelectTokenizerResponse(TypedDict): type: Literal["openai_tokenizer", "huggingface_tokenizer"] tokenizer: Any + + +class FieldsWithMessageContent(str, Enum): + MESSAGES = "messages" + INPUT = "input" + PROMPT = "prompt" + QUERY = "query" + + @classmethod + def get_all_fields(cls) -> List[str]: + return [field.value for field in cls] diff --git a/tests/litellm/litellm_core_utils/test_litellm_logging.py b/tests/litellm/litellm_core_utils/test_litellm_logging.py new file mode 100644 index 0000000000..2d888d5982 --- /dev/null +++ b/tests/litellm/litellm_core_utils/test_litellm_logging.py @@ -0,0 +1,104 @@ +import json +import os +import sys + +import pytest +from fastapi.testclient import TestClient + +sys.path.insert( + 0, os.path.abspath("../../..") +) # Adds the parent directory to the system path + +from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup + + +def test_remove_message_content_from_dict(): + # Test with None input + assert StandardLoggingPayloadSetup._remove_message_content_from_dict(None) == {} + + # Test with empty dict + assert StandardLoggingPayloadSetup._remove_message_content_from_dict({}) == {} + + # Test with sensitive content + input_dict = { + "messages": "sensitive content", + "input": "secret prompt", + "prompt": "confidential text", + "safe_key": "safe value", + "temperature": 0.7, + } + + expected_output = {"safe_key": "safe value", "temperature": 0.7} + + result = StandardLoggingPayloadSetup._remove_message_content_from_dict(input_dict) + assert result == expected_output + + +def test_get_model_parameters(): + # Test with empty kwargs + assert StandardLoggingPayloadSetup._get_model_parameters({}) == {} + + # Test with None optional_params + assert ( + StandardLoggingPayloadSetup._get_model_parameters({"optional_params": None}) + == {} + ) + + # Test with actual parameters + kwargs = { + "optional_params": { + "temperature": 0.8, + "messages": "sensitive data", + "max_tokens": 100, + "prompt": "secret prompt", + } + } + + expected_output = {"temperature": 0.8, "max_tokens": 100} + + result = StandardLoggingPayloadSetup._get_model_parameters(kwargs) + assert result == expected_output + + +def test_get_model_parameters_complex(): + # Test with more complex optional parameters + kwargs = { + "optional_params": { + "temperature": 0.8, + "messages": [{"role": "user", "content": "sensitive"}], + "max_tokens": 100, + "stop": ["\n", "stop"], + "presence_penalty": 0.5, + "frequency_penalty": 0.3, + "prompt": "secret prompt", + "n": 1, + "best_of": 2, + "logit_bias": {"50256": -100}, + } + } + + expected_output = { + "temperature": 0.8, + "max_tokens": 100, + "stop": ["\n", "stop"], + "presence_penalty": 0.5, + "frequency_penalty": 0.3, + "n": 1, + "best_of": 2, + "logit_bias": {"50256": -100}, + } + + result = StandardLoggingPayloadSetup._get_model_parameters(kwargs) + assert result == expected_output + + +def test_get_model_parameters_empty_optional_params(): + # Test with empty optional_params dictionary + kwargs = {"optional_params": {}} + assert StandardLoggingPayloadSetup._get_model_parameters(kwargs) == {} + + +def test_get_model_parameters_missing_optional_params(): + # Test with missing optional_params key + kwargs = {"model": "gpt-4", "api_key": "test"} + assert StandardLoggingPayloadSetup._get_model_parameters(kwargs) == {}