(fix) use cohere_chat optional params

This commit is contained in:
ishaan-jaff 2024-03-12 14:31:43 -07:00
parent 5b0b251d42
commit b9bfc7c36c
2 changed files with 38 additions and 2 deletions

View file

@ -22,7 +22,6 @@ def test_chat_completion_cohere():
try: try:
litellm.set_verbose = True litellm.set_verbose = True
messages = [ messages = [
{"role": "system", "content": "You're a good bot"},
{ {
"role": "user", "role": "user",
"content": "Hey", "content": "Hey",
@ -42,7 +41,6 @@ def test_chat_completion_cohere_stream():
try: try:
litellm.set_verbose = False litellm.set_verbose = False
messages = [ messages = [
{"role": "system", "content": "You're a good bot"},
{ {
"role": "user", "role": "user",
"content": "Hey", "content": "Hey",

View file

@ -4401,6 +4401,31 @@ def get_optional_params(
optional_params["presence_penalty"] = presence_penalty optional_params["presence_penalty"] = presence_penalty
if stop is not None: if stop is not None:
optional_params["stop_sequences"] = stop optional_params["stop_sequences"] = stop
elif custom_llm_provider == "cohere_chat":
## check if unsupported param passed in
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
# handle cohere params
if stream:
optional_params["stream"] = stream
if temperature is not None:
optional_params["temperature"] = temperature
if max_tokens is not None:
optional_params["max_tokens"] = max_tokens
if n is not None:
optional_params["num_generations"] = n
if top_p is not None:
optional_params["p"] = top_p
if frequency_penalty is not None:
optional_params["frequency_penalty"] = frequency_penalty
if presence_penalty is not None:
optional_params["presence_penalty"] = presence_penalty
if stop is not None:
optional_params["stop_sequences"] = stop
if tools is not None:
optional_params["tools"] = tools
elif custom_llm_provider == "maritalk": elif custom_llm_provider == "maritalk":
## check if unsupported param passed in ## check if unsupported param passed in
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
@ -5084,6 +5109,19 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
"stop", "stop",
"n", "n",
] ]
elif custom_llm_provider == "cohere_chat":
return [
"stream",
"temperature",
"max_tokens",
"top_p",
"frequency_penalty",
"presence_penalty",
"stop",
"n",
"tools",
"tool_choice",
]
elif custom_llm_provider == "maritalk": elif custom_llm_provider == "maritalk":
return [ return [
"stream", "stream",