(Bug fix) - don't log messages in model_parameters in StandardLoggingPayload (#8932)

* define model param helper

* use ModelParamHelper

* get_standard_logging_model_parameters

* fix code quality

* get_standard_logging_model_parameters

* StandardLoggingPayload

* test_get_kwargs_for_cache_key

* test_langsmith_key_based_logging

* fix code qa

* fix linting
This commit is contained in:
Ishaan Jaff 2025-03-01 13:39:45 -08:00 committed by GitHub
parent ee7cd60fdb
commit bc9b3e4847
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 151 additions and 90 deletions

View file

@ -5,9 +5,15 @@ If the ddtrace package is not installed, the tracer will be a no-op.
"""
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Union
from litellm.secret_managers.main import get_secret_bool
if TYPE_CHECKING:
from ddtrace.tracer import Tracer as DD_TRACER
else:
DD_TRACER = Any
class NullSpan:
"""A no-op span implementation."""
@ -53,12 +59,13 @@ def _should_use_dd_tracer():
# Initialize tracer
should_use_dd_tracer = _should_use_dd_tracer()
tracer = None
tracer: Union[NullTracer, DD_TRACER] = NullTracer()
# We need to ensure tracer is never None and always has the required methods
if should_use_dd_tracer:
try:
from ddtrace import tracer as dd_tracer
# Define the type to match what's expected by the code using this module
tracer = dd_tracer
except ImportError:
tracer = NullTracer()

View file

@ -33,6 +33,7 @@ from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.mlflow import MlflowLogger
from litellm.integrations.pagerduty.pagerduty import PagerDutyAlerting
from litellm.litellm_core_utils.get_litellm_params import get_litellm_params
from litellm.litellm_core_utils.model_param_helper import ModelParamHelper
from litellm.litellm_core_utils.redact_messages import (
redact_message_input_output_from_custom_logger,
redact_message_input_output_from_logging,
@ -3330,7 +3331,9 @@ def get_standard_logging_object_payload(
requester_ip_address=clean_metadata.get("requester_ip_address", None),
messages=kwargs.get("messages"),
response=final_response_obj,
model_parameters=kwargs.get("optional_params", None),
model_parameters=ModelParamHelper.get_standard_logging_model_parameters(
kwargs.get("optional_params", None) or {}
),
hidden_params=clean_hidden_params,
model_map_information=model_cost_information,
error_str=error_str,

View file

@ -0,0 +1,133 @@
from typing import Set
from openai.types.audio.transcription_create_params import TranscriptionCreateParams
from openai.types.chat.completion_create_params import (
CompletionCreateParamsNonStreaming,
CompletionCreateParamsStreaming,
)
from openai.types.completion_create_params import (
CompletionCreateParamsNonStreaming as TextCompletionCreateParamsNonStreaming,
)
from openai.types.completion_create_params import (
CompletionCreateParamsStreaming as TextCompletionCreateParamsStreaming,
)
from openai.types.embedding_create_params import EmbeddingCreateParams
from litellm.types.rerank import RerankRequest
class ModelParamHelper:
@staticmethod
def get_standard_logging_model_parameters(
model_parameters: dict,
) -> dict:
""" """
standard_logging_model_parameters: dict = {}
supported_model_parameters = (
ModelParamHelper._get_relevant_args_to_use_for_logging()
)
for key, value in model_parameters.items():
if key in supported_model_parameters:
standard_logging_model_parameters[key] = value
return standard_logging_model_parameters
@staticmethod
def get_exclude_params_for_model_parameters() -> Set[str]:
return set(["messages", "prompt", "input"])
@staticmethod
def _get_relevant_args_to_use_for_logging() -> Set[str]:
"""
Gets all relevant llm api params besides the ones with prompt content
"""
all_openai_llm_api_params = ModelParamHelper._get_all_llm_api_params()
# Exclude parameters that contain prompt content
combined_kwargs = all_openai_llm_api_params.difference(
set(ModelParamHelper.get_exclude_params_for_model_parameters())
)
return combined_kwargs
@staticmethod
def _get_all_llm_api_params() -> Set[str]:
"""
Gets the supported kwargs for each call type and combines them
"""
chat_completion_kwargs = (
ModelParamHelper._get_litellm_supported_chat_completion_kwargs()
)
text_completion_kwargs = (
ModelParamHelper._get_litellm_supported_text_completion_kwargs()
)
embedding_kwargs = ModelParamHelper._get_litellm_supported_embedding_kwargs()
transcription_kwargs = (
ModelParamHelper._get_litellm_supported_transcription_kwargs()
)
rerank_kwargs = ModelParamHelper._get_litellm_supported_rerank_kwargs()
exclude_kwargs = ModelParamHelper._get_exclude_kwargs()
combined_kwargs = chat_completion_kwargs.union(
text_completion_kwargs,
embedding_kwargs,
transcription_kwargs,
rerank_kwargs,
)
combined_kwargs = combined_kwargs.difference(exclude_kwargs)
return combined_kwargs
@staticmethod
def _get_litellm_supported_chat_completion_kwargs() -> Set[str]:
"""
Get the litellm supported chat completion kwargs
This follows the OpenAI API Spec
"""
all_chat_completion_kwargs = set(
CompletionCreateParamsNonStreaming.__annotations__.keys()
).union(set(CompletionCreateParamsStreaming.__annotations__.keys()))
return all_chat_completion_kwargs
@staticmethod
def _get_litellm_supported_text_completion_kwargs() -> Set[str]:
"""
Get the litellm supported text completion kwargs
This follows the OpenAI API Spec
"""
all_text_completion_kwargs = set(
TextCompletionCreateParamsNonStreaming.__annotations__.keys()
).union(set(TextCompletionCreateParamsStreaming.__annotations__.keys()))
return all_text_completion_kwargs
@staticmethod
def _get_litellm_supported_rerank_kwargs() -> Set[str]:
"""
Get the litellm supported rerank kwargs
"""
return set(RerankRequest.model_fields.keys())
@staticmethod
def _get_litellm_supported_embedding_kwargs() -> Set[str]:
"""
Get the litellm supported embedding kwargs
This follows the OpenAI API Spec
"""
return set(EmbeddingCreateParams.__annotations__.keys())
@staticmethod
def _get_litellm_supported_transcription_kwargs() -> Set[str]:
"""
Get the litellm supported transcription kwargs
This follows the OpenAI API Spec
"""
return set(TranscriptionCreateParams.__annotations__.keys())
@staticmethod
def _get_exclude_kwargs() -> Set[str]:
"""
Get the kwargs to exclude from the cache key
"""
return set(["metadata"])