Merge branch 'main' into litellm_contributor_prs_03_24_2025_p1

This commit is contained in:
Krish Dholakia 2025-04-09 22:42:36 -07:00 committed by GitHub
commit ba03736077
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
727 changed files with 35116 additions and 9106 deletions

View file

@ -57,9 +57,21 @@ import litellm._service_logger # for storing API inputs, outputs, and metadata
import litellm.litellm_core_utils
import litellm.litellm_core_utils.audio_utils.utils
import litellm.litellm_core_utils.json_validation_rule
import litellm.llms
import litellm.llms.gemini
from litellm.caching._internal_lru_cache import lru_cache_wrapper
from litellm.caching.caching import DualCache
from litellm.caching.caching_handler import CachingHandlerResponse, LLMCachingHandler
from litellm.constants import (
DEFAULT_MAX_LRU_CACHE_SIZE,
DEFAULT_TRIM_RATIO,
FUNCTION_DEFINITION_TOKEN_COUNT,
INITIAL_RETRY_DELAY,
JITTER,
MAX_RETRY_DELAY,
MINIMUM_PROMPT_CACHE_TOKEN_COUNT,
TOOL_CHOICE_OBJECT_TOKEN_COUNT,
)
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.core_helpers import (
@ -207,6 +219,7 @@ from litellm.llms.base_llm.base_utils import (
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.base_llm.completion.transformation import BaseTextCompletionConfig
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
from litellm.llms.base_llm.files.transformation import BaseFilesConfig
from litellm.llms.base_llm.image_variations.transformation import (
BaseImageVariationConfig,
)
@ -482,7 +495,6 @@ def get_dynamic_callbacks(
def function_setup( # noqa: PLR0915
original_function: str, rules_obj, start_time, *args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
### NOTICES ###
from litellm import Logging as LiteLLMLogging
from litellm.litellm_core_utils.litellm_logging import set_callbacks
@ -504,9 +516,9 @@ def function_setup( # noqa: PLR0915
function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None
## DYNAMIC CALLBACKS ##
dynamic_callbacks: Optional[List[Union[str, Callable, CustomLogger]]] = (
kwargs.pop("callbacks", None)
)
dynamic_callbacks: Optional[
List[Union[str, Callable, CustomLogger]]
] = kwargs.pop("callbacks", None)
all_callbacks = get_dynamic_callbacks(dynamic_callbacks=dynamic_callbacks)
if len(all_callbacks) > 0:
@ -1190,9 +1202,9 @@ def client(original_function): # noqa: PLR0915
exception=e,
retry_policy=kwargs.get("retry_policy"),
)
kwargs["retry_policy"] = (
reset_retry_policy()
) # prevent infinite loops
kwargs[
"retry_policy"
] = reset_retry_policy() # prevent infinite loops
litellm.num_retries = (
None # set retries to None to prevent infinite loops
)
@ -1260,6 +1272,7 @@ def client(original_function): # noqa: PLR0915
logging_obj, kwargs = function_setup(
original_function.__name__, rules_obj, start_time, *args, **kwargs
)
kwargs["litellm_logging_obj"] = logging_obj
## LOAD CREDENTIALS
load_credentials_from_list(kwargs)
@ -1404,7 +1417,6 @@ def client(original_function): # noqa: PLR0915
if (
num_retries and not _is_litellm_router_call
): # only enter this if call is not from litellm router/proxy. router has it's own logic for retrying
try:
litellm.num_retries = (
None # set retries to None to prevent infinite loops
@ -1425,7 +1437,6 @@ def client(original_function): # noqa: PLR0915
and context_window_fallback_dict
and model in context_window_fallback_dict
):
if len(args) > 0:
args[0] = context_window_fallback_dict[model] # type: ignore
else:
@ -1519,9 +1530,8 @@ def _select_tokenizer(
return _select_tokenizer_helper(model=model)
@lru_cache(maxsize=128)
@lru_cache(maxsize=DEFAULT_MAX_LRU_CACHE_SIZE)
def _select_tokenizer_helper(model: str) -> SelectTokenizerResponse:
if litellm.disable_hf_tokenizer_download is True:
return _return_openai_tokenizer(model)
@ -2235,7 +2245,8 @@ def supports_embedding_image_input(
####### HELPER FUNCTIONS ################
def _update_dictionary(existing_dict: Dict, new_dict: dict) -> dict:
for k, v in new_dict.items():
existing_dict[k] = v
if v is not None:
existing_dict[k] = v
return existing_dict
@ -2631,7 +2642,7 @@ def get_optional_params_embeddings( # noqa: PLR0915
non_default_params=non_default_params, optional_params={}, kwargs=kwargs
)
return optional_params
elif custom_llm_provider == "vertex_ai":
elif custom_llm_provider == "vertex_ai" or custom_llm_provider == "gemini":
supported_params = get_supported_openai_params(
model=model,
custom_llm_provider="vertex_ai",
@ -2846,6 +2857,7 @@ def get_optional_params( # noqa: PLR0915
api_version=None,
parallel_tool_calls=None,
drop_params=None,
allowed_openai_params: Optional[List[str]] = None,
reasoning_effort=None,
additional_drop_params=None,
messages: Optional[List[AllMessageValues]] = None,
@ -2931,6 +2943,7 @@ def get_optional_params( # noqa: PLR0915
"api_version": None,
"parallel_tool_calls": None,
"drop_params": None,
"allowed_openai_params": None,
"additional_drop_params": None,
"messages": None,
"reasoning_effort": None,
@ -2947,6 +2960,7 @@ def get_optional_params( # noqa: PLR0915
and k != "custom_llm_provider"
and k != "api_version"
and k != "drop_params"
and k != "allowed_openai_params"
and k != "additional_drop_params"
and k != "messages"
and k in default_params
@ -2993,16 +3007,16 @@ def get_optional_params( # noqa: PLR0915
True # so that main.py adds the function call to the prompt
)
if "tools" in non_default_params:
optional_params["functions_unsupported_model"] = (
non_default_params.pop("tools")
)
optional_params[
"functions_unsupported_model"
] = non_default_params.pop("tools")
non_default_params.pop(
"tool_choice", None
) # causes ollama requests to hang
elif "functions" in non_default_params:
optional_params["functions_unsupported_model"] = (
non_default_params.pop("functions")
)
optional_params[
"functions_unsupported_model"
] = non_default_params.pop("functions")
elif (
litellm.add_function_to_prompt
): # if user opts to add it to prompt instead
@ -3025,10 +3039,10 @@ def get_optional_params( # noqa: PLR0915
if "response_format" in non_default_params:
if provider_config is not None:
non_default_params["response_format"] = (
provider_config.get_json_schema_from_pydantic_object(
response_format=non_default_params["response_format"]
)
non_default_params[
"response_format"
] = provider_config.get_json_schema_from_pydantic_object(
response_format=non_default_params["response_format"]
)
else:
non_default_params["response_format"] = type_to_response_format_param(
@ -3056,6 +3070,12 @@ def get_optional_params( # noqa: PLR0915
tool_function["parameters"] = new_parameters
def _check_valid_arg(supported_params: List[str]):
"""
Check if the params passed to completion() are supported by the provider
Args:
supported_params: List[str] - supported params from the litellm config
"""
verbose_logger.info(
f"\nLiteLLM completion() model= {model}; provider = {custom_llm_provider}"
)
@ -3089,7 +3109,7 @@ def get_optional_params( # noqa: PLR0915
else:
raise UnsupportedParamsError(
status_code=500,
message=f"{custom_llm_provider} does not support parameters: {unsupported_params}, for model={model}. To drop these, set `litellm.drop_params=True` or for proxy:\n\n`litellm_settings:\n drop_params: true`\n",
message=f"{custom_llm_provider} does not support parameters: {list(unsupported_params.keys())}, for model={model}. To drop these, set `litellm.drop_params=True` or for proxy:\n\n`litellm_settings:\n drop_params: true`\n. \n If you want to use these params dynamically send allowed_openai_params={list(unsupported_params.keys())} in your request.",
)
supported_params = get_supported_openai_params(
@ -3099,7 +3119,14 @@ def get_optional_params( # noqa: PLR0915
supported_params = get_supported_openai_params(
model=model, custom_llm_provider="openai"
)
_check_valid_arg(supported_params=supported_params or [])
supported_params = supported_params or []
allowed_openai_params = allowed_openai_params or []
supported_params.extend(allowed_openai_params)
_check_valid_arg(
supported_params=supported_params or [],
)
## raise exception if provider doesn't support passed in param
if custom_llm_provider == "anthropic":
## check if unsupported param passed in
@ -3180,7 +3207,6 @@ def get_optional_params( # noqa: PLR0915
),
)
elif custom_llm_provider == "replicate":
optional_params = litellm.ReplicateConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
@ -3203,7 +3229,7 @@ def get_optional_params( # noqa: PLR0915
),
)
elif custom_llm_provider == "huggingface":
optional_params = litellm.HuggingfaceConfig().map_openai_params(
optional_params = litellm.HuggingFaceChatConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
@ -3214,7 +3240,6 @@ def get_optional_params( # noqa: PLR0915
),
)
elif custom_llm_provider == "together_ai":
optional_params = litellm.TogetherAIConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
@ -3282,7 +3307,6 @@ def get_optional_params( # noqa: PLR0915
),
)
elif custom_llm_provider == "vertex_ai":
if model in litellm.vertex_mistral_models:
if "codestral" in model:
optional_params = (
@ -3356,12 +3380,10 @@ def get_optional_params( # noqa: PLR0915
if drop_params is not None and isinstance(drop_params, bool)
else False
),
messages=messages,
)
elif "anthropic" in bedrock_base_model and bedrock_route == "invoke":
if bedrock_base_model.startswith("anthropic.claude-3"):
optional_params = (
litellm.AmazonAnthropicClaude3Config().map_openai_params(
non_default_params=non_default_params,
@ -3398,7 +3420,6 @@ def get_optional_params( # noqa: PLR0915
),
)
elif custom_llm_provider == "cloudflare":
optional_params = litellm.CloudflareChatConfig().map_openai_params(
model=model,
non_default_params=non_default_params,
@ -3410,7 +3431,6 @@ def get_optional_params( # noqa: PLR0915
),
)
elif custom_llm_provider == "ollama":
optional_params = litellm.OllamaConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
@ -3422,7 +3442,6 @@ def get_optional_params( # noqa: PLR0915
),
)
elif custom_llm_provider == "ollama_chat":
optional_params = litellm.OllamaChatConfig().map_openai_params(
model=model,
non_default_params=non_default_params,
@ -3703,6 +3722,17 @@ def get_optional_params( # noqa: PLR0915
else False
),
)
elif provider_config is not None:
optional_params = provider_config.map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
else: # assume passing in params for openai-like api
optional_params = litellm.OpenAILikeChatConfig().map_openai_params(
non_default_params=non_default_params,
@ -3745,6 +3775,26 @@ def get_optional_params( # noqa: PLR0915
if k not in default_params.keys():
optional_params[k] = passed_params[k]
print_verbose(f"Final returned optional params: {optional_params}")
optional_params = _apply_openai_param_overrides(
optional_params=optional_params,
non_default_params=non_default_params,
allowed_openai_params=allowed_openai_params,
)
return optional_params
def _apply_openai_param_overrides(
optional_params: dict, non_default_params: dict, allowed_openai_params: list
):
"""
If user passes in allowed_openai_params, apply them to optional_params
These params will get passed as is to the LLM API since the user opted in to passing them in the request
"""
if allowed_openai_params:
for param in allowed_openai_params:
if param not in optional_params:
optional_params[param] = non_default_params.pop(param, None)
return optional_params
@ -4008,9 +4058,9 @@ def _count_characters(text: str) -> int:
def get_response_string(response_obj: Union[ModelResponse, ModelResponseStream]) -> str:
_choices: Union[List[Union[Choices, StreamingChoices]], List[StreamingChoices]] = (
response_obj.choices
)
_choices: Union[
List[Union[Choices, StreamingChoices]], List[StreamingChoices]
] = response_obj.choices
response_str = ""
for choice in _choices:
@ -4408,7 +4458,6 @@ def _get_model_info_helper( # noqa: PLR0915
):
_model_info = None
if _model_info is None and model in litellm.model_cost:
key = model
_model_info = _get_model_info_from_model_cost(key=key)
if not _check_provider_match(
@ -4419,7 +4468,6 @@ def _get_model_info_helper( # noqa: PLR0915
_model_info is None
and combined_stripped_model_name in litellm.model_cost
):
key = combined_stripped_model_name
_model_info = _get_model_info_from_model_cost(key=key)
if not _check_provider_match(
@ -4427,7 +4475,6 @@ def _get_model_info_helper( # noqa: PLR0915
):
_model_info = None
if _model_info is None and stripped_model_name in litellm.model_cost:
key = stripped_model_name
_model_info = _get_model_info_from_model_cost(key=key)
if not _check_provider_match(
@ -4435,7 +4482,6 @@ def _get_model_info_helper( # noqa: PLR0915
):
_model_info = None
if _model_info is None and split_model in litellm.model_cost:
key = split_model
_model_info = _get_model_info_from_model_cost(key=key)
if not _check_provider_match(
@ -4490,6 +4536,9 @@ def _get_model_info_helper( # noqa: PLR0915
input_cost_per_token_above_128k_tokens=_model_info.get(
"input_cost_per_token_above_128k_tokens", None
),
input_cost_per_token_above_200k_tokens=_model_info.get(
"input_cost_per_token_above_200k_tokens", None
),
input_cost_per_query=_model_info.get("input_cost_per_query", None),
input_cost_per_second=_model_info.get("input_cost_per_second", None),
input_cost_per_audio_token=_model_info.get(
@ -4514,6 +4563,9 @@ def _get_model_info_helper( # noqa: PLR0915
output_cost_per_character_above_128k_tokens=_model_info.get(
"output_cost_per_character_above_128k_tokens", None
),
output_cost_per_token_above_200k_tokens=_model_info.get(
"output_cost_per_token_above_200k_tokens", None
),
output_cost_per_second=_model_info.get("output_cost_per_second", None),
output_cost_per_image=_model_info.get("output_cost_per_image", None),
output_vector_size=_model_info.get("output_vector_size", None),
@ -5324,15 +5376,15 @@ def _calculate_retry_after(
if retry_after is not None and 0 < retry_after <= 60:
return retry_after
initial_retry_delay = 0.5
max_retry_delay = 8.0
initial_retry_delay = INITIAL_RETRY_DELAY
max_retry_delay = MAX_RETRY_DELAY
nb_retries = max_retries - remaining_retries
# Apply exponential backoff, but not more than the max.
sleep_seconds = min(initial_retry_delay * pow(2.0, nb_retries), max_retry_delay)
# Apply some jitter, plus-or-minus half a second.
jitter = 1 - 0.25 * random.random()
jitter = JITTER * random.random()
timeout = sleep_seconds * jitter
return timeout if timeout >= min_timeout else min_timeout
@ -5658,7 +5710,7 @@ def shorten_message_to_fit_limit(message, tokens_needed, model: Optional[str]):
def trim_messages(
messages,
model: Optional[str] = None,
trim_ratio: float = 0.75,
trim_ratio: float = DEFAULT_TRIM_RATIO,
return_response_tokens: bool = False,
max_tokens=None,
):
@ -5757,13 +5809,15 @@ def trim_messages(
return messages
def get_valid_models(check_provider_endpoint: bool = False) -> List[str]:
def get_valid_models(
check_provider_endpoint: bool = False, custom_llm_provider: Optional[str] = None
) -> List[str]:
"""
Returns a list of valid LLMs based on the set environment variables
Args:
check_provider_endpoint: If True, will check the provider's endpoint for valid models.
custom_llm_provider: If provided, will only check the provider's endpoint for valid models.
Returns:
A list of valid LLMs
"""
@ -5775,6 +5829,9 @@ def get_valid_models(check_provider_endpoint: bool = False) -> List[str]:
valid_models = []
for provider in litellm.provider_list:
if custom_llm_provider and provider != custom_llm_provider:
continue
# edge case litellm has together_ai as a provider, it should be togetherai
env_provider_1 = provider.replace("_", "")
env_provider_2 = provider
@ -5796,10 +5853,17 @@ def get_valid_models(check_provider_endpoint: bool = False) -> List[str]:
provider=LlmProviders(provider),
)
if custom_llm_provider and provider != custom_llm_provider:
continue
if provider == "azure":
valid_models.append("Azure-LLM")
elif provider_config is not None and check_provider_endpoint:
valid_models.extend(provider_config.get_models())
try:
models = provider_config.get_models()
valid_models.extend(models)
except Exception as e:
verbose_logger.debug(f"Error getting valid models: {e}")
else:
models_for_provider = litellm.models_by_provider.get(provider, [])
valid_models.extend(models_for_provider)
@ -5917,9 +5981,10 @@ class ModelResponseIterator:
class ModelResponseListIterator:
def __init__(self, model_responses):
def __init__(self, model_responses, delay: Optional[float] = None):
self.model_responses = model_responses
self.index = 0
self.delay = delay
# Sync iterator
def __iter__(self):
@ -5930,6 +5995,8 @@ class ModelResponseListIterator:
raise StopIteration
model_response = self.model_responses[self.index]
self.index += 1
if self.delay:
time.sleep(self.delay)
return model_response
# Async iterator
@ -5941,6 +6008,8 @@ class ModelResponseListIterator:
raise StopAsyncIteration
model_response = self.model_responses[self.index]
self.index += 1
if self.delay:
await asyncio.sleep(self.delay)
return model_response
@ -6073,6 +6142,8 @@ def validate_and_fix_openai_messages(messages: List):
for message in messages:
if not message.get("role"):
message["role"] = "assistant"
if message.get("tool_calls"):
message["tool_calls"] = jsonify_tools(tools=message["tool_calls"])
return validate_chat_completion_messages(messages=messages)
@ -6161,7 +6232,7 @@ class ProviderConfigManager:
@staticmethod
def get_provider_chat_config( # noqa: PLR0915
model: str, provider: LlmProviders
) -> BaseConfig:
) -> Optional[BaseConfig]:
"""
Returns the provider config for a given provider.
"""
@ -6192,9 +6263,22 @@ class ProviderConfigManager:
return litellm.AnthropicConfig()
elif litellm.LlmProviders.ANTHROPIC_TEXT == provider:
return litellm.AnthropicTextConfig()
elif litellm.LlmProviders.VERTEX_AI_BETA == provider:
return litellm.VertexGeminiConfig()
elif litellm.LlmProviders.VERTEX_AI == provider:
if "claude" in model:
if "gemini" in model:
return litellm.VertexGeminiConfig()
elif "claude" in model:
return litellm.VertexAIAnthropicConfig()
elif model in litellm.vertex_mistral_models:
if "codestral" in model:
return litellm.CodestralTextCompletionConfig()
else:
return litellm.MistralConfig()
elif model in litellm.vertex_ai_ai21_models:
return litellm.VertexAIAi21Config()
else: # use generic openai-like param mapping
return litellm.VertexAILlama3Config()
elif litellm.LlmProviders.CLOUDFLARE == provider:
return litellm.CloudflareChatConfig()
elif litellm.LlmProviders.SAGEMAKER_CHAT == provider:
@ -6217,7 +6301,6 @@ class ProviderConfigManager:
litellm.LlmProviders.CUSTOM == provider
or litellm.LlmProviders.CUSTOM_OPENAI == provider
or litellm.LlmProviders.OPENAI_LIKE == provider
or litellm.LlmProviders.LITELLM_PROXY == provider
):
return litellm.OpenAILikeChatConfig()
elif litellm.LlmProviders.AIOHTTP_OPENAI == provider:
@ -6231,7 +6314,7 @@ class ProviderConfigManager:
elif litellm.LlmProviders.REPLICATE == provider:
return litellm.ReplicateConfig()
elif litellm.LlmProviders.HUGGINGFACE == provider:
return litellm.HuggingfaceConfig()
return litellm.HuggingFaceChatConfig()
elif litellm.LlmProviders.TOGETHER_AI == provider:
return litellm.TogetherAIConfig()
elif litellm.LlmProviders.OPENROUTER == provider:
@ -6324,9 +6407,15 @@ class ProviderConfigManager:
return litellm.AmazonMistralConfig()
elif bedrock_invoke_provider == "deepseek_r1": # deepseek models on bedrock
return litellm.AmazonDeepSeekR1Config()
elif bedrock_invoke_provider == "nova":
return litellm.AmazonInvokeNovaConfig()
else:
return litellm.AmazonInvokeConfig()
return litellm.OpenAIGPTConfig()
elif litellm.LlmProviders.LITELLM_PROXY == provider:
return litellm.LiteLLMProxyChatConfig()
elif litellm.LlmProviders.OPENAI == provider:
return litellm.OpenAIGPTConfig()
return None
@staticmethod
def get_provider_embedding_config(
@ -6379,6 +6468,11 @@ class ProviderConfigManager:
return litellm.FireworksAIAudioTranscriptionConfig()
elif litellm.LlmProviders.DEEPGRAM == provider:
return litellm.DeepgramAudioTranscriptionConfig()
elif litellm.LlmProviders.OPENAI == provider:
if "gpt-4o" in model:
return litellm.OpenAIGPTAudioTranscriptionConfig()
else:
return litellm.OpenAIWhisperAudioTranscriptionConfig()
return None
@staticmethod
@ -6410,10 +6504,16 @@ class ProviderConfigManager:
return litellm.FireworksAIConfig()
elif LlmProviders.OPENAI == provider:
return litellm.OpenAIGPTConfig()
elif LlmProviders.GEMINI == provider:
return litellm.GeminiModelInfo()
elif LlmProviders.LITELLM_PROXY == provider:
return litellm.LiteLLMProxyChatConfig()
elif LlmProviders.TOPAZ == provider:
return litellm.TopazModelInfo()
elif LlmProviders.ANTHROPIC == provider:
return litellm.AnthropicModelInfo()
elif LlmProviders.XAI == provider:
return litellm.XAIModelInfo()
return None
@ -6428,6 +6528,23 @@ class ProviderConfigManager:
return litellm.TopazImageVariationConfig()
return None
@staticmethod
def get_provider_files_config(
model: str,
provider: LlmProviders,
) -> Optional[BaseFilesConfig]:
if LlmProviders.GEMINI == provider:
from litellm.llms.gemini.files.transformation import (
GoogleAIStudioFilesHandler, # experimental approach, to reduce bloat on __init__.py
)
return GoogleAIStudioFilesHandler()
elif LlmProviders.VERTEX_AI == provider:
from litellm.llms.vertex_ai.files.transformation import VertexAIFilesConfig
return VertexAIFilesConfig()
return None
def get_end_user_id_for_cost_tracking(
litellm_params: dict,
@ -6492,7 +6609,7 @@ def is_prompt_caching_valid_prompt(
model=model,
use_default_image_token_count=True,
)
return token_count >= 1024
return token_count >= MINIMUM_PROMPT_CACHE_TOKEN_COUNT
except Exception as e:
verbose_logger.error(f"Error in is_prompt_caching_valid_prompt: {e}")
return False
@ -6644,3 +6761,20 @@ def return_raw_request(endpoint: CallTypes, kwargs: dict) -> RawRequestTypedDict
return RawRequestTypedDict(
error=received_exception,
)
def jsonify_tools(tools: List[Any]) -> List[Dict]:
"""
Fixes https://github.com/BerriAI/litellm/issues/9321
Where user passes in a pydantic base model
"""
new_tools: List[Dict] = []
for tool in tools:
if isinstance(tool, BaseModel):
tool = tool.model_dump(exclude_none=True)
elif isinstance(tool, dict):
tool = tool.copy()
if isinstance(tool, dict):
new_tools.append(tool)
return new_tools