diff --git a/litellm/tests/test_cohere_completion.py b/litellm/tests/test_cohere_completion.py index 9c3c9bf93c..372c87b400 100644 --- a/litellm/tests/test_cohere_completion.py +++ b/litellm/tests/test_cohere_completion.py @@ -22,7 +22,6 @@ def test_chat_completion_cohere(): try: litellm.set_verbose = True messages = [ - {"role": "system", "content": "You're a good bot"}, { "role": "user", "content": "Hey", @@ -42,7 +41,6 @@ def test_chat_completion_cohere_stream(): try: litellm.set_verbose = False messages = [ - {"role": "system", "content": "You're a good bot"}, { "role": "user", "content": "Hey", diff --git a/litellm/utils.py b/litellm/utils.py index 1fd8434338..6abba4afa0 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4401,6 +4401,31 @@ def get_optional_params( optional_params["presence_penalty"] = presence_penalty if stop is not None: optional_params["stop_sequences"] = stop + elif custom_llm_provider == "cohere_chat": + ## check if unsupported param passed in + supported_params = get_supported_openai_params( + model=model, custom_llm_provider=custom_llm_provider + ) + _check_valid_arg(supported_params=supported_params) + # handle cohere params + if stream: + optional_params["stream"] = stream + if temperature is not None: + optional_params["temperature"] = temperature + if max_tokens is not None: + optional_params["max_tokens"] = max_tokens + if n is not None: + optional_params["num_generations"] = n + if top_p is not None: + optional_params["p"] = top_p + if frequency_penalty is not None: + optional_params["frequency_penalty"] = frequency_penalty + if presence_penalty is not None: + optional_params["presence_penalty"] = presence_penalty + if stop is not None: + optional_params["stop_sequences"] = stop + if tools is not None: + optional_params["tools"] = tools elif custom_llm_provider == "maritalk": ## check if unsupported param passed in supported_params = get_supported_openai_params( @@ -5084,6 +5109,19 @@ def get_supported_openai_params(model: str, custom_llm_provider: str): "stop", "n", ] + elif custom_llm_provider == "cohere_chat": + return [ + "stream", + "temperature", + "max_tokens", + "top_p", + "frequency_penalty", + "presence_penalty", + "stop", + "n", + "tools", + "tool_choice", + ] elif custom_llm_provider == "maritalk": return [ "stream",