support cohere top p

This commit is contained in:
Krrish Dholakia 2023-10-02 21:38:56 -07:00
parent 4ce5c1d1dc
commit 4ea7162a81
4 changed files with 6 additions and 3 deletions

View file

@ -47,13 +47,14 @@ def test_completion_return_full_text_hf():
def test_completion_invalid_param_cohere(): def test_completion_invalid_param_cohere():
try: try:
response = completion(model="command-nightly", messages=messages, top_p=1) response = completion(model="command-nightly", messages=messages, top_p=1)
print(f"response: {response}")
except Exception as e: except Exception as e:
if "Unsupported parameters passed: top_p" in str(e): if "Unsupported parameters passed: top_p" in str(e):
pass pass
else: else:
pytest.fail(f'An error occurred {e}') pytest.fail(f'An error occurred {e}')
# test_completion_invalid_param_cohere() test_completion_invalid_param_cohere()
def test_completion_function_call_cohere(): def test_completion_function_call_cohere():
try: try:

View file

@ -1003,7 +1003,7 @@ def get_optional_params( # use the openai defaults
optional_params["max_tokens_to_sample"] = max_tokens optional_params["max_tokens_to_sample"] = max_tokens
elif custom_llm_provider == "cohere": elif custom_llm_provider == "cohere":
## check if unsupported param passed in ## check if unsupported param passed in
supported_params = ["stream", "temperature", "max_tokens", "logit_bias"] supported_params = ["stream", "temperature", "max_tokens", "logit_bias", "top_p"]
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
# handle cohere params # handle cohere params
if stream: if stream:
@ -1014,6 +1014,8 @@ def get_optional_params( # use the openai defaults
optional_params["max_tokens"] = max_tokens optional_params["max_tokens"] = max_tokens
if logit_bias != {}: if logit_bias != {}:
optional_params["logit_bias"] = logit_bias optional_params["logit_bias"] = logit_bias
if top_p:
optional_params["p"] = top_p
elif custom_llm_provider == "replicate": elif custom_llm_provider == "replicate":
## check if unsupported param passed in ## check if unsupported param passed in
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "seed"] supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "seed"]

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "0.1.814" version = "0.1.815"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT License" license = "MIT License"