mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(utils.py): support deepinfra optional params
Fixes https://github.com/BerriAI/litellm/issues/3855
This commit is contained in:
parent
a6a84e57ce
commit
f0f853b941
3 changed files with 109 additions and 38 deletions
|
@ -766,7 +766,12 @@ from .llms.bedrock import (
|
||||||
AmazonMistralConfig,
|
AmazonMistralConfig,
|
||||||
AmazonBedrockGlobalConfig,
|
AmazonBedrockGlobalConfig,
|
||||||
)
|
)
|
||||||
from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig, MistralConfig
|
from .llms.openai import (
|
||||||
|
OpenAIConfig,
|
||||||
|
OpenAITextCompletionConfig,
|
||||||
|
MistralConfig,
|
||||||
|
DeepInfraConfig,
|
||||||
|
)
|
||||||
from .llms.azure import AzureOpenAIConfig, AzureOpenAIError
|
from .llms.azure import AzureOpenAIConfig, AzureOpenAIError
|
||||||
from .llms.watsonx import IBMWatsonXAIConfig
|
from .llms.watsonx import IBMWatsonXAIConfig
|
||||||
from .main import * # type: ignore
|
from .main import * # type: ignore
|
||||||
|
|
|
@ -157,6 +157,101 @@ class MistralConfig:
|
||||||
)
|
)
|
||||||
if param == "seed":
|
if param == "seed":
|
||||||
optional_params["extra_body"] = {"random_seed": value}
|
optional_params["extra_body"] = {"random_seed": value}
|
||||||
|
if param == "response_format":
|
||||||
|
optional_params["response_format"] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
|
class DeepInfraConfig:
|
||||||
|
"""
|
||||||
|
Reference: https://deepinfra.com/docs/advanced/openai_api
|
||||||
|
|
||||||
|
The class `DeepInfra` provides configuration for the DeepInfra's Chat Completions API interface. Below are the parameters:
|
||||||
|
"""
|
||||||
|
|
||||||
|
frequency_penalty: Optional[int] = None
|
||||||
|
function_call: Optional[Union[str, dict]] = None
|
||||||
|
functions: Optional[list] = None
|
||||||
|
logit_bias: Optional[dict] = None
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
n: Optional[int] = None
|
||||||
|
presence_penalty: Optional[int] = None
|
||||||
|
stop: Optional[Union[str, list]] = None
|
||||||
|
temperature: Optional[int] = None
|
||||||
|
top_p: Optional[int] = None
|
||||||
|
response_format: Optional[dict] = None
|
||||||
|
tools: Optional[list] = None
|
||||||
|
tool_choice: Optional[Union[str, dict]] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
frequency_penalty: Optional[int] = None,
|
||||||
|
function_call: Optional[Union[str, dict]] = None,
|
||||||
|
functions: Optional[list] = None,
|
||||||
|
logit_bias: Optional[dict] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
n: Optional[int] = None,
|
||||||
|
presence_penalty: Optional[int] = None,
|
||||||
|
stop: Optional[Union[str, list]] = None,
|
||||||
|
temperature: Optional[int] = None,
|
||||||
|
top_p: Optional[int] = None,
|
||||||
|
response_format: Optional[dict] = None,
|
||||||
|
tools: Optional[list] = None,
|
||||||
|
tool_choice: Optional[Union[str, dict]] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals().copy()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self):
|
||||||
|
return [
|
||||||
|
"frequency_penalty",
|
||||||
|
"function_call",
|
||||||
|
"functions",
|
||||||
|
"logit_bias",
|
||||||
|
"max_tokens",
|
||||||
|
"n",
|
||||||
|
"presence_penalty",
|
||||||
|
"stop",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"response_format",
|
||||||
|
"tools",
|
||||||
|
"tool_choice",
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self, non_default_params: dict, optional_params: dict, model: str
|
||||||
|
):
|
||||||
|
supported_openai_params = self.get_supported_openai_params()
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if (
|
||||||
|
param == "temperature"
|
||||||
|
and value == 0
|
||||||
|
and model == "mistralai/Mistral-7B-Instruct-v0.1"
|
||||||
|
): # this model does no support temperature == 0
|
||||||
|
value = 0.0001 # close to 0
|
||||||
|
if param in supported_openai_params:
|
||||||
|
optional_params[param] = value
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
|
@ -197,6 +292,7 @@ class OpenAIConfig:
|
||||||
stop: Optional[Union[str, list]] = None
|
stop: Optional[Union[str, list]] = None
|
||||||
temperature: Optional[int] = None
|
temperature: Optional[int] = None
|
||||||
top_p: Optional[int] = None
|
top_p: Optional[int] = None
|
||||||
|
response_format: Optional[dict] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -210,6 +306,7 @@ class OpenAIConfig:
|
||||||
stop: Optional[Union[str, list]] = None,
|
stop: Optional[Union[str, list]] = None,
|
||||||
temperature: Optional[int] = None,
|
temperature: Optional[int] = None,
|
||||||
top_p: Optional[int] = None,
|
top_p: Optional[int] = None,
|
||||||
|
response_format: Optional[dict] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
locals_ = locals().copy()
|
locals_ = locals().copy()
|
||||||
for key, value in locals_.items():
|
for key, value in locals_.items():
|
||||||
|
|
|
@ -5797,30 +5797,11 @@ def get_optional_params(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
)
|
)
|
||||||
_check_valid_arg(supported_params=supported_params)
|
_check_valid_arg(supported_params=supported_params)
|
||||||
if temperature is not None:
|
optional_params = litellm.DeepInfraConfig().map_openai_params(
|
||||||
if (
|
non_default_params=non_default_params,
|
||||||
temperature == 0 and model == "mistralai/Mistral-7B-Instruct-v0.1"
|
optional_params=optional_params,
|
||||||
): # this model does no support temperature == 0
|
model=model,
|
||||||
temperature = 0.0001 # close to 0
|
)
|
||||||
optional_params["temperature"] = temperature
|
|
||||||
if top_p:
|
|
||||||
optional_params["top_p"] = top_p
|
|
||||||
if n:
|
|
||||||
optional_params["n"] = n
|
|
||||||
if stream:
|
|
||||||
optional_params["stream"] = stream
|
|
||||||
if stop:
|
|
||||||
optional_params["stop"] = stop
|
|
||||||
if max_tokens:
|
|
||||||
optional_params["max_tokens"] = max_tokens
|
|
||||||
if presence_penalty:
|
|
||||||
optional_params["presence_penalty"] = presence_penalty
|
|
||||||
if frequency_penalty:
|
|
||||||
optional_params["frequency_penalty"] = frequency_penalty
|
|
||||||
if logit_bias:
|
|
||||||
optional_params["logit_bias"] = logit_bias
|
|
||||||
if user:
|
|
||||||
optional_params["user"] = user
|
|
||||||
elif custom_llm_provider == "perplexity":
|
elif custom_llm_provider == "perplexity":
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
@ -6604,19 +6585,7 @@ def get_supported_openai_params(
|
||||||
elif custom_llm_provider == "petals":
|
elif custom_llm_provider == "petals":
|
||||||
return ["max_tokens", "temperature", "top_p", "stream"]
|
return ["max_tokens", "temperature", "top_p", "stream"]
|
||||||
elif custom_llm_provider == "deepinfra":
|
elif custom_llm_provider == "deepinfra":
|
||||||
return [
|
return litellm.DeepInfraConfig().get_supported_openai_params()
|
||||||
"temperature",
|
|
||||||
"top_p",
|
|
||||||
"n",
|
|
||||||
"stream",
|
|
||||||
"stop",
|
|
||||||
"max_tokens",
|
|
||||||
"presence_penalty",
|
|
||||||
"frequency_penalty",
|
|
||||||
"logit_bias",
|
|
||||||
"user",
|
|
||||||
"response_format",
|
|
||||||
]
|
|
||||||
elif custom_llm_provider == "perplexity":
|
elif custom_llm_provider == "perplexity":
|
||||||
return [
|
return [
|
||||||
"temperature",
|
"temperature",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue