fix(n param in completion()): fix error thrown when passing n for cohere

This commit is contained in:
ishaan-jaff 2023-10-05 19:54:13 -07:00
parent 1897a1ee46
commit 4e6e79b20a

View file

@ -993,7 +993,8 @@ def get_optional_params( # use the openai defaults
if k not in supported_params:
if k == "n" and n == 1: # langchain sends n=1 as a default value
pass
if k == "request_timeout": # litellm handles request time outs
# Always keeps this in elif code blocks
elif k == "request_timeout": # litellm handles request time outs
pass
else:
unsupported_params.append(k)
@ -1019,7 +1020,7 @@ def get_optional_params( # use the openai defaults
optional_params["max_tokens_to_sample"] = max_tokens
elif custom_llm_provider == "cohere":
## check if unsupported param passed in
supported_params = ["stream", "temperature", "max_tokens", "logit_bias", "top_p", "frequency_penalty", "presence_penalty", "stop"]
supported_params = ["stream", "temperature", "max_tokens", "logit_bias", "top_p", "frequency_penalty", "presence_penalty", "stop", "n"]
_check_valid_arg(supported_params=supported_params)
# handle cohere params
if stream: