diff --git a/litellm/caching/caching.py b/litellm/caching/caching.py index 26f94a94c2..415c49edff 100644 --- a/litellm/caching/caching.py +++ b/litellm/caching/caching.py @@ -13,26 +13,14 @@ import json import time import traceback from enum import Enum -from typing import Any, Dict, List, Optional, Set, Union +from typing import Any, Dict, List, Optional, Union -from openai.types.audio.transcription_create_params import TranscriptionCreateParams -from openai.types.chat.completion_create_params import ( - CompletionCreateParamsNonStreaming, - CompletionCreateParamsStreaming, -) -from openai.types.completion_create_params import ( - CompletionCreateParamsNonStreaming as TextCompletionCreateParamsNonStreaming, -) -from openai.types.completion_create_params import ( - CompletionCreateParamsStreaming as TextCompletionCreateParamsStreaming, -) -from openai.types.embedding_create_params import EmbeddingCreateParams from pydantic import BaseModel import litellm from litellm._logging import verbose_logger +from litellm.litellm_core_utils.model_param_helper import ModelParamHelper from litellm.types.caching import * -from litellm.types.rerank import RerankRequest from litellm.types.utils import all_litellm_params from .base_cache import BaseCache @@ -257,7 +245,7 @@ class Cache: verbose_logger.debug("\nReturning preset cache key: %s", preset_cache_key) return preset_cache_key - combined_kwargs = self._get_relevant_args_to_use_for_cache_key() + combined_kwargs = ModelParamHelper._get_all_llm_api_params() litellm_param_kwargs = all_litellm_params for param in kwargs: if param in combined_kwargs: @@ -364,76 +352,6 @@ class Cache: if "litellm_params" in kwargs: kwargs["litellm_params"]["preset_cache_key"] = preset_cache_key - def _get_relevant_args_to_use_for_cache_key(self) -> Set[str]: - """ - Gets the supported kwargs for each call type and combines them - """ - chat_completion_kwargs = self._get_litellm_supported_chat_completion_kwargs() - text_completion_kwargs = self._get_litellm_supported_text_completion_kwargs() - embedding_kwargs = self._get_litellm_supported_embedding_kwargs() - transcription_kwargs = self._get_litellm_supported_transcription_kwargs() - rerank_kwargs = self._get_litellm_supported_rerank_kwargs() - exclude_kwargs = self._get_kwargs_to_exclude_from_cache_key() - - combined_kwargs = chat_completion_kwargs.union( - text_completion_kwargs, - embedding_kwargs, - transcription_kwargs, - rerank_kwargs, - ) - combined_kwargs = combined_kwargs.difference(exclude_kwargs) - return combined_kwargs - - def _get_litellm_supported_chat_completion_kwargs(self) -> Set[str]: - """ - Get the litellm supported chat completion kwargs - - This follows the OpenAI API Spec - """ - all_chat_completion_kwargs = set( - CompletionCreateParamsNonStreaming.__annotations__.keys() - ).union(set(CompletionCreateParamsStreaming.__annotations__.keys())) - return all_chat_completion_kwargs - - def _get_litellm_supported_text_completion_kwargs(self) -> Set[str]: - """ - Get the litellm supported text completion kwargs - - This follows the OpenAI API Spec - """ - all_text_completion_kwargs = set( - TextCompletionCreateParamsNonStreaming.__annotations__.keys() - ).union(set(TextCompletionCreateParamsStreaming.__annotations__.keys())) - return all_text_completion_kwargs - - def _get_litellm_supported_rerank_kwargs(self) -> Set[str]: - """ - Get the litellm supported rerank kwargs - """ - return set(RerankRequest.model_fields.keys()) - - def _get_litellm_supported_embedding_kwargs(self) -> Set[str]: - """ - Get the litellm supported embedding kwargs - - This follows the OpenAI API Spec - """ - return set(EmbeddingCreateParams.__annotations__.keys()) - - def _get_litellm_supported_transcription_kwargs(self) -> Set[str]: - """ - Get the litellm supported transcription kwargs - - This follows the OpenAI API Spec - """ - return set(TranscriptionCreateParams.__annotations__.keys()) - - def _get_kwargs_to_exclude_from_cache_key(self) -> Set[str]: - """ - Get the kwargs to exclude from the cache key - """ - return set(["metadata"]) - @staticmethod def _get_hashed_cache_key(cache_key: str) -> str: """ diff --git a/litellm/litellm_core_utils/dd_tracing.py b/litellm/litellm_core_utils/dd_tracing.py index fe2c96000c..1f866a998a 100644 --- a/litellm/litellm_core_utils/dd_tracing.py +++ b/litellm/litellm_core_utils/dd_tracing.py @@ -5,9 +5,15 @@ If the ddtrace package is not installed, the tracer will be a no-op. """ from contextlib import contextmanager +from typing import TYPE_CHECKING, Any, Union from litellm.secret_managers.main import get_secret_bool +if TYPE_CHECKING: + from ddtrace.tracer import Tracer as DD_TRACER +else: + DD_TRACER = Any + class NullSpan: """A no-op span implementation.""" @@ -53,12 +59,13 @@ def _should_use_dd_tracer(): # Initialize tracer should_use_dd_tracer = _should_use_dd_tracer() -tracer = None - +tracer: Union[NullTracer, DD_TRACER] = NullTracer() +# We need to ensure tracer is never None and always has the required methods if should_use_dd_tracer: try: from ddtrace import tracer as dd_tracer + # Define the type to match what's expected by the code using this module tracer = dd_tracer except ImportError: tracer = NullTracer() diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 5a0a9c55ef..d2fd5b6ca8 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -33,6 +33,7 @@ from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.mlflow import MlflowLogger from litellm.integrations.pagerduty.pagerduty import PagerDutyAlerting from litellm.litellm_core_utils.get_litellm_params import get_litellm_params +from litellm.litellm_core_utils.model_param_helper import ModelParamHelper from litellm.litellm_core_utils.redact_messages import ( redact_message_input_output_from_custom_logger, redact_message_input_output_from_logging, @@ -3330,7 +3331,9 @@ def get_standard_logging_object_payload( requester_ip_address=clean_metadata.get("requester_ip_address", None), messages=kwargs.get("messages"), response=final_response_obj, - model_parameters=kwargs.get("optional_params", None), + model_parameters=ModelParamHelper.get_standard_logging_model_parameters( + kwargs.get("optional_params", None) or {} + ), hidden_params=clean_hidden_params, model_map_information=model_cost_information, error_str=error_str, diff --git a/litellm/litellm_core_utils/model_param_helper.py b/litellm/litellm_core_utils/model_param_helper.py new file mode 100644 index 0000000000..09a2c15a77 --- /dev/null +++ b/litellm/litellm_core_utils/model_param_helper.py @@ -0,0 +1,133 @@ +from typing import Set + +from openai.types.audio.transcription_create_params import TranscriptionCreateParams +from openai.types.chat.completion_create_params import ( + CompletionCreateParamsNonStreaming, + CompletionCreateParamsStreaming, +) +from openai.types.completion_create_params import ( + CompletionCreateParamsNonStreaming as TextCompletionCreateParamsNonStreaming, +) +from openai.types.completion_create_params import ( + CompletionCreateParamsStreaming as TextCompletionCreateParamsStreaming, +) +from openai.types.embedding_create_params import EmbeddingCreateParams + +from litellm.types.rerank import RerankRequest + + +class ModelParamHelper: + + @staticmethod + def get_standard_logging_model_parameters( + model_parameters: dict, + ) -> dict: + """ """ + standard_logging_model_parameters: dict = {} + supported_model_parameters = ( + ModelParamHelper._get_relevant_args_to_use_for_logging() + ) + + for key, value in model_parameters.items(): + if key in supported_model_parameters: + standard_logging_model_parameters[key] = value + return standard_logging_model_parameters + + @staticmethod + def get_exclude_params_for_model_parameters() -> Set[str]: + return set(["messages", "prompt", "input"]) + + @staticmethod + def _get_relevant_args_to_use_for_logging() -> Set[str]: + """ + Gets all relevant llm api params besides the ones with prompt content + """ + all_openai_llm_api_params = ModelParamHelper._get_all_llm_api_params() + # Exclude parameters that contain prompt content + combined_kwargs = all_openai_llm_api_params.difference( + set(ModelParamHelper.get_exclude_params_for_model_parameters()) + ) + return combined_kwargs + + @staticmethod + def _get_all_llm_api_params() -> Set[str]: + """ + Gets the supported kwargs for each call type and combines them + """ + chat_completion_kwargs = ( + ModelParamHelper._get_litellm_supported_chat_completion_kwargs() + ) + text_completion_kwargs = ( + ModelParamHelper._get_litellm_supported_text_completion_kwargs() + ) + embedding_kwargs = ModelParamHelper._get_litellm_supported_embedding_kwargs() + transcription_kwargs = ( + ModelParamHelper._get_litellm_supported_transcription_kwargs() + ) + rerank_kwargs = ModelParamHelper._get_litellm_supported_rerank_kwargs() + exclude_kwargs = ModelParamHelper._get_exclude_kwargs() + + combined_kwargs = chat_completion_kwargs.union( + text_completion_kwargs, + embedding_kwargs, + transcription_kwargs, + rerank_kwargs, + ) + combined_kwargs = combined_kwargs.difference(exclude_kwargs) + return combined_kwargs + + @staticmethod + def _get_litellm_supported_chat_completion_kwargs() -> Set[str]: + """ + Get the litellm supported chat completion kwargs + + This follows the OpenAI API Spec + """ + all_chat_completion_kwargs = set( + CompletionCreateParamsNonStreaming.__annotations__.keys() + ).union(set(CompletionCreateParamsStreaming.__annotations__.keys())) + return all_chat_completion_kwargs + + @staticmethod + def _get_litellm_supported_text_completion_kwargs() -> Set[str]: + """ + Get the litellm supported text completion kwargs + + This follows the OpenAI API Spec + """ + all_text_completion_kwargs = set( + TextCompletionCreateParamsNonStreaming.__annotations__.keys() + ).union(set(TextCompletionCreateParamsStreaming.__annotations__.keys())) + return all_text_completion_kwargs + + @staticmethod + def _get_litellm_supported_rerank_kwargs() -> Set[str]: + """ + Get the litellm supported rerank kwargs + """ + return set(RerankRequest.model_fields.keys()) + + @staticmethod + def _get_litellm_supported_embedding_kwargs() -> Set[str]: + """ + Get the litellm supported embedding kwargs + + This follows the OpenAI API Spec + """ + return set(EmbeddingCreateParams.__annotations__.keys()) + + @staticmethod + def _get_litellm_supported_transcription_kwargs() -> Set[str]: + """ + Get the litellm supported transcription kwargs + + This follows the OpenAI API Spec + """ + return set(TranscriptionCreateParams.__annotations__.keys()) + + @staticmethod + def _get_exclude_kwargs() -> Set[str]: + """ + Get the kwargs to exclude from the cache key + """ + return set(["metadata"]) diff --git a/tests/local_testing/test_unit_test_caching.py b/tests/local_testing/test_unit_test_caching.py index b1e8d4fe61..033fb774f0 100644 --- a/tests/local_testing/test_unit_test_caching.py +++ b/tests/local_testing/test_unit_test_caching.py @@ -34,13 +34,14 @@ from litellm.types.utils import ( ) from datetime import timedelta, datetime from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging +from litellm.litellm_core_utils.model_param_helper import ModelParamHelper from litellm._logging import verbose_logger import logging def test_get_kwargs_for_cache_key(): _cache = litellm.Cache() - relevant_kwargs = _cache._get_relevant_args_to_use_for_cache_key() + relevant_kwargs = ModelParamHelper._get_all_llm_api_params() print(relevant_kwargs) diff --git a/tests/logging_callback_tests/test_langsmith_unit_test.py b/tests/logging_callback_tests/test_langsmith_unit_test.py index 9f99ed4a11..2ec5f1a2e4 100644 --- a/tests/logging_callback_tests/test_langsmith_unit_test.py +++ b/tests/logging_callback_tests/test_langsmith_unit_test.py @@ -264,7 +264,6 @@ async def test_langsmith_key_based_logging(mocker): "model_parameters": { "temperature": 0.2, "max_tokens": 10, - "extra_body": {}, }, }, "outputs": {