diff --git a/litellm/utils.py b/litellm/utils.py index 83f2a7ec1..3d80ddf58 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4254,15 +4254,9 @@ def get_optional_params( ## raise exception if provider doesn't support passed in param if custom_llm_provider == "anthropic": ## check if unsupported param passed in - supported_params = [ - "stream", - "stop", - "temperature", - "top_p", - "max_tokens", - "tools", - "tool_choice", - ] + supported_params = get_supported_openai_params( + model=model, custom_llm_provider=custom_llm_provider + ) _check_valid_arg(supported_params=supported_params) # handle anthropic params if stream: @@ -4286,17 +4280,9 @@ def get_optional_params( optional_params["tools"] = tools elif custom_llm_provider == "cohere": ## check if unsupported param passed in - supported_params = [ - "stream", - "temperature", - "max_tokens", - "logit_bias", - "top_p", - "frequency_penalty", - "presence_penalty", - "stop", - "n", - ] + supported_params = get_supported_openai_params( + model=model, custom_llm_provider=custom_llm_provider + ) _check_valid_arg(supported_params=supported_params) # handle cohere params if stream: @@ -4897,25 +4883,9 @@ def get_optional_params( extra_body # openai client supports `extra_body` param ) elif custom_llm_provider == "openrouter": - supported_params = [ - "functions", - "function_call", - "temperature", - "top_p", - "n", - "stream", - "stop", - "max_tokens", - "presence_penalty", - "frequency_penalty", - "logit_bias", - "user", - "response_format", - "seed", - "tools", - "tool_choice", - "max_retries", - ] + supported_params = get_supported_openai_params( + model=model, custom_llm_provider=custom_llm_provider + ) _check_valid_arg(supported_params=supported_params) if functions is not None: @@ -4969,28 +4939,9 @@ def get_optional_params( ) else: # assume passing in params for openai/azure openai print_verbose(f"UNMAPPED PROVIDER, ASSUMING IT'S OPENAI/AZURE") - supported_params = [ - "functions", - "function_call", - "temperature", - "top_p", - "n", - "stream", - "stop", - "max_tokens", - "presence_penalty", - "frequency_penalty", - "logit_bias", - "user", - "response_format", - "seed", - "tools", - "tool_choice", - "max_retries", - "logprobs", - "top_logprobs", - "extra_headers", - ] + supported_params = get_supported_openai_params( + model=model, custom_llm_provider="openai" + ) _check_valid_arg(supported_params=supported_params) if functions is not None: optional_params["functions"] = functions @@ -5069,6 +5020,79 @@ def get_supported_openai_params(model: str, custom_llm_provider: str): "tools", "tool_choice", ] + elif custom_llm_provider == "cohere": + return [ + "stream", + "temperature", + "max_tokens", + "logit_bias", + "top_p", + "frequency_penalty", + "presence_penalty", + "stop", + "n", + ] + elif custom_llm_provider == "maritalk": + return [ + "stream", + "temperature", + "max_tokens", + "top_p", + "presence_penalty", + "stop", + ] + elif custom_llm_provider == "openai" or custom_llm_provider == "azure": + return [ + "functions", + "function_call", + "temperature", + "top_p", + "n", + "stream", + "stop", + "max_tokens", + "presence_penalty", + "frequency_penalty", + "logit_bias", + "user", + "response_format", + "seed", + "tools", + "tool_choice", + "max_retries", + "logprobs", + "top_logprobs", + "extra_headers", + ] + elif custom_llm_provider == "openrouter": + return [ + "functions", + "function_call", + "temperature", + "top_p", + "n", + "stream", + "stop", + "max_tokens", + "presence_penalty", + "frequency_penalty", + "logit_bias", + "user", + "response_format", + "seed", + "tools", + "tool_choice", + "max_retries", + ] + elif custom_llm_provider == "mistral": + return [ + "temperature", + "top_p", + "stream", + "max_tokens", + "tools", + "tool_choice", + ] def get_llm_provider(