Merge branch 'main' into sync-logging

This commit is contained in:
Yuki Watanabe 2025-02-28 14:44:39 +09:00 committed by GitHub
commit 0d9a3dd50c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
553 changed files with 37238 additions and 10299 deletions

View file

@ -60,6 +60,7 @@ import litellm.litellm_core_utils.json_validation_rule
from litellm.caching._internal_lru_cache import lru_cache_wrapper
from litellm.caching.caching import DualCache
from litellm.caching.caching_handler import CachingHandlerResponse, LLMCachingHandler
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.core_helpers import (
map_finish_reason,
@ -86,6 +87,7 @@ from litellm.litellm_core_utils.llm_request_utils import _ensure_extra_body_is_s
from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import (
LiteLLMResponseObjectHandler,
_handle_invalid_parallel_tool_calls,
_parse_content_for_reasoning,
convert_to_model_response_object,
convert_to_streaming_response,
convert_to_streaming_response_async,
@ -110,13 +112,17 @@ from litellm.litellm_core_utils.token_counter import (
calculate_img_tokens,
get_modified_max_tokens,
)
from litellm.llms.bedrock.common_utils import BedrockModelInfo
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.router_utils.get_retry_from_policy import (
get_num_retries_from_retry_policy,
reset_retry_policy,
)
from litellm.secret_managers.main import get_secret
from litellm.types.llms.anthropic import ANTHROPIC_API_ONLY_HEADERS
from litellm.types.llms.anthropic import (
ANTHROPIC_API_ONLY_HEADERS,
AnthropicThinkingParam,
)
from litellm.types.llms.openai import (
AllMessageValues,
AllPromptValues,
@ -416,6 +422,35 @@ def _custom_logger_class_exists_in_failure_callbacks(
)
def get_request_guardrails(kwargs: Dict[str, Any]) -> List[str]:
"""
Get the request guardrails from the kwargs
"""
metadata = kwargs.get("metadata") or {}
requester_metadata = metadata.get("requester_metadata") or {}
applied_guardrails = requester_metadata.get("guardrails") or []
return applied_guardrails
def get_applied_guardrails(kwargs: Dict[str, Any]) -> List[str]:
"""
- Add 'default_on' guardrails to the list
- Add request guardrails to the list
"""
request_guardrails = get_request_guardrails(kwargs)
applied_guardrails = []
for callback in litellm.callbacks:
if callback is not None and isinstance(callback, CustomGuardrail):
if callback.guardrail_name is not None:
if callback.default_on is True:
applied_guardrails.append(callback.guardrail_name)
elif callback.guardrail_name in request_guardrails:
applied_guardrails.append(callback.guardrail_name)
return applied_guardrails
def function_setup( # noqa: PLR0915
original_function: str, rules_obj, start_time, *args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
@ -434,6 +469,9 @@ def function_setup( # noqa: PLR0915
## CUSTOM LLM SETUP ##
custom_llm_setup()
## GET APPLIED GUARDRAILS
applied_guardrails = get_applied_guardrails(kwargs)
## LOGGING SETUP
function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None
@ -583,7 +621,7 @@ def function_setup( # noqa: PLR0915
details_to_log.pop("prompt", None)
add_breadcrumb(
category="litellm.llm_call",
message=f"Positional Args: {args}, Keyword Args: {details_to_log}",
message=f"Keyword Args: {details_to_log}",
level="info",
)
if "logger_fn" in kwargs:
@ -675,6 +713,7 @@ def function_setup( # noqa: PLR0915
dynamic_async_success_callbacks=dynamic_async_success_callbacks,
dynamic_async_failure_callbacks=dynamic_async_failure_callbacks,
kwargs=kwargs,
applied_guardrails=applied_guardrails,
)
## check if metadata is passed in
@ -690,8 +729,8 @@ def function_setup( # noqa: PLR0915
)
return logging_obj, kwargs
except Exception as e:
verbose_logger.error(
f"litellm.utils.py::function_setup() - [Non-Blocking] {traceback.format_exc()}; args - {args}; kwargs - {kwargs}"
verbose_logger.exception(
"litellm.utils.py::function_setup() - [Non-Blocking] Error in function_setup"
)
raise e
@ -1513,6 +1552,7 @@ def openai_token_counter( # noqa: PLR0915
bool
] = False, # Flag passed from litellm.stream_chunk_builder, to indicate counting tokens for LLM Response. We need this because for LLM input we add +3 tokens per message - based on OpenAI's token counter
use_default_image_token_count: Optional[bool] = False,
default_token_count: Optional[int] = None,
):
"""
Return the number of tokens used by a list of messages.
@ -1560,31 +1600,12 @@ def openai_token_counter( # noqa: PLR0915
if key == "name":
num_tokens += tokens_per_name
elif isinstance(value, List):
for c in value:
if c["type"] == "text":
text += c["text"]
num_tokens += len(
encoding.encode(c["text"], disallowed_special=())
)
elif c["type"] == "image_url":
if isinstance(c["image_url"], dict):
image_url_dict = c["image_url"]
detail = image_url_dict.get("detail", "auto")
url = image_url_dict.get("url")
num_tokens += calculate_img_tokens(
data=url,
mode=detail,
use_default_image_token_count=use_default_image_token_count
or False,
)
elif isinstance(c["image_url"], str):
image_url_str = c["image_url"]
num_tokens += calculate_img_tokens(
data=image_url_str,
mode="auto",
use_default_image_token_count=use_default_image_token_count
or False,
)
text, num_tokens_from_list = _get_num_tokens_from_content_list(
content_list=value,
use_default_image_token_count=use_default_image_token_count,
default_token_count=default_token_count,
)
num_tokens += num_tokens_from_list
elif text is not None and count_response_tokens is True:
# This is the case where we need to count tokens for a streamed response. We should NOT add +3 tokens per message in this branch
num_tokens = len(encoding.encode(text, disallowed_special=()))
@ -1719,15 +1740,62 @@ def _format_type(props, indent):
return "any"
def _get_num_tokens_from_content_list(
content_list: List[Dict[str, Any]],
use_default_image_token_count: Optional[bool] = False,
default_token_count: Optional[int] = None,
) -> Tuple[str, int]:
"""
Get the number of tokens from a list of content.
Returns:
Tuple[str, int]: A tuple containing the text and the number of tokens.
"""
try:
num_tokens = 0
text = ""
for c in content_list:
if c["type"] == "text":
text += c["text"]
num_tokens += len(encoding.encode(c["text"], disallowed_special=()))
elif c["type"] == "image_url":
if isinstance(c["image_url"], dict):
image_url_dict = c["image_url"]
detail = image_url_dict.get("detail", "auto")
url = image_url_dict.get("url")
num_tokens += calculate_img_tokens(
data=url,
mode=detail,
use_default_image_token_count=use_default_image_token_count
or False,
)
elif isinstance(c["image_url"], str):
image_url_str = c["image_url"]
num_tokens += calculate_img_tokens(
data=image_url_str,
mode="auto",
use_default_image_token_count=use_default_image_token_count
or False,
)
return text, num_tokens
except Exception as e:
if default_token_count is not None:
return "", default_token_count
raise ValueError(
f"Error getting number of tokens from content list: {e}, default_token_count={default_token_count}"
)
def token_counter(
model="",
custom_tokenizer: Optional[dict] = None,
custom_tokenizer: Optional[Union[dict, SelectTokenizerResponse]] = None,
text: Optional[Union[str, List[str]]] = None,
messages: Optional[List] = None,
count_response_tokens: Optional[bool] = False,
tools: Optional[List[ChatCompletionToolParam]] = None,
tool_choice: Optional[ChatCompletionNamedToolChoiceParam] = None,
use_default_image_token_count: Optional[bool] = False,
default_token_count: Optional[int] = None,
) -> int:
"""
Count the number of tokens in a given text using a specified model.
@ -1737,6 +1805,7 @@ def token_counter(
custom_tokenizer (Optional[dict]): A custom tokenizer created with the `create_pretrained_tokenizer` or `create_tokenizer` method. Must be a dictionary with a string value for `type` and Tokenizer for `tokenizer`. Default is None.
text (str): The raw text string to be passed to the model. Default is None.
messages (Optional[List[Dict[str, str]]]): Alternative to passing in text. A list of dictionaries representing messages with "role" and "content" keys. Default is None.
default_token_count (Optional[int]): The default number of tokens to return for a message block, if an error occurs. Default is None.
Returns:
int: The number of tokens in the text.
@ -1754,34 +1823,20 @@ def token_counter(
if isinstance(content, str):
text += message["content"]
elif isinstance(content, List):
for c in content:
if c["type"] == "text":
text += c["text"]
elif c["type"] == "image_url":
if isinstance(c["image_url"], dict):
image_url_dict = c["image_url"]
detail = image_url_dict.get("detail", "auto")
url = image_url_dict.get("url")
num_tokens += calculate_img_tokens(
data=url,
mode=detail,
use_default_image_token_count=use_default_image_token_count
or False,
)
elif isinstance(c["image_url"], str):
image_url_str = c["image_url"]
num_tokens += calculate_img_tokens(
data=image_url_str,
mode="auto",
use_default_image_token_count=use_default_image_token_count
or False,
)
text, num_tokens = _get_num_tokens_from_content_list(
content_list=content,
use_default_image_token_count=use_default_image_token_count,
default_token_count=default_token_count,
)
if message.get("tool_calls"):
is_tool_call = True
for tool_call in message["tool_calls"]:
if "function" in tool_call:
function_arguments = tool_call["function"]["arguments"]
text += function_arguments
text = (
text if isinstance(text, str) else "".join(text or [])
) + (str(function_arguments) if function_arguments else "")
else:
raise ValueError("text and messages cannot both be None")
elif isinstance(text, List):
@ -1816,6 +1871,7 @@ def token_counter(
tool_choice=tool_choice,
use_default_image_token_count=use_default_image_token_count
or False,
default_token_count=default_token_count,
)
else:
print_verbose(
@ -1831,6 +1887,7 @@ def token_counter(
tool_choice=tool_choice,
use_default_image_token_count=use_default_image_token_count
or False,
default_token_count=default_token_count,
)
else:
num_tokens = len(encoding.encode(text, disallowed_special=())) # type: ignore
@ -1911,6 +1968,19 @@ def supports_response_schema(
)
def supports_parallel_function_calling(
model: str, custom_llm_provider: Optional[str] = None
) -> bool:
"""
Check if the given model supports parallel tool calls and return a boolean value.
"""
return _supports_factory(
model=model,
custom_llm_provider=custom_llm_provider,
key="supports_parallel_function_calling",
)
def supports_function_calling(
model: str, custom_llm_provider: Optional[str] = None
) -> bool:
@ -1934,6 +2004,15 @@ def supports_function_calling(
)
def supports_tool_choice(model: str, custom_llm_provider: Optional[str] = None) -> bool:
"""
Check if the given model supports `tool_choice` and return a boolean value.
"""
return _supports_factory(
model=model, custom_llm_provider=custom_llm_provider, key="supports_tool_choice"
)
def _supports_factory(model: str, custom_llm_provider: Optional[str], key: str) -> bool:
"""
Check if the given model supports function calling and return a boolean value.
@ -2051,30 +2130,6 @@ def supports_embedding_image_input(
)
def supports_parallel_function_calling(model: str):
"""
Check if the given model supports parallel function calling and return True if it does, False otherwise.
Parameters:
model (str): The model to check for support of parallel function calling.
Returns:
bool: True if the model supports parallel function calling, False otherwise.
Raises:
Exception: If the model is not found in the model_cost dictionary.
"""
if model in litellm.model_cost:
model_info = litellm.model_cost[model]
if model_info.get("supports_parallel_function_calling", False) is True:
return True
return False
else:
raise Exception(
f"Model not supports parallel function calling. You passed model={model}."
)
####### HELPER FUNCTIONS ################
def _update_dictionary(existing_dict: Dict, new_dict: dict) -> dict:
for k, v in new_dict.items():
@ -2682,8 +2737,10 @@ def get_optional_params( # noqa: PLR0915
api_version=None,
parallel_tool_calls=None,
drop_params=None,
reasoning_effort=None,
additional_drop_params=None,
messages: Optional[List[AllMessageValues]] = None,
thinking: Optional[AnthropicThinkingParam] = None,
**kwargs,
):
# retrieve all parameters passed to the function
@ -2767,9 +2824,12 @@ def get_optional_params( # noqa: PLR0915
"drop_params": None,
"additional_drop_params": None,
"messages": None,
"reasoning_effort": None,
"thinking": None,
}
# filter out those parameters that were passed with non-default values
non_default_params = {
k: v
for k, v in passed_params.items()
@ -3112,51 +3172,56 @@ def get_optional_params( # noqa: PLR0915
else False
),
)
elif custom_llm_provider == "vertex_ai" and model in litellm.vertex_llama3_models:
optional_params = litellm.VertexAILlama3Config().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif custom_llm_provider == "vertex_ai" and model in litellm.vertex_mistral_models:
if "codestral" in model:
optional_params = litellm.CodestralTextCompletionConfig().map_openai_params(
model=model,
elif custom_llm_provider == "vertex_ai":
if model in litellm.vertex_mistral_models:
if "codestral" in model:
optional_params = (
litellm.CodestralTextCompletionConfig().map_openai_params(
model=model,
non_default_params=non_default_params,
optional_params=optional_params,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
)
else:
optional_params = litellm.MistralConfig().map_openai_params(
model=model,
non_default_params=non_default_params,
optional_params=optional_params,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif model in litellm.vertex_ai_ai21_models:
optional_params = litellm.VertexAIAi21Config().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
else:
optional_params = litellm.MistralConfig().map_openai_params(
model=model,
else: # use generic openai-like param mapping
optional_params = litellm.VertexAILlama3Config().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif custom_llm_provider == "vertex_ai" and model in litellm.vertex_ai_ai21_models:
optional_params = litellm.VertexAIAi21Config().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif custom_llm_provider == "sagemaker":
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
optional_params = litellm.SagemakerConfig().map_openai_params(
@ -3170,8 +3235,8 @@ def get_optional_params( # noqa: PLR0915
),
)
elif custom_llm_provider == "bedrock":
base_model = litellm.AmazonConverseConfig()._get_base_model(model)
if base_model in litellm.bedrock_converse_models:
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
if bedrock_route == "converse" or bedrock_route == "converse_like":
optional_params = litellm.AmazonConverseConfig().map_openai_params(
model=model,
non_default_params=non_default_params,
@ -3184,15 +3249,20 @@ def get_optional_params( # noqa: PLR0915
messages=messages,
)
elif "anthropic" in model:
if "aws_bedrock_client" in passed_params: # deprecated boto3.invoke route.
if model.startswith("anthropic.claude-3"):
optional_params = (
litellm.AmazonAnthropicClaude3Config().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
elif "anthropic" in model and bedrock_route == "invoke":
if model.startswith("anthropic.claude-3"):
optional_params = (
litellm.AmazonAnthropicClaude3Config().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
)
else:
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
non_default_params=non_default_params,
@ -3953,8 +4023,16 @@ def _strip_stable_vertex_version(model_name) -> str:
return re.sub(r"-\d+$", "", model_name)
def _strip_bedrock_region(model_name) -> str:
return litellm.AmazonConverseConfig()._get_base_model(model_name)
def _get_base_bedrock_model(model_name) -> str:
"""
Get the base model from the given model name.
Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
"""
from litellm.llms.bedrock.common_utils import BedrockModelInfo
return BedrockModelInfo.get_base_model(model_name)
def _strip_openai_finetune_model_name(model_name: str) -> str:
@ -3975,8 +4053,8 @@ def _strip_openai_finetune_model_name(model_name: str) -> str:
def _strip_model_name(model: str, custom_llm_provider: Optional[str]) -> str:
if custom_llm_provider and custom_llm_provider == "bedrock":
strip_bedrock_region = _strip_bedrock_region(model_name=model)
return strip_bedrock_region
stripped_bedrock_model = _get_base_bedrock_model(model_name=model)
return stripped_bedrock_model
elif custom_llm_provider and (
custom_llm_provider == "vertex_ai" or custom_llm_provider == "gemini"
):
@ -4012,7 +4090,7 @@ def _check_provider_match(model_info: dict, custom_llm_provider: Optional[str])
"litellm_provider"
].startswith("fireworks_ai"):
return True
elif custom_llm_provider == "bedrock" and model_info[
elif custom_llm_provider.startswith("bedrock") and model_info[
"litellm_provider"
].startswith("bedrock"):
return True
@ -4101,7 +4179,6 @@ def _get_max_position_embeddings(model_name: str) -> Optional[int]:
return None
@lru_cache_wrapper(maxsize=16)
def _cached_get_model_info_helper(
model: str, custom_llm_provider: Optional[str]
) -> ModelInfoBase:
@ -4183,6 +4260,7 @@ def _get_model_info_helper( # noqa: PLR0915
supports_system_messages=None,
supports_response_schema=None,
supports_function_calling=None,
supports_tool_choice=None,
supports_assistant_prefill=None,
supports_prompt_caching=None,
supports_pdf_input=None,
@ -4327,6 +4405,7 @@ def _get_model_info_helper( # noqa: PLR0915
supports_function_calling=_model_info.get(
"supports_function_calling", False
),
supports_tool_choice=_model_info.get("supports_tool_choice", False),
supports_assistant_prefill=_model_info.get(
"supports_assistant_prefill", False
),
@ -4405,6 +4484,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
supports_response_schema: Optional[bool]
supports_vision: Optional[bool]
supports_function_calling: Optional[bool]
supports_tool_choice: Optional[bool]
supports_prompt_caching: Optional[bool]
supports_audio_input: Optional[bool]
supports_audio_output: Optional[bool]
@ -5119,9 +5199,10 @@ def _calculate_retry_after(
# custom prompt helper function
def register_prompt_template(
model: str,
roles: dict,
roles: dict = {},
initial_prompt_value: str = "",
final_prompt_value: str = "",
tokenizer_config: dict = {},
):
"""
Register a prompt template to follow your custom format for a given model
@ -5158,12 +5239,27 @@ def register_prompt_template(
)
```
"""
model = get_llm_provider(model=model)[0]
litellm.custom_prompt_dict[model] = {
"roles": roles,
"initial_prompt_value": initial_prompt_value,
"final_prompt_value": final_prompt_value,
}
complete_model = model
potential_models = [complete_model]
try:
model = get_llm_provider(model=model)[0]
potential_models.append(model)
except Exception:
pass
if tokenizer_config:
for m in potential_models:
litellm.known_tokenizer_config[m] = {
"tokenizer": tokenizer_config,
"status": "success",
}
else:
for m in potential_models:
litellm.custom_prompt_dict[m] = {
"roles": roles,
"initial_prompt_value": initial_prompt_value,
"final_prompt_value": final_prompt_value,
}
return litellm.custom_prompt_dict
@ -5827,6 +5923,18 @@ def convert_to_dict(message: Union[BaseModel, dict]) -> dict:
)
def validate_and_fix_openai_messages(messages: List):
"""
Ensures all messages are valid OpenAI chat completion messages.
Handles missing role for assistant messages.
"""
for message in messages:
if not message.get("role"):
message["role"] = "assistant"
return validate_chat_completion_messages(messages=messages)
def validate_chat_completion_messages(messages: List[AllMessageValues]):
"""
Ensures all messages are valid OpenAI chat completion messages.
@ -5866,6 +5974,10 @@ def validate_chat_completion_user_messages(messages: List[AllMessageValues]):
if item.get("type") not in ValidUserMessageContentTypes:
raise Exception("invalid content type")
except Exception as e:
if isinstance(e, KeyError):
raise Exception(
f"Invalid message={m} at index {idx}. Please ensure all messages are valid OpenAI chat completion messages."
)
if "invalid content type" in str(e):
raise Exception(
f"Invalid user message={m} at index {idx}. Please ensure all user messages are valid OpenAI chat completion messages."
@ -6040,25 +6152,29 @@ class ProviderConfigManager:
elif litellm.LlmProviders.PETALS == provider:
return litellm.PetalsConfig()
elif litellm.LlmProviders.BEDROCK == provider:
base_model = litellm.AmazonConverseConfig()._get_base_model(model)
bedrock_provider = litellm.BedrockLLM.get_bedrock_invoke_provider(model)
if (
base_model in litellm.bedrock_converse_models
or "converse_like" in model
):
bedrock_route = BedrockModelInfo.get_bedrock_route(model)
bedrock_invoke_provider = litellm.BedrockLLM.get_bedrock_invoke_provider(
model
)
if bedrock_route == "converse" or bedrock_route == "converse_like":
return litellm.AmazonConverseConfig()
elif bedrock_provider == "amazon": # amazon titan llms
elif bedrock_invoke_provider == "amazon": # amazon titan llms
return litellm.AmazonTitanConfig()
elif (
bedrock_provider == "meta" or bedrock_provider == "llama"
bedrock_invoke_provider == "meta" or bedrock_invoke_provider == "llama"
): # amazon / meta llms
return litellm.AmazonLlamaConfig()
elif bedrock_provider == "ai21": # ai21 llms
elif bedrock_invoke_provider == "ai21": # ai21 llms
return litellm.AmazonAI21Config()
elif bedrock_provider == "cohere": # cohere models on bedrock
elif bedrock_invoke_provider == "cohere": # cohere models on bedrock
return litellm.AmazonCohereConfig()
elif bedrock_provider == "mistral": # mistral models on bedrock
elif bedrock_invoke_provider == "mistral": # mistral models on bedrock
return litellm.AmazonMistralConfig()
elif bedrock_invoke_provider == "deepseek_r1": # deepseek models on bedrock
return litellm.AmazonDeepSeekR1Config()
else:
return litellm.AmazonInvokeConfig()
return litellm.OpenAIGPTConfig()
@staticmethod
@ -6078,13 +6194,20 @@ class ProviderConfigManager:
def get_provider_rerank_config(
model: str,
provider: LlmProviders,
api_base: Optional[str],
present_version_params: List[str],
) -> BaseRerankConfig:
if litellm.LlmProviders.COHERE == provider:
return litellm.CohereRerankConfig()
if should_use_cohere_v1_client(api_base, present_version_params):
return litellm.CohereRerankConfig()
else:
return litellm.CohereRerankV2Config()
elif litellm.LlmProviders.AZURE_AI == provider:
return litellm.AzureAIRerankConfig()
elif litellm.LlmProviders.INFINITY == provider:
return litellm.InfinityRerankConfig()
elif litellm.LlmProviders.JINA_AI == provider:
return litellm.JinaAIRerankConfig()
return litellm.CohereRerankConfig()
@staticmethod
@ -6163,6 +6286,19 @@ def get_end_user_id_for_cost_tracking(
return end_user_id
def should_use_cohere_v1_client(
api_base: Optional[str], present_version_params: List[str]
):
if not api_base:
return False
uses_v1_params = ("max_chunks_per_doc" in present_version_params) and (
"max_tokens_per_doc" not in present_version_params
)
return api_base.endswith("/v1/rerank") or (
uses_v1_params and not api_base.endswith("/v2/rerank")
)
def is_prompt_caching_valid_prompt(
model: str,
messages: Optional[List[AllMessageValues]],
@ -6273,7 +6409,9 @@ def get_non_default_completion_params(kwargs: dict) -> dict:
def add_openai_metadata(metadata: dict) -> dict:
"""
Add metadata to openai optional parameters, excluding hidden params
Add metadata to openai optional parameters, excluding hidden params.
OpenAI 'metadata' only supports string values.
Args:
params (dict): Dictionary of API parameters
@ -6285,5 +6423,10 @@ def add_openai_metadata(metadata: dict) -> dict:
if metadata is None:
return None
# Only include non-hidden parameters
visible_metadata = {k: v for k, v in metadata.items() if k != "hidden_params"}
visible_metadata = {
k: v
for k, v in metadata.items()
if k != "hidden_params" and isinstance(v, (str))
}
return visible_metadata.copy()