support cohere top p

This commit is contained in:
Krrish Dholakia 2023-10-02 21:38:56 -07:00
parent 4ce5c1d1dc
commit 4ea7162a81
4 changed files with 6 additions and 3 deletions

View file

@ -47,13 +47,14 @@ def test_completion_return_full_text_hf():
def test_completion_invalid_param_cohere():
try:
response = completion(model="command-nightly", messages=messages, top_p=1)
print(f"response: {response}")
except Exception as e:
if "Unsupported parameters passed: top_p" in str(e):
pass
else:
pytest.fail(f'An error occurred {e}')
# test_completion_invalid_param_cohere()
test_completion_invalid_param_cohere()
def test_completion_function_call_cohere():
try:

View file

@ -1003,7 +1003,7 @@ def get_optional_params( # use the openai defaults
optional_params["max_tokens_to_sample"] = max_tokens
elif custom_llm_provider == "cohere":
## check if unsupported param passed in
supported_params = ["stream", "temperature", "max_tokens", "logit_bias"]
supported_params = ["stream", "temperature", "max_tokens", "logit_bias", "top_p"]
_check_valid_arg(supported_params=supported_params)
# handle cohere params
if stream:
@ -1014,6 +1014,8 @@ def get_optional_params( # use the openai defaults
optional_params["max_tokens"] = max_tokens
if logit_bias != {}:
optional_params["logit_bias"] = logit_bias
if top_p:
optional_params["p"] = top_p
elif custom_llm_provider == "replicate":
## check if unsupported param passed in
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "seed"]

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "litellm"
version = "0.1.814"
version = "0.1.815"
description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"]
license = "MIT License"