(fix) use cohere_chat optional params

This commit is contained in:
ishaan-jaff 2024-03-12 14:31:43 -07:00
parent 5b0b251d42
commit b9bfc7c36c
2 changed files with 38 additions and 2 deletions

View file

@ -4401,6 +4401,31 @@ def get_optional_params(
optional_params["presence_penalty"] = presence_penalty
if stop is not None:
optional_params["stop_sequences"] = stop
elif custom_llm_provider == "cohere_chat":
## check if unsupported param passed in
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
# handle cohere params
if stream:
optional_params["stream"] = stream
if temperature is not None:
optional_params["temperature"] = temperature
if max_tokens is not None:
optional_params["max_tokens"] = max_tokens
if n is not None:
optional_params["num_generations"] = n
if top_p is not None:
optional_params["p"] = top_p
if frequency_penalty is not None:
optional_params["frequency_penalty"] = frequency_penalty
if presence_penalty is not None:
optional_params["presence_penalty"] = presence_penalty
if stop is not None:
optional_params["stop_sequences"] = stop
if tools is not None:
optional_params["tools"] = tools
elif custom_llm_provider == "maritalk":
## check if unsupported param passed in
supported_params = get_supported_openai_params(
@ -5084,6 +5109,19 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
"stop",
"n",
]
elif custom_llm_provider == "cohere_chat":
return [
"stream",
"temperature",
"max_tokens",
"top_p",
"frequency_penalty",
"presence_penalty",
"stop",
"n",
"tools",
"tool_choice",
]
elif custom_llm_provider == "maritalk":
return [
"stream",