mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
fix(n param in completion()): fix error thrown when passing n for cohere
This commit is contained in:
parent
1897a1ee46
commit
4e6e79b20a
1 changed files with 4 additions and 3 deletions
|
@ -978,7 +978,7 @@ def get_optional_params( # use the openai defaults
|
||||||
}
|
}
|
||||||
# filter out those parameters that were passed with non-default values
|
# filter out those parameters that were passed with non-default values
|
||||||
non_default_params = {k: v for k, v in passed_params.items() if (k != "model" and k != "custom_llm_provider" and k in default_params and v != default_params[k])}
|
non_default_params = {k: v for k, v in passed_params.items() if (k != "model" and k != "custom_llm_provider" and k in default_params and v != default_params[k])}
|
||||||
|
|
||||||
## raise exception if function calling passed in for a provider that doesn't support it
|
## raise exception if function calling passed in for a provider that doesn't support it
|
||||||
if "functions" in non_default_params or "function_call" in non_default_params:
|
if "functions" in non_default_params or "function_call" in non_default_params:
|
||||||
if custom_llm_provider != "openai" and custom_llm_provider != "text-completion-openai" and custom_llm_provider != "azure":
|
if custom_llm_provider != "openai" and custom_llm_provider != "text-completion-openai" and custom_llm_provider != "azure":
|
||||||
|
@ -993,7 +993,8 @@ def get_optional_params( # use the openai defaults
|
||||||
if k not in supported_params:
|
if k not in supported_params:
|
||||||
if k == "n" and n == 1: # langchain sends n=1 as a default value
|
if k == "n" and n == 1: # langchain sends n=1 as a default value
|
||||||
pass
|
pass
|
||||||
if k == "request_timeout": # litellm handles request time outs
|
# Always keeps this in elif code blocks
|
||||||
|
elif k == "request_timeout": # litellm handles request time outs
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
unsupported_params.append(k)
|
unsupported_params.append(k)
|
||||||
|
@ -1019,7 +1020,7 @@ def get_optional_params( # use the openai defaults
|
||||||
optional_params["max_tokens_to_sample"] = max_tokens
|
optional_params["max_tokens_to_sample"] = max_tokens
|
||||||
elif custom_llm_provider == "cohere":
|
elif custom_llm_provider == "cohere":
|
||||||
## check if unsupported param passed in
|
## check if unsupported param passed in
|
||||||
supported_params = ["stream", "temperature", "max_tokens", "logit_bias", "top_p", "frequency_penalty", "presence_penalty", "stop"]
|
supported_params = ["stream", "temperature", "max_tokens", "logit_bias", "top_p", "frequency_penalty", "presence_penalty", "stop", "n"]
|
||||||
_check_valid_arg(supported_params=supported_params)
|
_check_valid_arg(supported_params=supported_params)
|
||||||
# handle cohere params
|
# handle cohere params
|
||||||
if stream:
|
if stream:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue