fix(utils.py): add additional providers to get_supported_openai_params

This commit is contained in:
Krrish Dholakia 2024-03-08 22:49:24 -08:00
parent daa371ade9
commit aeb3cbc9b6

View file

@ -4254,15 +4254,9 @@ def get_optional_params(
## raise exception if provider doesn't support passed in param
if custom_llm_provider == "anthropic":
## check if unsupported param passed in
supported_params = [
"stream",
"stop",
"temperature",
"top_p",
"max_tokens",
"tools",
"tool_choice",
]
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
# handle anthropic params
if stream:
@ -4286,17 +4280,9 @@ def get_optional_params(
optional_params["tools"] = tools
elif custom_llm_provider == "cohere":
## check if unsupported param passed in
supported_params = [
"stream",
"temperature",
"max_tokens",
"logit_bias",
"top_p",
"frequency_penalty",
"presence_penalty",
"stop",
"n",
]
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
# handle cohere params
if stream:
@ -4897,25 +4883,9 @@ def get_optional_params(
extra_body # openai client supports `extra_body` param
)
elif custom_llm_provider == "openrouter":
supported_params = [
"functions",
"function_call",
"temperature",
"top_p",
"n",
"stream",
"stop",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"response_format",
"seed",
"tools",
"tool_choice",
"max_retries",
]
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
if functions is not None:
@ -4969,28 +4939,9 @@ def get_optional_params(
)
else: # assume passing in params for openai/azure openai
print_verbose(f"UNMAPPED PROVIDER, ASSUMING IT'S OPENAI/AZURE")
supported_params = [
"functions",
"function_call",
"temperature",
"top_p",
"n",
"stream",
"stop",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"response_format",
"seed",
"tools",
"tool_choice",
"max_retries",
"logprobs",
"top_logprobs",
"extra_headers",
]
supported_params = get_supported_openai_params(
model=model, custom_llm_provider="openai"
)
_check_valid_arg(supported_params=supported_params)
if functions is not None:
optional_params["functions"] = functions
@ -5069,6 +5020,79 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
"tools",
"tool_choice",
]
elif custom_llm_provider == "cohere":
return [
"stream",
"temperature",
"max_tokens",
"logit_bias",
"top_p",
"frequency_penalty",
"presence_penalty",
"stop",
"n",
]
elif custom_llm_provider == "maritalk":
return [
"stream",
"temperature",
"max_tokens",
"top_p",
"presence_penalty",
"stop",
]
elif custom_llm_provider == "openai" or custom_llm_provider == "azure":
return [
"functions",
"function_call",
"temperature",
"top_p",
"n",
"stream",
"stop",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"response_format",
"seed",
"tools",
"tool_choice",
"max_retries",
"logprobs",
"top_logprobs",
"extra_headers",
]
elif custom_llm_provider == "openrouter":
return [
"functions",
"function_call",
"temperature",
"top_p",
"n",
"stream",
"stop",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"response_format",
"seed",
"tools",
"tool_choice",
"max_retries",
]
elif custom_llm_provider == "mistral":
return [
"temperature",
"top_p",
"stream",
"max_tokens",
"tools",
"tool_choice",
]
def get_llm_provider(