mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Merge branch 'main' into sync-logging
This commit is contained in:
commit
0d9a3dd50c
553 changed files with 37238 additions and 10299 deletions
429
litellm/utils.py
429
litellm/utils.py
|
@ -60,6 +60,7 @@ import litellm.litellm_core_utils.json_validation_rule
|
|||
from litellm.caching._internal_lru_cache import lru_cache_wrapper
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.caching.caching_handler import CachingHandlerResponse, LLMCachingHandler
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.core_helpers import (
|
||||
map_finish_reason,
|
||||
|
@ -86,6 +87,7 @@ from litellm.litellm_core_utils.llm_request_utils import _ensure_extra_body_is_s
|
|||
from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import (
|
||||
LiteLLMResponseObjectHandler,
|
||||
_handle_invalid_parallel_tool_calls,
|
||||
_parse_content_for_reasoning,
|
||||
convert_to_model_response_object,
|
||||
convert_to_streaming_response,
|
||||
convert_to_streaming_response_async,
|
||||
|
@ -110,13 +112,17 @@ from litellm.litellm_core_utils.token_counter import (
|
|||
calculate_img_tokens,
|
||||
get_modified_max_tokens,
|
||||
)
|
||||
from litellm.llms.bedrock.common_utils import BedrockModelInfo
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.router_utils.get_retry_from_policy import (
|
||||
get_num_retries_from_retry_policy,
|
||||
reset_retry_policy,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.types.llms.anthropic import ANTHROPIC_API_ONLY_HEADERS
|
||||
from litellm.types.llms.anthropic import (
|
||||
ANTHROPIC_API_ONLY_HEADERS,
|
||||
AnthropicThinkingParam,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
AllPromptValues,
|
||||
|
@ -416,6 +422,35 @@ def _custom_logger_class_exists_in_failure_callbacks(
|
|||
)
|
||||
|
||||
|
||||
def get_request_guardrails(kwargs: Dict[str, Any]) -> List[str]:
|
||||
"""
|
||||
Get the request guardrails from the kwargs
|
||||
"""
|
||||
metadata = kwargs.get("metadata") or {}
|
||||
requester_metadata = metadata.get("requester_metadata") or {}
|
||||
applied_guardrails = requester_metadata.get("guardrails") or []
|
||||
return applied_guardrails
|
||||
|
||||
|
||||
def get_applied_guardrails(kwargs: Dict[str, Any]) -> List[str]:
|
||||
"""
|
||||
- Add 'default_on' guardrails to the list
|
||||
- Add request guardrails to the list
|
||||
"""
|
||||
|
||||
request_guardrails = get_request_guardrails(kwargs)
|
||||
applied_guardrails = []
|
||||
for callback in litellm.callbacks:
|
||||
if callback is not None and isinstance(callback, CustomGuardrail):
|
||||
if callback.guardrail_name is not None:
|
||||
if callback.default_on is True:
|
||||
applied_guardrails.append(callback.guardrail_name)
|
||||
elif callback.guardrail_name in request_guardrails:
|
||||
applied_guardrails.append(callback.guardrail_name)
|
||||
|
||||
return applied_guardrails
|
||||
|
||||
|
||||
def function_setup( # noqa: PLR0915
|
||||
original_function: str, rules_obj, start_time, *args, **kwargs
|
||||
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
|
||||
|
@ -434,6 +469,9 @@ def function_setup( # noqa: PLR0915
|
|||
## CUSTOM LLM SETUP ##
|
||||
custom_llm_setup()
|
||||
|
||||
## GET APPLIED GUARDRAILS
|
||||
applied_guardrails = get_applied_guardrails(kwargs)
|
||||
|
||||
## LOGGING SETUP
|
||||
function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None
|
||||
|
||||
|
@ -583,7 +621,7 @@ def function_setup( # noqa: PLR0915
|
|||
details_to_log.pop("prompt", None)
|
||||
add_breadcrumb(
|
||||
category="litellm.llm_call",
|
||||
message=f"Positional Args: {args}, Keyword Args: {details_to_log}",
|
||||
message=f"Keyword Args: {details_to_log}",
|
||||
level="info",
|
||||
)
|
||||
if "logger_fn" in kwargs:
|
||||
|
@ -675,6 +713,7 @@ def function_setup( # noqa: PLR0915
|
|||
dynamic_async_success_callbacks=dynamic_async_success_callbacks,
|
||||
dynamic_async_failure_callbacks=dynamic_async_failure_callbacks,
|
||||
kwargs=kwargs,
|
||||
applied_guardrails=applied_guardrails,
|
||||
)
|
||||
|
||||
## check if metadata is passed in
|
||||
|
@ -690,8 +729,8 @@ def function_setup( # noqa: PLR0915
|
|||
)
|
||||
return logging_obj, kwargs
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
f"litellm.utils.py::function_setup() - [Non-Blocking] {traceback.format_exc()}; args - {args}; kwargs - {kwargs}"
|
||||
verbose_logger.exception(
|
||||
"litellm.utils.py::function_setup() - [Non-Blocking] Error in function_setup"
|
||||
)
|
||||
raise e
|
||||
|
||||
|
@ -1513,6 +1552,7 @@ def openai_token_counter( # noqa: PLR0915
|
|||
bool
|
||||
] = False, # Flag passed from litellm.stream_chunk_builder, to indicate counting tokens for LLM Response. We need this because for LLM input we add +3 tokens per message - based on OpenAI's token counter
|
||||
use_default_image_token_count: Optional[bool] = False,
|
||||
default_token_count: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Return the number of tokens used by a list of messages.
|
||||
|
@ -1560,31 +1600,12 @@ def openai_token_counter( # noqa: PLR0915
|
|||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
elif isinstance(value, List):
|
||||
for c in value:
|
||||
if c["type"] == "text":
|
||||
text += c["text"]
|
||||
num_tokens += len(
|
||||
encoding.encode(c["text"], disallowed_special=())
|
||||
)
|
||||
elif c["type"] == "image_url":
|
||||
if isinstance(c["image_url"], dict):
|
||||
image_url_dict = c["image_url"]
|
||||
detail = image_url_dict.get("detail", "auto")
|
||||
url = image_url_dict.get("url")
|
||||
num_tokens += calculate_img_tokens(
|
||||
data=url,
|
||||
mode=detail,
|
||||
use_default_image_token_count=use_default_image_token_count
|
||||
or False,
|
||||
)
|
||||
elif isinstance(c["image_url"], str):
|
||||
image_url_str = c["image_url"]
|
||||
num_tokens += calculate_img_tokens(
|
||||
data=image_url_str,
|
||||
mode="auto",
|
||||
use_default_image_token_count=use_default_image_token_count
|
||||
or False,
|
||||
)
|
||||
text, num_tokens_from_list = _get_num_tokens_from_content_list(
|
||||
content_list=value,
|
||||
use_default_image_token_count=use_default_image_token_count,
|
||||
default_token_count=default_token_count,
|
||||
)
|
||||
num_tokens += num_tokens_from_list
|
||||
elif text is not None and count_response_tokens is True:
|
||||
# This is the case where we need to count tokens for a streamed response. We should NOT add +3 tokens per message in this branch
|
||||
num_tokens = len(encoding.encode(text, disallowed_special=()))
|
||||
|
@ -1719,15 +1740,62 @@ def _format_type(props, indent):
|
|||
return "any"
|
||||
|
||||
|
||||
def _get_num_tokens_from_content_list(
|
||||
content_list: List[Dict[str, Any]],
|
||||
use_default_image_token_count: Optional[bool] = False,
|
||||
default_token_count: Optional[int] = None,
|
||||
) -> Tuple[str, int]:
|
||||
"""
|
||||
Get the number of tokens from a list of content.
|
||||
|
||||
Returns:
|
||||
Tuple[str, int]: A tuple containing the text and the number of tokens.
|
||||
"""
|
||||
try:
|
||||
num_tokens = 0
|
||||
text = ""
|
||||
for c in content_list:
|
||||
if c["type"] == "text":
|
||||
text += c["text"]
|
||||
num_tokens += len(encoding.encode(c["text"], disallowed_special=()))
|
||||
elif c["type"] == "image_url":
|
||||
if isinstance(c["image_url"], dict):
|
||||
image_url_dict = c["image_url"]
|
||||
detail = image_url_dict.get("detail", "auto")
|
||||
url = image_url_dict.get("url")
|
||||
num_tokens += calculate_img_tokens(
|
||||
data=url,
|
||||
mode=detail,
|
||||
use_default_image_token_count=use_default_image_token_count
|
||||
or False,
|
||||
)
|
||||
elif isinstance(c["image_url"], str):
|
||||
image_url_str = c["image_url"]
|
||||
num_tokens += calculate_img_tokens(
|
||||
data=image_url_str,
|
||||
mode="auto",
|
||||
use_default_image_token_count=use_default_image_token_count
|
||||
or False,
|
||||
)
|
||||
return text, num_tokens
|
||||
except Exception as e:
|
||||
if default_token_count is not None:
|
||||
return "", default_token_count
|
||||
raise ValueError(
|
||||
f"Error getting number of tokens from content list: {e}, default_token_count={default_token_count}"
|
||||
)
|
||||
|
||||
|
||||
def token_counter(
|
||||
model="",
|
||||
custom_tokenizer: Optional[dict] = None,
|
||||
custom_tokenizer: Optional[Union[dict, SelectTokenizerResponse]] = None,
|
||||
text: Optional[Union[str, List[str]]] = None,
|
||||
messages: Optional[List] = None,
|
||||
count_response_tokens: Optional[bool] = False,
|
||||
tools: Optional[List[ChatCompletionToolParam]] = None,
|
||||
tool_choice: Optional[ChatCompletionNamedToolChoiceParam] = None,
|
||||
use_default_image_token_count: Optional[bool] = False,
|
||||
default_token_count: Optional[int] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Count the number of tokens in a given text using a specified model.
|
||||
|
@ -1737,6 +1805,7 @@ def token_counter(
|
|||
custom_tokenizer (Optional[dict]): A custom tokenizer created with the `create_pretrained_tokenizer` or `create_tokenizer` method. Must be a dictionary with a string value for `type` and Tokenizer for `tokenizer`. Default is None.
|
||||
text (str): The raw text string to be passed to the model. Default is None.
|
||||
messages (Optional[List[Dict[str, str]]]): Alternative to passing in text. A list of dictionaries representing messages with "role" and "content" keys. Default is None.
|
||||
default_token_count (Optional[int]): The default number of tokens to return for a message block, if an error occurs. Default is None.
|
||||
|
||||
Returns:
|
||||
int: The number of tokens in the text.
|
||||
|
@ -1754,34 +1823,20 @@ def token_counter(
|
|||
if isinstance(content, str):
|
||||
text += message["content"]
|
||||
elif isinstance(content, List):
|
||||
for c in content:
|
||||
if c["type"] == "text":
|
||||
text += c["text"]
|
||||
elif c["type"] == "image_url":
|
||||
if isinstance(c["image_url"], dict):
|
||||
image_url_dict = c["image_url"]
|
||||
detail = image_url_dict.get("detail", "auto")
|
||||
url = image_url_dict.get("url")
|
||||
num_tokens += calculate_img_tokens(
|
||||
data=url,
|
||||
mode=detail,
|
||||
use_default_image_token_count=use_default_image_token_count
|
||||
or False,
|
||||
)
|
||||
elif isinstance(c["image_url"], str):
|
||||
image_url_str = c["image_url"]
|
||||
num_tokens += calculate_img_tokens(
|
||||
data=image_url_str,
|
||||
mode="auto",
|
||||
use_default_image_token_count=use_default_image_token_count
|
||||
or False,
|
||||
)
|
||||
text, num_tokens = _get_num_tokens_from_content_list(
|
||||
content_list=content,
|
||||
use_default_image_token_count=use_default_image_token_count,
|
||||
default_token_count=default_token_count,
|
||||
)
|
||||
if message.get("tool_calls"):
|
||||
is_tool_call = True
|
||||
for tool_call in message["tool_calls"]:
|
||||
if "function" in tool_call:
|
||||
function_arguments = tool_call["function"]["arguments"]
|
||||
text += function_arguments
|
||||
text = (
|
||||
text if isinstance(text, str) else "".join(text or [])
|
||||
) + (str(function_arguments) if function_arguments else "")
|
||||
|
||||
else:
|
||||
raise ValueError("text and messages cannot both be None")
|
||||
elif isinstance(text, List):
|
||||
|
@ -1816,6 +1871,7 @@ def token_counter(
|
|||
tool_choice=tool_choice,
|
||||
use_default_image_token_count=use_default_image_token_count
|
||||
or False,
|
||||
default_token_count=default_token_count,
|
||||
)
|
||||
else:
|
||||
print_verbose(
|
||||
|
@ -1831,6 +1887,7 @@ def token_counter(
|
|||
tool_choice=tool_choice,
|
||||
use_default_image_token_count=use_default_image_token_count
|
||||
or False,
|
||||
default_token_count=default_token_count,
|
||||
)
|
||||
else:
|
||||
num_tokens = len(encoding.encode(text, disallowed_special=())) # type: ignore
|
||||
|
@ -1911,6 +1968,19 @@ def supports_response_schema(
|
|||
)
|
||||
|
||||
|
||||
def supports_parallel_function_calling(
|
||||
model: str, custom_llm_provider: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the given model supports parallel tool calls and return a boolean value.
|
||||
"""
|
||||
return _supports_factory(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
key="supports_parallel_function_calling",
|
||||
)
|
||||
|
||||
|
||||
def supports_function_calling(
|
||||
model: str, custom_llm_provider: Optional[str] = None
|
||||
) -> bool:
|
||||
|
@ -1934,6 +2004,15 @@ def supports_function_calling(
|
|||
)
|
||||
|
||||
|
||||
def supports_tool_choice(model: str, custom_llm_provider: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Check if the given model supports `tool_choice` and return a boolean value.
|
||||
"""
|
||||
return _supports_factory(
|
||||
model=model, custom_llm_provider=custom_llm_provider, key="supports_tool_choice"
|
||||
)
|
||||
|
||||
|
||||
def _supports_factory(model: str, custom_llm_provider: Optional[str], key: str) -> bool:
|
||||
"""
|
||||
Check if the given model supports function calling and return a boolean value.
|
||||
|
@ -2051,30 +2130,6 @@ def supports_embedding_image_input(
|
|||
)
|
||||
|
||||
|
||||
def supports_parallel_function_calling(model: str):
|
||||
"""
|
||||
Check if the given model supports parallel function calling and return True if it does, False otherwise.
|
||||
|
||||
Parameters:
|
||||
model (str): The model to check for support of parallel function calling.
|
||||
|
||||
Returns:
|
||||
bool: True if the model supports parallel function calling, False otherwise.
|
||||
|
||||
Raises:
|
||||
Exception: If the model is not found in the model_cost dictionary.
|
||||
"""
|
||||
if model in litellm.model_cost:
|
||||
model_info = litellm.model_cost[model]
|
||||
if model_info.get("supports_parallel_function_calling", False) is True:
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
raise Exception(
|
||||
f"Model not supports parallel function calling. You passed model={model}."
|
||||
)
|
||||
|
||||
|
||||
####### HELPER FUNCTIONS ################
|
||||
def _update_dictionary(existing_dict: Dict, new_dict: dict) -> dict:
|
||||
for k, v in new_dict.items():
|
||||
|
@ -2682,8 +2737,10 @@ def get_optional_params( # noqa: PLR0915
|
|||
api_version=None,
|
||||
parallel_tool_calls=None,
|
||||
drop_params=None,
|
||||
reasoning_effort=None,
|
||||
additional_drop_params=None,
|
||||
messages: Optional[List[AllMessageValues]] = None,
|
||||
thinking: Optional[AnthropicThinkingParam] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# retrieve all parameters passed to the function
|
||||
|
@ -2767,9 +2824,12 @@ def get_optional_params( # noqa: PLR0915
|
|||
"drop_params": None,
|
||||
"additional_drop_params": None,
|
||||
"messages": None,
|
||||
"reasoning_effort": None,
|
||||
"thinking": None,
|
||||
}
|
||||
|
||||
# filter out those parameters that were passed with non-default values
|
||||
|
||||
non_default_params = {
|
||||
k: v
|
||||
for k, v in passed_params.items()
|
||||
|
@ -3112,51 +3172,56 @@ def get_optional_params( # noqa: PLR0915
|
|||
else False
|
||||
),
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai" and model in litellm.vertex_llama3_models:
|
||||
optional_params = litellm.VertexAILlama3Config().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=(
|
||||
drop_params
|
||||
if drop_params is not None and isinstance(drop_params, bool)
|
||||
else False
|
||||
),
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai" and model in litellm.vertex_mistral_models:
|
||||
if "codestral" in model:
|
||||
optional_params = litellm.CodestralTextCompletionConfig().map_openai_params(
|
||||
model=model,
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
|
||||
if model in litellm.vertex_mistral_models:
|
||||
if "codestral" in model:
|
||||
optional_params = (
|
||||
litellm.CodestralTextCompletionConfig().map_openai_params(
|
||||
model=model,
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
drop_params=(
|
||||
drop_params
|
||||
if drop_params is not None and isinstance(drop_params, bool)
|
||||
else False
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
optional_params = litellm.MistralConfig().map_openai_params(
|
||||
model=model,
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
drop_params=(
|
||||
drop_params
|
||||
if drop_params is not None and isinstance(drop_params, bool)
|
||||
else False
|
||||
),
|
||||
)
|
||||
elif model in litellm.vertex_ai_ai21_models:
|
||||
optional_params = litellm.VertexAIAi21Config().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=(
|
||||
drop_params
|
||||
if drop_params is not None and isinstance(drop_params, bool)
|
||||
else False
|
||||
),
|
||||
)
|
||||
else:
|
||||
optional_params = litellm.MistralConfig().map_openai_params(
|
||||
model=model,
|
||||
else: # use generic openai-like param mapping
|
||||
optional_params = litellm.VertexAILlama3Config().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=(
|
||||
drop_params
|
||||
if drop_params is not None and isinstance(drop_params, bool)
|
||||
else False
|
||||
),
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai" and model in litellm.vertex_ai_ai21_models:
|
||||
optional_params = litellm.VertexAIAi21Config().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=(
|
||||
drop_params
|
||||
if drop_params is not None and isinstance(drop_params, bool)
|
||||
else False
|
||||
),
|
||||
)
|
||||
|
||||
elif custom_llm_provider == "sagemaker":
|
||||
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
|
||||
optional_params = litellm.SagemakerConfig().map_openai_params(
|
||||
|
@ -3170,8 +3235,8 @@ def get_optional_params( # noqa: PLR0915
|
|||
),
|
||||
)
|
||||
elif custom_llm_provider == "bedrock":
|
||||
base_model = litellm.AmazonConverseConfig()._get_base_model(model)
|
||||
if base_model in litellm.bedrock_converse_models:
|
||||
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
|
||||
if bedrock_route == "converse" or bedrock_route == "converse_like":
|
||||
optional_params = litellm.AmazonConverseConfig().map_openai_params(
|
||||
model=model,
|
||||
non_default_params=non_default_params,
|
||||
|
@ -3184,15 +3249,20 @@ def get_optional_params( # noqa: PLR0915
|
|||
messages=messages,
|
||||
)
|
||||
|
||||
elif "anthropic" in model:
|
||||
if "aws_bedrock_client" in passed_params: # deprecated boto3.invoke route.
|
||||
if model.startswith("anthropic.claude-3"):
|
||||
optional_params = (
|
||||
litellm.AmazonAnthropicClaude3Config().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
elif "anthropic" in model and bedrock_route == "invoke":
|
||||
if model.startswith("anthropic.claude-3"):
|
||||
optional_params = (
|
||||
litellm.AmazonAnthropicClaude3Config().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=(
|
||||
drop_params
|
||||
if drop_params is not None and isinstance(drop_params, bool)
|
||||
else False
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
|
@ -3953,8 +4023,16 @@ def _strip_stable_vertex_version(model_name) -> str:
|
|||
return re.sub(r"-\d+$", "", model_name)
|
||||
|
||||
|
||||
def _strip_bedrock_region(model_name) -> str:
|
||||
return litellm.AmazonConverseConfig()._get_base_model(model_name)
|
||||
def _get_base_bedrock_model(model_name) -> str:
|
||||
"""
|
||||
Get the base model from the given model name.
|
||||
|
||||
Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
|
||||
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
|
||||
"""
|
||||
from litellm.llms.bedrock.common_utils import BedrockModelInfo
|
||||
|
||||
return BedrockModelInfo.get_base_model(model_name)
|
||||
|
||||
|
||||
def _strip_openai_finetune_model_name(model_name: str) -> str:
|
||||
|
@ -3975,8 +4053,8 @@ def _strip_openai_finetune_model_name(model_name: str) -> str:
|
|||
|
||||
def _strip_model_name(model: str, custom_llm_provider: Optional[str]) -> str:
|
||||
if custom_llm_provider and custom_llm_provider == "bedrock":
|
||||
strip_bedrock_region = _strip_bedrock_region(model_name=model)
|
||||
return strip_bedrock_region
|
||||
stripped_bedrock_model = _get_base_bedrock_model(model_name=model)
|
||||
return stripped_bedrock_model
|
||||
elif custom_llm_provider and (
|
||||
custom_llm_provider == "vertex_ai" or custom_llm_provider == "gemini"
|
||||
):
|
||||
|
@ -4012,7 +4090,7 @@ def _check_provider_match(model_info: dict, custom_llm_provider: Optional[str])
|
|||
"litellm_provider"
|
||||
].startswith("fireworks_ai"):
|
||||
return True
|
||||
elif custom_llm_provider == "bedrock" and model_info[
|
||||
elif custom_llm_provider.startswith("bedrock") and model_info[
|
||||
"litellm_provider"
|
||||
].startswith("bedrock"):
|
||||
return True
|
||||
|
@ -4101,7 +4179,6 @@ def _get_max_position_embeddings(model_name: str) -> Optional[int]:
|
|||
return None
|
||||
|
||||
|
||||
@lru_cache_wrapper(maxsize=16)
|
||||
def _cached_get_model_info_helper(
|
||||
model: str, custom_llm_provider: Optional[str]
|
||||
) -> ModelInfoBase:
|
||||
|
@ -4183,6 +4260,7 @@ def _get_model_info_helper( # noqa: PLR0915
|
|||
supports_system_messages=None,
|
||||
supports_response_schema=None,
|
||||
supports_function_calling=None,
|
||||
supports_tool_choice=None,
|
||||
supports_assistant_prefill=None,
|
||||
supports_prompt_caching=None,
|
||||
supports_pdf_input=None,
|
||||
|
@ -4327,6 +4405,7 @@ def _get_model_info_helper( # noqa: PLR0915
|
|||
supports_function_calling=_model_info.get(
|
||||
"supports_function_calling", False
|
||||
),
|
||||
supports_tool_choice=_model_info.get("supports_tool_choice", False),
|
||||
supports_assistant_prefill=_model_info.get(
|
||||
"supports_assistant_prefill", False
|
||||
),
|
||||
|
@ -4405,6 +4484,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
|||
supports_response_schema: Optional[bool]
|
||||
supports_vision: Optional[bool]
|
||||
supports_function_calling: Optional[bool]
|
||||
supports_tool_choice: Optional[bool]
|
||||
supports_prompt_caching: Optional[bool]
|
||||
supports_audio_input: Optional[bool]
|
||||
supports_audio_output: Optional[bool]
|
||||
|
@ -5119,9 +5199,10 @@ def _calculate_retry_after(
|
|||
# custom prompt helper function
|
||||
def register_prompt_template(
|
||||
model: str,
|
||||
roles: dict,
|
||||
roles: dict = {},
|
||||
initial_prompt_value: str = "",
|
||||
final_prompt_value: str = "",
|
||||
tokenizer_config: dict = {},
|
||||
):
|
||||
"""
|
||||
Register a prompt template to follow your custom format for a given model
|
||||
|
@ -5158,12 +5239,27 @@ def register_prompt_template(
|
|||
)
|
||||
```
|
||||
"""
|
||||
model = get_llm_provider(model=model)[0]
|
||||
litellm.custom_prompt_dict[model] = {
|
||||
"roles": roles,
|
||||
"initial_prompt_value": initial_prompt_value,
|
||||
"final_prompt_value": final_prompt_value,
|
||||
}
|
||||
complete_model = model
|
||||
potential_models = [complete_model]
|
||||
try:
|
||||
model = get_llm_provider(model=model)[0]
|
||||
potential_models.append(model)
|
||||
except Exception:
|
||||
pass
|
||||
if tokenizer_config:
|
||||
for m in potential_models:
|
||||
litellm.known_tokenizer_config[m] = {
|
||||
"tokenizer": tokenizer_config,
|
||||
"status": "success",
|
||||
}
|
||||
else:
|
||||
for m in potential_models:
|
||||
litellm.custom_prompt_dict[m] = {
|
||||
"roles": roles,
|
||||
"initial_prompt_value": initial_prompt_value,
|
||||
"final_prompt_value": final_prompt_value,
|
||||
}
|
||||
|
||||
return litellm.custom_prompt_dict
|
||||
|
||||
|
||||
|
@ -5827,6 +5923,18 @@ def convert_to_dict(message: Union[BaseModel, dict]) -> dict:
|
|||
)
|
||||
|
||||
|
||||
def validate_and_fix_openai_messages(messages: List):
|
||||
"""
|
||||
Ensures all messages are valid OpenAI chat completion messages.
|
||||
|
||||
Handles missing role for assistant messages.
|
||||
"""
|
||||
for message in messages:
|
||||
if not message.get("role"):
|
||||
message["role"] = "assistant"
|
||||
return validate_chat_completion_messages(messages=messages)
|
||||
|
||||
|
||||
def validate_chat_completion_messages(messages: List[AllMessageValues]):
|
||||
"""
|
||||
Ensures all messages are valid OpenAI chat completion messages.
|
||||
|
@ -5866,6 +5974,10 @@ def validate_chat_completion_user_messages(messages: List[AllMessageValues]):
|
|||
if item.get("type") not in ValidUserMessageContentTypes:
|
||||
raise Exception("invalid content type")
|
||||
except Exception as e:
|
||||
if isinstance(e, KeyError):
|
||||
raise Exception(
|
||||
f"Invalid message={m} at index {idx}. Please ensure all messages are valid OpenAI chat completion messages."
|
||||
)
|
||||
if "invalid content type" in str(e):
|
||||
raise Exception(
|
||||
f"Invalid user message={m} at index {idx}. Please ensure all user messages are valid OpenAI chat completion messages."
|
||||
|
@ -6040,25 +6152,29 @@ class ProviderConfigManager:
|
|||
elif litellm.LlmProviders.PETALS == provider:
|
||||
return litellm.PetalsConfig()
|
||||
elif litellm.LlmProviders.BEDROCK == provider:
|
||||
base_model = litellm.AmazonConverseConfig()._get_base_model(model)
|
||||
bedrock_provider = litellm.BedrockLLM.get_bedrock_invoke_provider(model)
|
||||
if (
|
||||
base_model in litellm.bedrock_converse_models
|
||||
or "converse_like" in model
|
||||
):
|
||||
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
|
||||
bedrock_invoke_provider = litellm.BedrockLLM.get_bedrock_invoke_provider(
|
||||
model
|
||||
)
|
||||
|
||||
if bedrock_route == "converse" or bedrock_route == "converse_like":
|
||||
return litellm.AmazonConverseConfig()
|
||||
elif bedrock_provider == "amazon": # amazon titan llms
|
||||
elif bedrock_invoke_provider == "amazon": # amazon titan llms
|
||||
return litellm.AmazonTitanConfig()
|
||||
elif (
|
||||
bedrock_provider == "meta" or bedrock_provider == "llama"
|
||||
bedrock_invoke_provider == "meta" or bedrock_invoke_provider == "llama"
|
||||
): # amazon / meta llms
|
||||
return litellm.AmazonLlamaConfig()
|
||||
elif bedrock_provider == "ai21": # ai21 llms
|
||||
elif bedrock_invoke_provider == "ai21": # ai21 llms
|
||||
return litellm.AmazonAI21Config()
|
||||
elif bedrock_provider == "cohere": # cohere models on bedrock
|
||||
elif bedrock_invoke_provider == "cohere": # cohere models on bedrock
|
||||
return litellm.AmazonCohereConfig()
|
||||
elif bedrock_provider == "mistral": # mistral models on bedrock
|
||||
elif bedrock_invoke_provider == "mistral": # mistral models on bedrock
|
||||
return litellm.AmazonMistralConfig()
|
||||
elif bedrock_invoke_provider == "deepseek_r1": # deepseek models on bedrock
|
||||
return litellm.AmazonDeepSeekR1Config()
|
||||
else:
|
||||
return litellm.AmazonInvokeConfig()
|
||||
return litellm.OpenAIGPTConfig()
|
||||
|
||||
@staticmethod
|
||||
|
@ -6078,13 +6194,20 @@ class ProviderConfigManager:
|
|||
def get_provider_rerank_config(
|
||||
model: str,
|
||||
provider: LlmProviders,
|
||||
api_base: Optional[str],
|
||||
present_version_params: List[str],
|
||||
) -> BaseRerankConfig:
|
||||
if litellm.LlmProviders.COHERE == provider:
|
||||
return litellm.CohereRerankConfig()
|
||||
if should_use_cohere_v1_client(api_base, present_version_params):
|
||||
return litellm.CohereRerankConfig()
|
||||
else:
|
||||
return litellm.CohereRerankV2Config()
|
||||
elif litellm.LlmProviders.AZURE_AI == provider:
|
||||
return litellm.AzureAIRerankConfig()
|
||||
elif litellm.LlmProviders.INFINITY == provider:
|
||||
return litellm.InfinityRerankConfig()
|
||||
elif litellm.LlmProviders.JINA_AI == provider:
|
||||
return litellm.JinaAIRerankConfig()
|
||||
return litellm.CohereRerankConfig()
|
||||
|
||||
@staticmethod
|
||||
|
@ -6163,6 +6286,19 @@ def get_end_user_id_for_cost_tracking(
|
|||
return end_user_id
|
||||
|
||||
|
||||
def should_use_cohere_v1_client(
|
||||
api_base: Optional[str], present_version_params: List[str]
|
||||
):
|
||||
if not api_base:
|
||||
return False
|
||||
uses_v1_params = ("max_chunks_per_doc" in present_version_params) and (
|
||||
"max_tokens_per_doc" not in present_version_params
|
||||
)
|
||||
return api_base.endswith("/v1/rerank") or (
|
||||
uses_v1_params and not api_base.endswith("/v2/rerank")
|
||||
)
|
||||
|
||||
|
||||
def is_prompt_caching_valid_prompt(
|
||||
model: str,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
|
@ -6273,7 +6409,9 @@ def get_non_default_completion_params(kwargs: dict) -> dict:
|
|||
|
||||
def add_openai_metadata(metadata: dict) -> dict:
|
||||
"""
|
||||
Add metadata to openai optional parameters, excluding hidden params
|
||||
Add metadata to openai optional parameters, excluding hidden params.
|
||||
|
||||
OpenAI 'metadata' only supports string values.
|
||||
|
||||
Args:
|
||||
params (dict): Dictionary of API parameters
|
||||
|
@ -6285,5 +6423,10 @@ def add_openai_metadata(metadata: dict) -> dict:
|
|||
if metadata is None:
|
||||
return None
|
||||
# Only include non-hidden parameters
|
||||
visible_metadata = {k: v for k, v in metadata.items() if k != "hidden_params"}
|
||||
visible_metadata = {
|
||||
k: v
|
||||
for k, v in metadata.items()
|
||||
if k != "hidden_params" and isinstance(v, (str))
|
||||
}
|
||||
|
||||
return visible_metadata.copy()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue