Merge pull request #2479 from BerriAI/litellm_cohere_tool_call

[FEAT Cohere/command-r tool calling
This commit is contained in:
Ishaan Jaff 2024-03-12 21:20:59 -07:00 committed by GitHub
commit 7b4f9691c7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 443 additions and 1 deletions

View file

@ -4271,6 +4271,7 @@ def get_optional_params(
and custom_llm_provider != "together_ai"
and custom_llm_provider != "mistral"
and custom_llm_provider != "anthropic"
and custom_llm_provider != "cohere_chat"
and custom_llm_provider != "bedrock"
and custom_llm_provider != "ollama_chat"
):
@ -4402,6 +4403,31 @@ def get_optional_params(
optional_params["presence_penalty"] = presence_penalty
if stop is not None:
optional_params["stop_sequences"] = stop
elif custom_llm_provider == "cohere_chat":
## check if unsupported param passed in
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
# handle cohere params
if stream:
optional_params["stream"] = stream
if temperature is not None:
optional_params["temperature"] = temperature
if max_tokens is not None:
optional_params["max_tokens"] = max_tokens
if n is not None:
optional_params["num_generations"] = n
if top_p is not None:
optional_params["p"] = top_p
if frequency_penalty is not None:
optional_params["frequency_penalty"] = frequency_penalty
if presence_penalty is not None:
optional_params["presence_penalty"] = presence_penalty
if stop is not None:
optional_params["stop_sequences"] = stop
if tools is not None:
optional_params["tools"] = tools
elif custom_llm_provider == "maritalk":
## check if unsupported param passed in
supported_params = get_supported_openai_params(
@ -5085,6 +5111,19 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
"stop",
"n",
]
elif custom_llm_provider == "cohere_chat":
return [
"stream",
"temperature",
"max_tokens",
"top_p",
"frequency_penalty",
"presence_penalty",
"stop",
"n",
"tools",
"tool_choice",
]
elif custom_llm_provider == "maritalk":
return [
"stream",