mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
Merge branch 'main' into litellm_contributor_prs_03_24_2025_p1
This commit is contained in:
commit
ba03736077
727 changed files with 35116 additions and 9106 deletions
244
litellm/utils.py
244
litellm/utils.py
|
@ -57,9 +57,21 @@ import litellm._service_logger # for storing API inputs, outputs, and metadata
|
|||
import litellm.litellm_core_utils
|
||||
import litellm.litellm_core_utils.audio_utils.utils
|
||||
import litellm.litellm_core_utils.json_validation_rule
|
||||
import litellm.llms
|
||||
import litellm.llms.gemini
|
||||
from litellm.caching._internal_lru_cache import lru_cache_wrapper
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.caching.caching_handler import CachingHandlerResponse, LLMCachingHandler
|
||||
from litellm.constants import (
|
||||
DEFAULT_MAX_LRU_CACHE_SIZE,
|
||||
DEFAULT_TRIM_RATIO,
|
||||
FUNCTION_DEFINITION_TOKEN_COUNT,
|
||||
INITIAL_RETRY_DELAY,
|
||||
JITTER,
|
||||
MAX_RETRY_DELAY,
|
||||
MINIMUM_PROMPT_CACHE_TOKEN_COUNT,
|
||||
TOOL_CHOICE_OBJECT_TOKEN_COUNT,
|
||||
)
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.core_helpers import (
|
||||
|
@ -207,6 +219,7 @@ from litellm.llms.base_llm.base_utils import (
|
|||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.llms.base_llm.completion.transformation import BaseTextCompletionConfig
|
||||
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
|
||||
from litellm.llms.base_llm.files.transformation import BaseFilesConfig
|
||||
from litellm.llms.base_llm.image_variations.transformation import (
|
||||
BaseImageVariationConfig,
|
||||
)
|
||||
|
@ -482,7 +495,6 @@ def get_dynamic_callbacks(
|
|||
def function_setup( # noqa: PLR0915
|
||||
original_function: str, rules_obj, start_time, *args, **kwargs
|
||||
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
|
||||
|
||||
### NOTICES ###
|
||||
from litellm import Logging as LiteLLMLogging
|
||||
from litellm.litellm_core_utils.litellm_logging import set_callbacks
|
||||
|
@ -504,9 +516,9 @@ def function_setup( # noqa: PLR0915
|
|||
function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None
|
||||
|
||||
## DYNAMIC CALLBACKS ##
|
||||
dynamic_callbacks: Optional[List[Union[str, Callable, CustomLogger]]] = (
|
||||
kwargs.pop("callbacks", None)
|
||||
)
|
||||
dynamic_callbacks: Optional[
|
||||
List[Union[str, Callable, CustomLogger]]
|
||||
] = kwargs.pop("callbacks", None)
|
||||
all_callbacks = get_dynamic_callbacks(dynamic_callbacks=dynamic_callbacks)
|
||||
|
||||
if len(all_callbacks) > 0:
|
||||
|
@ -1190,9 +1202,9 @@ def client(original_function): # noqa: PLR0915
|
|||
exception=e,
|
||||
retry_policy=kwargs.get("retry_policy"),
|
||||
)
|
||||
kwargs["retry_policy"] = (
|
||||
reset_retry_policy()
|
||||
) # prevent infinite loops
|
||||
kwargs[
|
||||
"retry_policy"
|
||||
] = reset_retry_policy() # prevent infinite loops
|
||||
litellm.num_retries = (
|
||||
None # set retries to None to prevent infinite loops
|
||||
)
|
||||
|
@ -1260,6 +1272,7 @@ def client(original_function): # noqa: PLR0915
|
|||
logging_obj, kwargs = function_setup(
|
||||
original_function.__name__, rules_obj, start_time, *args, **kwargs
|
||||
)
|
||||
|
||||
kwargs["litellm_logging_obj"] = logging_obj
|
||||
## LOAD CREDENTIALS
|
||||
load_credentials_from_list(kwargs)
|
||||
|
@ -1404,7 +1417,6 @@ def client(original_function): # noqa: PLR0915
|
|||
if (
|
||||
num_retries and not _is_litellm_router_call
|
||||
): # only enter this if call is not from litellm router/proxy. router has it's own logic for retrying
|
||||
|
||||
try:
|
||||
litellm.num_retries = (
|
||||
None # set retries to None to prevent infinite loops
|
||||
|
@ -1425,7 +1437,6 @@ def client(original_function): # noqa: PLR0915
|
|||
and context_window_fallback_dict
|
||||
and model in context_window_fallback_dict
|
||||
):
|
||||
|
||||
if len(args) > 0:
|
||||
args[0] = context_window_fallback_dict[model] # type: ignore
|
||||
else:
|
||||
|
@ -1519,9 +1530,8 @@ def _select_tokenizer(
|
|||
return _select_tokenizer_helper(model=model)
|
||||
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
@lru_cache(maxsize=DEFAULT_MAX_LRU_CACHE_SIZE)
|
||||
def _select_tokenizer_helper(model: str) -> SelectTokenizerResponse:
|
||||
|
||||
if litellm.disable_hf_tokenizer_download is True:
|
||||
return _return_openai_tokenizer(model)
|
||||
|
||||
|
@ -2235,7 +2245,8 @@ def supports_embedding_image_input(
|
|||
####### HELPER FUNCTIONS ################
|
||||
def _update_dictionary(existing_dict: Dict, new_dict: dict) -> dict:
|
||||
for k, v in new_dict.items():
|
||||
existing_dict[k] = v
|
||||
if v is not None:
|
||||
existing_dict[k] = v
|
||||
|
||||
return existing_dict
|
||||
|
||||
|
@ -2631,7 +2642,7 @@ def get_optional_params_embeddings( # noqa: PLR0915
|
|||
non_default_params=non_default_params, optional_params={}, kwargs=kwargs
|
||||
)
|
||||
return optional_params
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
elif custom_llm_provider == "vertex_ai" or custom_llm_provider == "gemini":
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model,
|
||||
custom_llm_provider="vertex_ai",
|
||||
|
@ -2846,6 +2857,7 @@ def get_optional_params( # noqa: PLR0915
|
|||
api_version=None,
|
||||
parallel_tool_calls=None,
|
||||
drop_params=None,
|
||||
allowed_openai_params: Optional[List[str]] = None,
|
||||
reasoning_effort=None,
|
||||
additional_drop_params=None,
|
||||
messages: Optional[List[AllMessageValues]] = None,
|
||||
|
@ -2931,6 +2943,7 @@ def get_optional_params( # noqa: PLR0915
|
|||
"api_version": None,
|
||||
"parallel_tool_calls": None,
|
||||
"drop_params": None,
|
||||
"allowed_openai_params": None,
|
||||
"additional_drop_params": None,
|
||||
"messages": None,
|
||||
"reasoning_effort": None,
|
||||
|
@ -2947,6 +2960,7 @@ def get_optional_params( # noqa: PLR0915
|
|||
and k != "custom_llm_provider"
|
||||
and k != "api_version"
|
||||
and k != "drop_params"
|
||||
and k != "allowed_openai_params"
|
||||
and k != "additional_drop_params"
|
||||
and k != "messages"
|
||||
and k in default_params
|
||||
|
@ -2993,16 +3007,16 @@ def get_optional_params( # noqa: PLR0915
|
|||
True # so that main.py adds the function call to the prompt
|
||||
)
|
||||
if "tools" in non_default_params:
|
||||
optional_params["functions_unsupported_model"] = (
|
||||
non_default_params.pop("tools")
|
||||
)
|
||||
optional_params[
|
||||
"functions_unsupported_model"
|
||||
] = non_default_params.pop("tools")
|
||||
non_default_params.pop(
|
||||
"tool_choice", None
|
||||
) # causes ollama requests to hang
|
||||
elif "functions" in non_default_params:
|
||||
optional_params["functions_unsupported_model"] = (
|
||||
non_default_params.pop("functions")
|
||||
)
|
||||
optional_params[
|
||||
"functions_unsupported_model"
|
||||
] = non_default_params.pop("functions")
|
||||
elif (
|
||||
litellm.add_function_to_prompt
|
||||
): # if user opts to add it to prompt instead
|
||||
|
@ -3025,10 +3039,10 @@ def get_optional_params( # noqa: PLR0915
|
|||
|
||||
if "response_format" in non_default_params:
|
||||
if provider_config is not None:
|
||||
non_default_params["response_format"] = (
|
||||
provider_config.get_json_schema_from_pydantic_object(
|
||||
response_format=non_default_params["response_format"]
|
||||
)
|
||||
non_default_params[
|
||||
"response_format"
|
||||
] = provider_config.get_json_schema_from_pydantic_object(
|
||||
response_format=non_default_params["response_format"]
|
||||
)
|
||||
else:
|
||||
non_default_params["response_format"] = type_to_response_format_param(
|
||||
|
@ -3056,6 +3070,12 @@ def get_optional_params( # noqa: PLR0915
|
|||
tool_function["parameters"] = new_parameters
|
||||
|
||||
def _check_valid_arg(supported_params: List[str]):
|
||||
"""
|
||||
Check if the params passed to completion() are supported by the provider
|
||||
|
||||
Args:
|
||||
supported_params: List[str] - supported params from the litellm config
|
||||
"""
|
||||
verbose_logger.info(
|
||||
f"\nLiteLLM completion() model= {model}; provider = {custom_llm_provider}"
|
||||
)
|
||||
|
@ -3089,7 +3109,7 @@ def get_optional_params( # noqa: PLR0915
|
|||
else:
|
||||
raise UnsupportedParamsError(
|
||||
status_code=500,
|
||||
message=f"{custom_llm_provider} does not support parameters: {unsupported_params}, for model={model}. To drop these, set `litellm.drop_params=True` or for proxy:\n\n`litellm_settings:\n drop_params: true`\n",
|
||||
message=f"{custom_llm_provider} does not support parameters: {list(unsupported_params.keys())}, for model={model}. To drop these, set `litellm.drop_params=True` or for proxy:\n\n`litellm_settings:\n drop_params: true`\n. \n If you want to use these params dynamically send allowed_openai_params={list(unsupported_params.keys())} in your request.",
|
||||
)
|
||||
|
||||
supported_params = get_supported_openai_params(
|
||||
|
@ -3099,7 +3119,14 @@ def get_optional_params( # noqa: PLR0915
|
|||
supported_params = get_supported_openai_params(
|
||||
model=model, custom_llm_provider="openai"
|
||||
)
|
||||
_check_valid_arg(supported_params=supported_params or [])
|
||||
|
||||
supported_params = supported_params or []
|
||||
allowed_openai_params = allowed_openai_params or []
|
||||
supported_params.extend(allowed_openai_params)
|
||||
|
||||
_check_valid_arg(
|
||||
supported_params=supported_params or [],
|
||||
)
|
||||
## raise exception if provider doesn't support passed in param
|
||||
if custom_llm_provider == "anthropic":
|
||||
## check if unsupported param passed in
|
||||
|
@ -3180,7 +3207,6 @@ def get_optional_params( # noqa: PLR0915
|
|||
),
|
||||
)
|
||||
elif custom_llm_provider == "replicate":
|
||||
|
||||
optional_params = litellm.ReplicateConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
|
@ -3203,7 +3229,7 @@ def get_optional_params( # noqa: PLR0915
|
|||
),
|
||||
)
|
||||
elif custom_llm_provider == "huggingface":
|
||||
optional_params = litellm.HuggingfaceConfig().map_openai_params(
|
||||
optional_params = litellm.HuggingFaceChatConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
|
@ -3214,7 +3240,6 @@ def get_optional_params( # noqa: PLR0915
|
|||
),
|
||||
)
|
||||
elif custom_llm_provider == "together_ai":
|
||||
|
||||
optional_params = litellm.TogetherAIConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
|
@ -3282,7 +3307,6 @@ def get_optional_params( # noqa: PLR0915
|
|||
),
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
|
||||
if model in litellm.vertex_mistral_models:
|
||||
if "codestral" in model:
|
||||
optional_params = (
|
||||
|
@ -3356,12 +3380,10 @@ def get_optional_params( # noqa: PLR0915
|
|||
if drop_params is not None and isinstance(drop_params, bool)
|
||||
else False
|
||||
),
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
elif "anthropic" in bedrock_base_model and bedrock_route == "invoke":
|
||||
if bedrock_base_model.startswith("anthropic.claude-3"):
|
||||
|
||||
optional_params = (
|
||||
litellm.AmazonAnthropicClaude3Config().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
|
@ -3398,7 +3420,6 @@ def get_optional_params( # noqa: PLR0915
|
|||
),
|
||||
)
|
||||
elif custom_llm_provider == "cloudflare":
|
||||
|
||||
optional_params = litellm.CloudflareChatConfig().map_openai_params(
|
||||
model=model,
|
||||
non_default_params=non_default_params,
|
||||
|
@ -3410,7 +3431,6 @@ def get_optional_params( # noqa: PLR0915
|
|||
),
|
||||
)
|
||||
elif custom_llm_provider == "ollama":
|
||||
|
||||
optional_params = litellm.OllamaConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
|
@ -3422,7 +3442,6 @@ def get_optional_params( # noqa: PLR0915
|
|||
),
|
||||
)
|
||||
elif custom_llm_provider == "ollama_chat":
|
||||
|
||||
optional_params = litellm.OllamaChatConfig().map_openai_params(
|
||||
model=model,
|
||||
non_default_params=non_default_params,
|
||||
|
@ -3703,6 +3722,17 @@ def get_optional_params( # noqa: PLR0915
|
|||
else False
|
||||
),
|
||||
)
|
||||
elif provider_config is not None:
|
||||
optional_params = provider_config.map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=(
|
||||
drop_params
|
||||
if drop_params is not None and isinstance(drop_params, bool)
|
||||
else False
|
||||
),
|
||||
)
|
||||
else: # assume passing in params for openai-like api
|
||||
optional_params = litellm.OpenAILikeChatConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
|
@ -3745,6 +3775,26 @@ def get_optional_params( # noqa: PLR0915
|
|||
if k not in default_params.keys():
|
||||
optional_params[k] = passed_params[k]
|
||||
print_verbose(f"Final returned optional params: {optional_params}")
|
||||
optional_params = _apply_openai_param_overrides(
|
||||
optional_params=optional_params,
|
||||
non_default_params=non_default_params,
|
||||
allowed_openai_params=allowed_openai_params,
|
||||
)
|
||||
return optional_params
|
||||
|
||||
|
||||
def _apply_openai_param_overrides(
|
||||
optional_params: dict, non_default_params: dict, allowed_openai_params: list
|
||||
):
|
||||
"""
|
||||
If user passes in allowed_openai_params, apply them to optional_params
|
||||
|
||||
These params will get passed as is to the LLM API since the user opted in to passing them in the request
|
||||
"""
|
||||
if allowed_openai_params:
|
||||
for param in allowed_openai_params:
|
||||
if param not in optional_params:
|
||||
optional_params[param] = non_default_params.pop(param, None)
|
||||
return optional_params
|
||||
|
||||
|
||||
|
@ -4008,9 +4058,9 @@ def _count_characters(text: str) -> int:
|
|||
|
||||
|
||||
def get_response_string(response_obj: Union[ModelResponse, ModelResponseStream]) -> str:
|
||||
_choices: Union[List[Union[Choices, StreamingChoices]], List[StreamingChoices]] = (
|
||||
response_obj.choices
|
||||
)
|
||||
_choices: Union[
|
||||
List[Union[Choices, StreamingChoices]], List[StreamingChoices]
|
||||
] = response_obj.choices
|
||||
|
||||
response_str = ""
|
||||
for choice in _choices:
|
||||
|
@ -4408,7 +4458,6 @@ def _get_model_info_helper( # noqa: PLR0915
|
|||
):
|
||||
_model_info = None
|
||||
if _model_info is None and model in litellm.model_cost:
|
||||
|
||||
key = model
|
||||
_model_info = _get_model_info_from_model_cost(key=key)
|
||||
if not _check_provider_match(
|
||||
|
@ -4419,7 +4468,6 @@ def _get_model_info_helper( # noqa: PLR0915
|
|||
_model_info is None
|
||||
and combined_stripped_model_name in litellm.model_cost
|
||||
):
|
||||
|
||||
key = combined_stripped_model_name
|
||||
_model_info = _get_model_info_from_model_cost(key=key)
|
||||
if not _check_provider_match(
|
||||
|
@ -4427,7 +4475,6 @@ def _get_model_info_helper( # noqa: PLR0915
|
|||
):
|
||||
_model_info = None
|
||||
if _model_info is None and stripped_model_name in litellm.model_cost:
|
||||
|
||||
key = stripped_model_name
|
||||
_model_info = _get_model_info_from_model_cost(key=key)
|
||||
if not _check_provider_match(
|
||||
|
@ -4435,7 +4482,6 @@ def _get_model_info_helper( # noqa: PLR0915
|
|||
):
|
||||
_model_info = None
|
||||
if _model_info is None and split_model in litellm.model_cost:
|
||||
|
||||
key = split_model
|
||||
_model_info = _get_model_info_from_model_cost(key=key)
|
||||
if not _check_provider_match(
|
||||
|
@ -4490,6 +4536,9 @@ def _get_model_info_helper( # noqa: PLR0915
|
|||
input_cost_per_token_above_128k_tokens=_model_info.get(
|
||||
"input_cost_per_token_above_128k_tokens", None
|
||||
),
|
||||
input_cost_per_token_above_200k_tokens=_model_info.get(
|
||||
"input_cost_per_token_above_200k_tokens", None
|
||||
),
|
||||
input_cost_per_query=_model_info.get("input_cost_per_query", None),
|
||||
input_cost_per_second=_model_info.get("input_cost_per_second", None),
|
||||
input_cost_per_audio_token=_model_info.get(
|
||||
|
@ -4514,6 +4563,9 @@ def _get_model_info_helper( # noqa: PLR0915
|
|||
output_cost_per_character_above_128k_tokens=_model_info.get(
|
||||
"output_cost_per_character_above_128k_tokens", None
|
||||
),
|
||||
output_cost_per_token_above_200k_tokens=_model_info.get(
|
||||
"output_cost_per_token_above_200k_tokens", None
|
||||
),
|
||||
output_cost_per_second=_model_info.get("output_cost_per_second", None),
|
||||
output_cost_per_image=_model_info.get("output_cost_per_image", None),
|
||||
output_vector_size=_model_info.get("output_vector_size", None),
|
||||
|
@ -5324,15 +5376,15 @@ def _calculate_retry_after(
|
|||
if retry_after is not None and 0 < retry_after <= 60:
|
||||
return retry_after
|
||||
|
||||
initial_retry_delay = 0.5
|
||||
max_retry_delay = 8.0
|
||||
initial_retry_delay = INITIAL_RETRY_DELAY
|
||||
max_retry_delay = MAX_RETRY_DELAY
|
||||
nb_retries = max_retries - remaining_retries
|
||||
|
||||
# Apply exponential backoff, but not more than the max.
|
||||
sleep_seconds = min(initial_retry_delay * pow(2.0, nb_retries), max_retry_delay)
|
||||
|
||||
# Apply some jitter, plus-or-minus half a second.
|
||||
jitter = 1 - 0.25 * random.random()
|
||||
jitter = JITTER * random.random()
|
||||
timeout = sleep_seconds * jitter
|
||||
return timeout if timeout >= min_timeout else min_timeout
|
||||
|
||||
|
@ -5658,7 +5710,7 @@ def shorten_message_to_fit_limit(message, tokens_needed, model: Optional[str]):
|
|||
def trim_messages(
|
||||
messages,
|
||||
model: Optional[str] = None,
|
||||
trim_ratio: float = 0.75,
|
||||
trim_ratio: float = DEFAULT_TRIM_RATIO,
|
||||
return_response_tokens: bool = False,
|
||||
max_tokens=None,
|
||||
):
|
||||
|
@ -5757,13 +5809,15 @@ def trim_messages(
|
|||
return messages
|
||||
|
||||
|
||||
def get_valid_models(check_provider_endpoint: bool = False) -> List[str]:
|
||||
def get_valid_models(
|
||||
check_provider_endpoint: bool = False, custom_llm_provider: Optional[str] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Returns a list of valid LLMs based on the set environment variables
|
||||
|
||||
Args:
|
||||
check_provider_endpoint: If True, will check the provider's endpoint for valid models.
|
||||
|
||||
custom_llm_provider: If provided, will only check the provider's endpoint for valid models.
|
||||
Returns:
|
||||
A list of valid LLMs
|
||||
"""
|
||||
|
@ -5775,6 +5829,9 @@ def get_valid_models(check_provider_endpoint: bool = False) -> List[str]:
|
|||
valid_models = []
|
||||
|
||||
for provider in litellm.provider_list:
|
||||
if custom_llm_provider and provider != custom_llm_provider:
|
||||
continue
|
||||
|
||||
# edge case litellm has together_ai as a provider, it should be togetherai
|
||||
env_provider_1 = provider.replace("_", "")
|
||||
env_provider_2 = provider
|
||||
|
@ -5796,10 +5853,17 @@ def get_valid_models(check_provider_endpoint: bool = False) -> List[str]:
|
|||
provider=LlmProviders(provider),
|
||||
)
|
||||
|
||||
if custom_llm_provider and provider != custom_llm_provider:
|
||||
continue
|
||||
|
||||
if provider == "azure":
|
||||
valid_models.append("Azure-LLM")
|
||||
elif provider_config is not None and check_provider_endpoint:
|
||||
valid_models.extend(provider_config.get_models())
|
||||
try:
|
||||
models = provider_config.get_models()
|
||||
valid_models.extend(models)
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Error getting valid models: {e}")
|
||||
else:
|
||||
models_for_provider = litellm.models_by_provider.get(provider, [])
|
||||
valid_models.extend(models_for_provider)
|
||||
|
@ -5917,9 +5981,10 @@ class ModelResponseIterator:
|
|||
|
||||
|
||||
class ModelResponseListIterator:
|
||||
def __init__(self, model_responses):
|
||||
def __init__(self, model_responses, delay: Optional[float] = None):
|
||||
self.model_responses = model_responses
|
||||
self.index = 0
|
||||
self.delay = delay
|
||||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
|
@ -5930,6 +5995,8 @@ class ModelResponseListIterator:
|
|||
raise StopIteration
|
||||
model_response = self.model_responses[self.index]
|
||||
self.index += 1
|
||||
if self.delay:
|
||||
time.sleep(self.delay)
|
||||
return model_response
|
||||
|
||||
# Async iterator
|
||||
|
@ -5941,6 +6008,8 @@ class ModelResponseListIterator:
|
|||
raise StopAsyncIteration
|
||||
model_response = self.model_responses[self.index]
|
||||
self.index += 1
|
||||
if self.delay:
|
||||
await asyncio.sleep(self.delay)
|
||||
return model_response
|
||||
|
||||
|
||||
|
@ -6073,6 +6142,8 @@ def validate_and_fix_openai_messages(messages: List):
|
|||
for message in messages:
|
||||
if not message.get("role"):
|
||||
message["role"] = "assistant"
|
||||
if message.get("tool_calls"):
|
||||
message["tool_calls"] = jsonify_tools(tools=message["tool_calls"])
|
||||
return validate_chat_completion_messages(messages=messages)
|
||||
|
||||
|
||||
|
@ -6161,7 +6232,7 @@ class ProviderConfigManager:
|
|||
@staticmethod
|
||||
def get_provider_chat_config( # noqa: PLR0915
|
||||
model: str, provider: LlmProviders
|
||||
) -> BaseConfig:
|
||||
) -> Optional[BaseConfig]:
|
||||
"""
|
||||
Returns the provider config for a given provider.
|
||||
"""
|
||||
|
@ -6192,9 +6263,22 @@ class ProviderConfigManager:
|
|||
return litellm.AnthropicConfig()
|
||||
elif litellm.LlmProviders.ANTHROPIC_TEXT == provider:
|
||||
return litellm.AnthropicTextConfig()
|
||||
elif litellm.LlmProviders.VERTEX_AI_BETA == provider:
|
||||
return litellm.VertexGeminiConfig()
|
||||
elif litellm.LlmProviders.VERTEX_AI == provider:
|
||||
if "claude" in model:
|
||||
if "gemini" in model:
|
||||
return litellm.VertexGeminiConfig()
|
||||
elif "claude" in model:
|
||||
return litellm.VertexAIAnthropicConfig()
|
||||
elif model in litellm.vertex_mistral_models:
|
||||
if "codestral" in model:
|
||||
return litellm.CodestralTextCompletionConfig()
|
||||
else:
|
||||
return litellm.MistralConfig()
|
||||
elif model in litellm.vertex_ai_ai21_models:
|
||||
return litellm.VertexAIAi21Config()
|
||||
else: # use generic openai-like param mapping
|
||||
return litellm.VertexAILlama3Config()
|
||||
elif litellm.LlmProviders.CLOUDFLARE == provider:
|
||||
return litellm.CloudflareChatConfig()
|
||||
elif litellm.LlmProviders.SAGEMAKER_CHAT == provider:
|
||||
|
@ -6217,7 +6301,6 @@ class ProviderConfigManager:
|
|||
litellm.LlmProviders.CUSTOM == provider
|
||||
or litellm.LlmProviders.CUSTOM_OPENAI == provider
|
||||
or litellm.LlmProviders.OPENAI_LIKE == provider
|
||||
or litellm.LlmProviders.LITELLM_PROXY == provider
|
||||
):
|
||||
return litellm.OpenAILikeChatConfig()
|
||||
elif litellm.LlmProviders.AIOHTTP_OPENAI == provider:
|
||||
|
@ -6231,7 +6314,7 @@ class ProviderConfigManager:
|
|||
elif litellm.LlmProviders.REPLICATE == provider:
|
||||
return litellm.ReplicateConfig()
|
||||
elif litellm.LlmProviders.HUGGINGFACE == provider:
|
||||
return litellm.HuggingfaceConfig()
|
||||
return litellm.HuggingFaceChatConfig()
|
||||
elif litellm.LlmProviders.TOGETHER_AI == provider:
|
||||
return litellm.TogetherAIConfig()
|
||||
elif litellm.LlmProviders.OPENROUTER == provider:
|
||||
|
@ -6324,9 +6407,15 @@ class ProviderConfigManager:
|
|||
return litellm.AmazonMistralConfig()
|
||||
elif bedrock_invoke_provider == "deepseek_r1": # deepseek models on bedrock
|
||||
return litellm.AmazonDeepSeekR1Config()
|
||||
elif bedrock_invoke_provider == "nova":
|
||||
return litellm.AmazonInvokeNovaConfig()
|
||||
else:
|
||||
return litellm.AmazonInvokeConfig()
|
||||
return litellm.OpenAIGPTConfig()
|
||||
elif litellm.LlmProviders.LITELLM_PROXY == provider:
|
||||
return litellm.LiteLLMProxyChatConfig()
|
||||
elif litellm.LlmProviders.OPENAI == provider:
|
||||
return litellm.OpenAIGPTConfig()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_provider_embedding_config(
|
||||
|
@ -6379,6 +6468,11 @@ class ProviderConfigManager:
|
|||
return litellm.FireworksAIAudioTranscriptionConfig()
|
||||
elif litellm.LlmProviders.DEEPGRAM == provider:
|
||||
return litellm.DeepgramAudioTranscriptionConfig()
|
||||
elif litellm.LlmProviders.OPENAI == provider:
|
||||
if "gpt-4o" in model:
|
||||
return litellm.OpenAIGPTAudioTranscriptionConfig()
|
||||
else:
|
||||
return litellm.OpenAIWhisperAudioTranscriptionConfig()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
|
@ -6410,10 +6504,16 @@ class ProviderConfigManager:
|
|||
return litellm.FireworksAIConfig()
|
||||
elif LlmProviders.OPENAI == provider:
|
||||
return litellm.OpenAIGPTConfig()
|
||||
elif LlmProviders.GEMINI == provider:
|
||||
return litellm.GeminiModelInfo()
|
||||
elif LlmProviders.LITELLM_PROXY == provider:
|
||||
return litellm.LiteLLMProxyChatConfig()
|
||||
elif LlmProviders.TOPAZ == provider:
|
||||
return litellm.TopazModelInfo()
|
||||
elif LlmProviders.ANTHROPIC == provider:
|
||||
return litellm.AnthropicModelInfo()
|
||||
elif LlmProviders.XAI == provider:
|
||||
return litellm.XAIModelInfo()
|
||||
|
||||
return None
|
||||
|
||||
|
@ -6428,6 +6528,23 @@ class ProviderConfigManager:
|
|||
return litellm.TopazImageVariationConfig()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_provider_files_config(
|
||||
model: str,
|
||||
provider: LlmProviders,
|
||||
) -> Optional[BaseFilesConfig]:
|
||||
if LlmProviders.GEMINI == provider:
|
||||
from litellm.llms.gemini.files.transformation import (
|
||||
GoogleAIStudioFilesHandler, # experimental approach, to reduce bloat on __init__.py
|
||||
)
|
||||
|
||||
return GoogleAIStudioFilesHandler()
|
||||
elif LlmProviders.VERTEX_AI == provider:
|
||||
from litellm.llms.vertex_ai.files.transformation import VertexAIFilesConfig
|
||||
|
||||
return VertexAIFilesConfig()
|
||||
return None
|
||||
|
||||
|
||||
def get_end_user_id_for_cost_tracking(
|
||||
litellm_params: dict,
|
||||
|
@ -6492,7 +6609,7 @@ def is_prompt_caching_valid_prompt(
|
|||
model=model,
|
||||
use_default_image_token_count=True,
|
||||
)
|
||||
return token_count >= 1024
|
||||
return token_count >= MINIMUM_PROMPT_CACHE_TOKEN_COUNT
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Error in is_prompt_caching_valid_prompt: {e}")
|
||||
return False
|
||||
|
@ -6644,3 +6761,20 @@ def return_raw_request(endpoint: CallTypes, kwargs: dict) -> RawRequestTypedDict
|
|||
return RawRequestTypedDict(
|
||||
error=received_exception,
|
||||
)
|
||||
|
||||
|
||||
def jsonify_tools(tools: List[Any]) -> List[Dict]:
|
||||
"""
|
||||
Fixes https://github.com/BerriAI/litellm/issues/9321
|
||||
|
||||
Where user passes in a pydantic base model
|
||||
"""
|
||||
new_tools: List[Dict] = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, BaseModel):
|
||||
tool = tool.model_dump(exclude_none=True)
|
||||
elif isinstance(tool, dict):
|
||||
tool = tool.copy()
|
||||
if isinstance(tool, dict):
|
||||
new_tools.append(tool)
|
||||
return new_tools
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue