fix(utils.py): add additional providers to get_supported_openai_params

This commit is contained in:
Krrish Dholakia 2024-03-08 22:49:24 -08:00
parent daa371ade9
commit aeb3cbc9b6

View file

@ -4254,15 +4254,9 @@ def get_optional_params(
## raise exception if provider doesn't support passed in param ## raise exception if provider doesn't support passed in param
if custom_llm_provider == "anthropic": if custom_llm_provider == "anthropic":
## check if unsupported param passed in ## check if unsupported param passed in
supported_params = [ supported_params = get_supported_openai_params(
"stream", model=model, custom_llm_provider=custom_llm_provider
"stop", )
"temperature",
"top_p",
"max_tokens",
"tools",
"tool_choice",
]
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
# handle anthropic params # handle anthropic params
if stream: if stream:
@ -4286,17 +4280,9 @@ def get_optional_params(
optional_params["tools"] = tools optional_params["tools"] = tools
elif custom_llm_provider == "cohere": elif custom_llm_provider == "cohere":
## check if unsupported param passed in ## check if unsupported param passed in
supported_params = [ supported_params = get_supported_openai_params(
"stream", model=model, custom_llm_provider=custom_llm_provider
"temperature", )
"max_tokens",
"logit_bias",
"top_p",
"frequency_penalty",
"presence_penalty",
"stop",
"n",
]
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
# handle cohere params # handle cohere params
if stream: if stream:
@ -4897,25 +4883,9 @@ def get_optional_params(
extra_body # openai client supports `extra_body` param extra_body # openai client supports `extra_body` param
) )
elif custom_llm_provider == "openrouter": elif custom_llm_provider == "openrouter":
supported_params = [ supported_params = get_supported_openai_params(
"functions", model=model, custom_llm_provider=custom_llm_provider
"function_call", )
"temperature",
"top_p",
"n",
"stream",
"stop",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"response_format",
"seed",
"tools",
"tool_choice",
"max_retries",
]
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
if functions is not None: if functions is not None:
@ -4969,28 +4939,9 @@ def get_optional_params(
) )
else: # assume passing in params for openai/azure openai else: # assume passing in params for openai/azure openai
print_verbose(f"UNMAPPED PROVIDER, ASSUMING IT'S OPENAI/AZURE") print_verbose(f"UNMAPPED PROVIDER, ASSUMING IT'S OPENAI/AZURE")
supported_params = [ supported_params = get_supported_openai_params(
"functions", model=model, custom_llm_provider="openai"
"function_call", )
"temperature",
"top_p",
"n",
"stream",
"stop",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"response_format",
"seed",
"tools",
"tool_choice",
"max_retries",
"logprobs",
"top_logprobs",
"extra_headers",
]
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
if functions is not None: if functions is not None:
optional_params["functions"] = functions optional_params["functions"] = functions
@ -5069,6 +5020,79 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
"tools", "tools",
"tool_choice", "tool_choice",
] ]
elif custom_llm_provider == "cohere":
return [
"stream",
"temperature",
"max_tokens",
"logit_bias",
"top_p",
"frequency_penalty",
"presence_penalty",
"stop",
"n",
]
elif custom_llm_provider == "maritalk":
return [
"stream",
"temperature",
"max_tokens",
"top_p",
"presence_penalty",
"stop",
]
elif custom_llm_provider == "openai" or custom_llm_provider == "azure":
return [
"functions",
"function_call",
"temperature",
"top_p",
"n",
"stream",
"stop",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"response_format",
"seed",
"tools",
"tool_choice",
"max_retries",
"logprobs",
"top_logprobs",
"extra_headers",
]
elif custom_llm_provider == "openrouter":
return [
"functions",
"function_call",
"temperature",
"top_p",
"n",
"stream",
"stop",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"response_format",
"seed",
"tools",
"tool_choice",
"max_retries",
]
elif custom_llm_provider == "mistral":
return [
"temperature",
"top_p",
"stream",
"max_tokens",
"tools",
"tool_choice",
]
def get_llm_provider( def get_llm_provider(