fix(utils.py): *new* get_supported_openai_params() function

Returns the supported openai params for a given model + provider
This commit is contained in:
Krrish Dholakia 2024-03-08 23:06:26 -08:00
parent aeb3cbc9b6
commit fd52b502a6

View file

@ -4305,14 +4305,9 @@ def get_optional_params(
optional_params["stop_sequences"] = stop
elif custom_llm_provider == "maritalk":
## check if unsupported param passed in
supported_params = [
"stream",
"temperature",
"max_tokens",
"top_p",
"presence_penalty",
"stop",
]
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
# handle cohere params
if stream:
@ -4331,14 +4326,9 @@ def get_optional_params(
optional_params["stopping_tokens"] = stop
elif custom_llm_provider == "replicate":
## check if unsupported param passed in
supported_params = [
"stream",
"temperature",
"max_tokens",
"top_p",
"stop",
"seed",
]
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
if stream:
@ -4359,7 +4349,9 @@ def get_optional_params(
optional_params["stop_sequences"] = stop
elif custom_llm_provider == "huggingface":
## check if unsupported param passed in
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
if temperature is not None:
@ -4398,16 +4390,9 @@ def get_optional_params(
) # since we handle translating echo, we should not send it to TGI request
elif custom_llm_provider == "together_ai":
## check if unsupported param passed in
supported_params = [
"stream",
"temperature",
"max_tokens",
"top_p",
"stop",
"frequency_penalty",
"tools",
"tool_choice",
]
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
if stream:
@ -4428,16 +4413,9 @@ def get_optional_params(
optional_params["tool_choice"] = tool_choice
elif custom_llm_provider == "ai21":
## check if unsupported param passed in
supported_params = [
"stream",
"n",
"temperature",
"max_tokens",
"top_p",
"stop",
"frequency_penalty",
"presence_penalty",
]
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
if stream:
@ -4460,7 +4438,9 @@ def get_optional_params(
custom_llm_provider == "palm" or custom_llm_provider == "gemini"
): # https://developers.generativeai.google/tutorials/curl_quickstart
## check if unsupported param passed in
supported_params = ["temperature", "top_p", "stream", "n", "stop", "max_tokens"]
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
if temperature is not None:
@ -4489,14 +4469,9 @@ def get_optional_params(
):
print_verbose(f"(start) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK")
## check if unsupported param passed in
supported_params = [
"temperature",
"top_p",
"max_tokens",
"stream",
"tools",
"tool_choice",
]
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
if temperature is not None:
@ -4526,7 +4501,9 @@ def get_optional_params(
)
elif custom_llm_provider == "sagemaker":
## check if unsupported param passed in
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
if temperature is not None:
@ -4553,8 +4530,10 @@ def get_optional_params(
max_tokens = 1
optional_params["max_new_tokens"] = max_tokens
elif custom_llm_provider == "bedrock":
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
if "ai21" in model:
supported_params = ["max_tokens", "temperature", "top_p", "stream"]
_check_valid_arg(supported_params=supported_params)
# params "maxTokens":200,"temperature":0,"topP":250,"stop_sequences":[],
# https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
@ -4567,9 +4546,6 @@ def get_optional_params(
if stream:
optional_params["stream"] = stream
elif "anthropic" in model:
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
# anthropic params on bedrock
# \"max_tokens_to_sample\":300,\"temperature\":0.5,\"top_p\":1,\"stop_sequences\":[\"\\\\n\\\\nHuman:\"]}"
@ -4586,7 +4562,6 @@ def get_optional_params(
optional_params=optional_params,
)
elif "amazon" in model: # amazon titan llms
supported_params = ["max_tokens", "temperature", "stop", "top_p", "stream"]
_check_valid_arg(supported_params=supported_params)
# see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large
if max_tokens is not None:
@ -4603,7 +4578,6 @@ def get_optional_params(
if stream:
optional_params["stream"] = stream
elif "meta" in model: # amazon / meta llms
supported_params = ["max_tokens", "temperature", "top_p", "stream"]
_check_valid_arg(supported_params=supported_params)
# see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large
if max_tokens is not None:
@ -4615,7 +4589,6 @@ def get_optional_params(
if stream:
optional_params["stream"] = stream
elif "cohere" in model: # cohere models on bedrock
supported_params = ["stream", "temperature", "max_tokens"]
_check_valid_arg(supported_params=supported_params)
# handle cohere params
if stream:
@ -4625,7 +4598,6 @@ def get_optional_params(
if max_tokens is not None:
optional_params["max_tokens"] = max_tokens
elif "mistral" in model:
supported_params = ["max_tokens", "temperature", "stop", "top_p", "stream"]
_check_valid_arg(supported_params=supported_params)
# mistral params on bedrock
# \"max_tokens\":400,\"temperature\":0.7,\"top_p\":0.7,\"stop\":[\"\\\\n\\\\nHuman:\"]}"
@ -4669,7 +4641,9 @@ def get_optional_params(
optional_params["stop_sequences"] = stop
elif custom_llm_provider == "cloudflare":
# https://developers.cloudflare.com/workers-ai/models/text-generation/#input
supported_params = ["max_tokens", "stream"]
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
if max_tokens is not None:
@ -4677,14 +4651,9 @@ def get_optional_params(
if stream is not None:
optional_params["stream"] = stream
elif custom_llm_provider == "ollama":
supported_params = [
"max_tokens",
"stream",
"top_p",
"temperature",
"frequency_penalty",
"stop",
]
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
if max_tokens is not None:
@ -4708,16 +4677,9 @@ def get_optional_params(
non_default_params=non_default_params, optional_params=optional_params
)
elif custom_llm_provider == "nlp_cloud":
supported_params = [
"max_tokens",
"stream",
"temperature",
"top_p",
"presence_penalty",
"frequency_penalty",
"n",
"stop",
]
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
if max_tokens is not None:
@ -4737,7 +4699,9 @@ def get_optional_params(
if stop is not None:
optional_params["stop_sequences"] = stop
elif custom_llm_provider == "petals":
supported_params = ["max_tokens", "temperature", "top_p", "stream"]
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
# max_new_tokens=1,temperature=0.9, top_p=0.6
if max_tokens is not None:
@ -4749,18 +4713,9 @@ def get_optional_params(
if stream:
optional_params["stream"] = stream
elif custom_llm_provider == "deepinfra":
supported_params = [
"temperature",
"top_p",
"n",
"stream",
"stop",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
]
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
if temperature is not None:
if (
@ -4787,14 +4742,9 @@ def get_optional_params(
if user:
optional_params["user"] = user
elif custom_llm_provider == "perplexity":
supported_params = [
"temperature",
"top_p",
"stream",
"max_tokens",
"presence_penalty",
"frequency_penalty",
]
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
if temperature is not None:
if (
@ -4813,15 +4763,9 @@ def get_optional_params(
if frequency_penalty:
optional_params["frequency_penalty"] = frequency_penalty
elif custom_llm_provider == "anyscale":
supported_params = [
"temperature",
"top_p",
"stream",
"max_tokens",
"stop",
"frequency_penalty",
"presence_penalty",
]
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
if model in [
"mistralai/Mistral-7B-Instruct-v0.1",
"mistralai/Mixtral-8x7B-Instruct-v0.1",
@ -4849,14 +4793,9 @@ def get_optional_params(
if max_tokens:
optional_params["max_tokens"] = max_tokens
elif custom_llm_provider == "mistral":
supported_params = [
"temperature",
"top_p",
"stream",
"max_tokens",
"tools",
"tool_choice",
]
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
if temperature is not None:
optional_params["temperature"] = temperature
@ -5002,12 +4941,27 @@ def get_optional_params(
def get_supported_openai_params(model: str, custom_llm_provider: str):
"""
Returns the supported openai params for a given model + provider
Example:
```
get_supported_openai_params(model="anthropic.claude-3", custom_llm_provider="bedrock")
```
"""
if custom_llm_provider == "bedrock":
if model.startswith("anthropic.claude-3"):
return litellm.AmazonAnthropicClaude3Config().get_supported_openai_params()
else:
elif model.startswith("anthropic"):
return litellm.AmazonAnthropicConfig().get_supported_openai_params()
elif model.startswith("ai21"):
return ["max_tokens", "temperature", "top_p", "stream"]
elif model.startswith("amazon"):
return ["max_tokens", "temperature", "stop", "top_p", "stream"]
elif model.startswith("meta"):
return ["max_tokens", "temperature", "top_p", "stream"]
elif model.startswith("cohere"):
return ["stream", "temperature", "max_tokens"]
elif model.startswith("mistral"):
return ["max_tokens", "temperature", "stop", "top_p", "stream"]
elif custom_llm_provider == "ollama_chat":
return litellm.OllamaChatConfig().get_supported_openai_params()
elif custom_llm_provider == "anthropic":
@ -5093,6 +5047,119 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
"tools",
"tool_choice",
]
elif custom_llm_provider == "replicate":
return [
"stream",
"temperature",
"max_tokens",
"top_p",
"stop",
"seed",
]
elif custom_llm_provider == "huggingface":
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
elif custom_llm_provider == "together_ai":
return [
"stream",
"temperature",
"max_tokens",
"top_p",
"stop",
"frequency_penalty",
"tools",
"tool_choice",
]
elif custom_llm_provider == "ai21":
return [
"stream",
"n",
"temperature",
"max_tokens",
"top_p",
"stop",
"frequency_penalty",
"presence_penalty",
]
elif custom_llm_provider == "palm" or custom_llm_provider == "gemini":
return ["temperature", "top_p", "stream", "n", "stop", "max_tokens"]
elif custom_llm_provider == "vertex_ai":
return [
"temperature",
"top_p",
"max_tokens",
"stream",
"tools",
"tool_choice",
]
elif custom_llm_provider == "sagemaker":
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
elif custom_llm_provider == "aleph_alpha":
return [
"max_tokens",
"stream",
"top_p",
"temperature",
"presence_penalty",
"frequency_penalty",
"n",
"stop",
]
elif custom_llm_provider == "cloudflare":
return ["max_tokens", "stream"]
elif custom_llm_provider == "ollama":
return [
"max_tokens",
"stream",
"top_p",
"temperature",
"frequency_penalty",
"stop",
]
elif custom_llm_provider == "nlp_cloud":
return [
"max_tokens",
"stream",
"temperature",
"top_p",
"presence_penalty",
"frequency_penalty",
"n",
"stop",
]
elif custom_llm_provider == "petals":
return ["max_tokens", "temperature", "top_p", "stream"]
elif custom_llm_provider == "deepinfra":
return [
"temperature",
"top_p",
"n",
"stream",
"stop",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
]
elif custom_llm_provider == "perplexity":
return [
"temperature",
"top_p",
"stream",
"max_tokens",
"presence_penalty",
"frequency_penalty",
]
elif custom_llm_provider == "anyscale":
return [
"temperature",
"top_p",
"stream",
"max_tokens",
"stop",
"frequency_penalty",
"presence_penalty",
]
def get_llm_provider(