Merge pull request #2474 from BerriAI/litellm_support_command_r

[New-Model] Cohere/command-r
This commit is contained in:
Ishaan Jaff 2024-03-12 11:11:56 -07:00 committed by GitHub
commit 5172fb1de9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 386 additions and 27 deletions

View file

@ -17,7 +17,7 @@ os.environ["COHERE_API_KEY"] = "cohere key"
# cohere call
response = completion(
model="command-nightly",
model="command-r",
messages = [{ "content": "Hello, how are you?","role": "user"}]
)
```
@ -32,7 +32,7 @@ os.environ["COHERE_API_KEY"] = "cohere key"
# cohere call
response = completion(
model="command-nightly",
model="command-r",
messages = [{ "content": "Hello, how are you?","role": "user"}],
stream=True
)
@ -41,7 +41,17 @@ for chunk in response:
print(chunk)
```
LiteLLM supports 'command', 'command-light', 'command-medium', 'command-medium-beta', 'command-xlarge-beta', 'command-nightly' models from [Cohere](https://cohere.com/).
## Supported Models
| Model Name | Function Call |
|------------|----------------|
| command-r | `completion('command-r', messages)` |
| command-light | `completion('command-light', messages)` |
| command-medium | `completion('command-medium', messages)` |
| command-medium-beta | `completion('command-medium-beta', messages)` |
| command-xlarge-beta | `completion('command-xlarge-beta', messages)` |
| command-nightly | `completion('command-nightly', messages)` |
## Embedding

View file

@ -131,6 +131,7 @@ const sidebars = {
"providers/anthropic",
"providers/aws_sagemaker",
"providers/bedrock",
"providers/cohere",
"providers/anyscale",
"providers/huggingface",
"providers/ollama",
@ -143,7 +144,6 @@ const sidebars = {
"providers/ai21",
"providers/nlp_cloud",
"providers/replicate",
"providers/cohere",
"providers/togetherai",
"providers/voyage",
"providers/aleph_alpha",

View file

@ -252,6 +252,7 @@ config_path = None
open_ai_chat_completion_models: List = []
open_ai_text_completion_models: List = []
cohere_models: List = []
cohere_chat_models: List = []
anthropic_models: List = []
openrouter_models: List = []
vertex_language_models: List = []
@ -274,6 +275,8 @@ for key, value in model_cost.items():
open_ai_text_completion_models.append(key)
elif value.get("litellm_provider") == "cohere":
cohere_models.append(key)
elif value.get("litellm_provider") == "cohere_chat":
cohere_chat_models.append(key)
elif value.get("litellm_provider") == "anthropic":
anthropic_models.append(key)
elif value.get("litellm_provider") == "openrouter":
@ -421,6 +424,7 @@ model_list = (
open_ai_chat_completion_models
+ open_ai_text_completion_models
+ cohere_models
+ cohere_chat_models
+ anthropic_models
+ replicate_models
+ openrouter_models
@ -444,6 +448,7 @@ provider_list: List = [
"custom_openai",
"text-completion-openai",
"cohere",
"cohere_chat",
"anthropic",
"replicate",
"huggingface",
@ -479,6 +484,7 @@ provider_list: List = [
models_by_provider: dict = {
"openai": open_ai_chat_completion_models + open_ai_text_completion_models,
"cohere": cohere_models,
"cohere_chat": cohere_chat_models,
"anthropic": anthropic_models,
"replicate": replicate_models,
"huggingface": huggingface_models,

204
litellm/llms/cohere_chat.py Normal file
View file

@ -0,0 +1,204 @@
import os, types
import json
from enum import Enum
import requests
import time, traceback
from typing import Callable, Optional
from litellm.utils import ModelResponse, Choices, Message, Usage
import litellm
import httpx
class CohereError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://api.cohere.ai/v1/chat")
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class CohereChatConfig:
"""
Configuration class for Cohere's API interface.
Args:
preamble (str, optional): When specified, the default Cohere preamble will be replaced with the provided one.
chat_history (List[Dict[str, str]], optional): A list of previous messages between the user and the model.
generation_id (str, optional): Unique identifier for the generated reply.
response_id (str, optional): Unique identifier for the response.
conversation_id (str, optional): An alternative to chat_history, creates or resumes a persisted conversation.
prompt_truncation (str, optional): Dictates how the prompt will be constructed. Options: 'AUTO', 'AUTO_PRESERVE_ORDER', 'OFF'.
connectors (List[Dict[str, str]], optional): List of connectors (e.g., web-search) to enrich the model's reply.
search_queries_only (bool, optional): When true, the response will only contain a list of generated search queries.
documents (List[Dict[str, str]], optional): A list of relevant documents that the model can cite.
temperature (float, optional): A non-negative float that tunes the degree of randomness in generation.
max_tokens (int, optional): The maximum number of tokens the model will generate as part of the response.
k (int, optional): Ensures only the top k most likely tokens are considered for generation at each step.
p (float, optional): Ensures that only the most likely tokens, with total probability mass of p, are considered for generation.
frequency_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
presence_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
tools (List[Dict[str, str]], optional): A list of available tools (functions) that the model may suggest invoking.
tool_results (List[Dict[str, Any]], optional): A list of results from invoking tools.
"""
preamble: Optional[str] = None
chat_history: Optional[list] = None
generation_id: Optional[str] = None
response_id: Optional[str] = None
conversation_id: Optional[str] = None
prompt_truncation: Optional[str] = None
connectors: Optional[list] = None
search_queries_only: Optional[bool] = None
documents: Optional[list] = None
temperature: Optional[int] = None
max_tokens: Optional[int] = None
k: Optional[int] = None
p: Optional[int] = None
frequency_penalty: Optional[int] = None
presence_penalty: Optional[int] = None
tools: Optional[list] = None
tool_results: Optional[list] = None
def __init__(
self,
preamble: Optional[str] = None,
chat_history: Optional[list] = None,
generation_id: Optional[str] = None,
response_id: Optional[str] = None,
conversation_id: Optional[str] = None,
prompt_truncation: Optional[str] = None,
connectors: Optional[list] = None,
search_queries_only: Optional[bool] = None,
documents: Optional[list] = None,
temperature: Optional[int] = None,
max_tokens: Optional[int] = None,
k: Optional[int] = None,
p: Optional[int] = None,
frequency_penalty: Optional[int] = None,
presence_penalty: Optional[int] = None,
tools: Optional[list] = None,
tool_results: Optional[list] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def validate_environment(api_key):
headers = {
"accept": "application/json",
"content-type": "application/json",
}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers
def completion(
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params=None,
litellm_params=None,
logger_fn=None,
):
headers = validate_environment(api_key)
completion_url = api_base
model = model
prompt = " ".join(message["content"] for message in messages)
## Load Config
config = litellm.CohereConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
data = {
"model": model,
"message": prompt,
**optional_params,
}
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": completion_url,
},
)
## COMPLETION CALL
response = requests.post(
completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False,
)
## error handling for cohere calls
if response.status_code != 200:
raise CohereError(message=response.text, status_code=response.status_code)
if "stream" in optional_params and optional_params["stream"] == True:
return response.iter_lines()
else:
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
completion_response = response.json()
try:
model_response.choices[0].message.content = completion_response["text"] # type: ignore
except Exception as e:
raise CohereError(message=response.text, status_code=response.status_code)
## CALCULATING USAGE - use cohere `billed_units` for returning usage
billed_units = completion_response.get("meta", {}).get("billed_units", {})
prompt_tokens = billed_units.get("input_tokens", 0)
completion_tokens = billed_units.get("output_tokens", 0)
model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response

View file

@ -55,6 +55,7 @@ from .llms import (
ollama_chat,
cloudflare,
cohere,
cohere_chat,
petals,
oobabooga,
openrouter,
@ -1287,6 +1288,46 @@ def completion(
)
return response
response = model_response
elif custom_llm_provider == "cohere_chat":
cohere_key = (
api_key
or litellm.cohere_key
or get_secret("COHERE_API_KEY")
or get_secret("CO_API_KEY")
or litellm.api_key
)
api_base = (
api_base
or litellm.api_base
or get_secret("COHERE_API_BASE")
or "https://api.cohere.ai/v1/chat"
)
model_response = cohere_chat.completion(
model=model,
messages=messages,
api_base=api_base,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
api_key=cohere_key,
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
)
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
response = CustomStreamWrapper(
model_response,
model,
custom_llm_provider="cohere_chat",
logging_obj=logging,
)
return response
response = model_response
elif custom_llm_provider == "maritalk":
maritalk_key = (
api_key

View file

@ -981,35 +981,45 @@
"litellm_provider": "gemini",
"mode": "chat"
},
"command-nightly": {
"cohere_chat/command-r": {
"max_tokens": 128000,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000050,
"output_cost_per_token": 0.0000015,
"litellm_provider": "cohere_chat",
"mode": "chat"
},
"cohere_chat/command-light": {
"max_tokens": 4096,
"input_cost_per_token": 0.000015,
"output_cost_per_token": 0.000015,
"litellm_provider": "cohere_chat",
"mode": "chat"
},
"cohere/command-nightly": {
"max_tokens": 4096,
"input_cost_per_token": 0.000015,
"output_cost_per_token": 0.000015,
"litellm_provider": "cohere",
"mode": "completion"
},
"command": {
"cohere/command": {
"max_tokens": 4096,
"input_cost_per_token": 0.000015,
"output_cost_per_token": 0.000015,
"litellm_provider": "cohere",
"mode": "completion"
},
"command-light": {
"cohere/command-medium-beta": {
"max_tokens": 4096,
"input_cost_per_token": 0.000015,
"output_cost_per_token": 0.000015,
"litellm_provider": "cohere",
"mode": "completion"
},
"command-medium-beta": {
"max_tokens": 4096,
"input_cost_per_token": 0.000015,
"output_cost_per_token": 0.000015,
"litellm_provider": "cohere",
"mode": "completion"
},
"command-xlarge-beta": {
"cohere/command-xlarge-beta": {
"max_tokens": 4096,
"input_cost_per_token": 0.000015,
"output_cost_per_token": 0.000015,

View file

@ -1984,6 +1984,50 @@ def test_completion_cohere():
pytest.fail(f"Error occurred: {e}")
# FYI - cohere_chat looks quite unstable, even when testing locally
def test_chat_completion_cohere():
try:
litellm.set_verbose = True
messages = [
{"role": "system", "content": "You're a good bot"},
{
"role": "user",
"content": "Hey",
},
]
response = completion(
model="cohere_chat/command-r",
messages=messages,
max_tokens=10,
)
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_chat_completion_cohere_stream():
try:
litellm.set_verbose = False
messages = [
{"role": "system", "content": "You're a good bot"},
{
"role": "user",
"content": "Hey",
},
]
response = completion(
model="cohere_chat/command-r",
messages=messages,
max_tokens=10,
stream=True,
)
print(response)
for chunk in response:
print(chunk)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_azure_cloudflare_api():
litellm.set_verbose = True
try:

View file

@ -7411,7 +7411,9 @@ def exception_type(
model=model,
response=original_exception.response,
)
elif custom_llm_provider == "cohere": # Cohere
elif (
custom_llm_provider == "cohere" or custom_llm_provider == "cohere_chat"
): # Cohere
if (
"invalid api token" in error_str
or "No API key provided." in error_str
@ -8544,6 +8546,29 @@ class CustomStreamWrapper:
except:
raise ValueError(f"Unable to parse response. Original response: {chunk}")
def handle_cohere_chat_chunk(self, chunk):
chunk = chunk.decode("utf-8")
data_json = json.loads(chunk)
print_verbose(f"chunk: {chunk}")
try:
text = ""
is_finished = False
finish_reason = ""
if "text" in data_json:
text = data_json["text"]
elif "is_finished" in data_json and data_json["is_finished"] == True:
is_finished = data_json["is_finished"]
finish_reason = data_json["finish_reason"]
else:
return
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
except:
raise ValueError(f"Unable to parse response. Original response: {chunk}")
def handle_azure_chunk(self, chunk):
is_finished = False
finish_reason = ""
@ -9073,6 +9098,15 @@ class CustomStreamWrapper:
model_response.choices[0].finish_reason = response_obj[
"finish_reason"
]
elif self.custom_llm_provider == "cohere_chat":
response_obj = self.handle_cohere_chat_chunk(chunk)
if response_obj is None:
return
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[
"finish_reason"
]
elif self.custom_llm_provider == "bedrock":
if self.sent_last_chunk:
raise StopIteration

View file

@ -981,35 +981,45 @@
"litellm_provider": "gemini",
"mode": "chat"
},
"command-nightly": {
"cohere_chat/command-r": {
"max_tokens": 128000,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000050,
"output_cost_per_token": 0.0000015,
"litellm_provider": "cohere_chat",
"mode": "chat"
},
"cohere_chat/command-light": {
"max_tokens": 4096,
"input_cost_per_token": 0.000015,
"output_cost_per_token": 0.000015,
"litellm_provider": "cohere_chat",
"mode": "chat"
},
"cohere/command-nightly": {
"max_tokens": 4096,
"input_cost_per_token": 0.000015,
"output_cost_per_token": 0.000015,
"litellm_provider": "cohere",
"mode": "completion"
},
"command": {
"cohere/command": {
"max_tokens": 4096,
"input_cost_per_token": 0.000015,
"output_cost_per_token": 0.000015,
"litellm_provider": "cohere",
"mode": "completion"
},
"command-light": {
"cohere/command-medium-beta": {
"max_tokens": 4096,
"input_cost_per_token": 0.000015,
"output_cost_per_token": 0.000015,
"litellm_provider": "cohere",
"mode": "completion"
},
"command-medium-beta": {
"max_tokens": 4096,
"input_cost_per_token": 0.000015,
"output_cost_per_token": 0.000015,
"litellm_provider": "cohere",
"mode": "completion"
},
"command-xlarge-beta": {
"cohere/command-xlarge-beta": {
"max_tokens": 4096,
"input_cost_per_token": 0.000015,
"output_cost_per_token": 0.000015,