diff --git a/docs/my-website/docs/providers/cohere.md b/docs/my-website/docs/providers/cohere.md index 980143770..c6efb3b40 100644 --- a/docs/my-website/docs/providers/cohere.md +++ b/docs/my-website/docs/providers/cohere.md @@ -17,7 +17,7 @@ os.environ["COHERE_API_KEY"] = "cohere key" # cohere call response = completion( - model="command-nightly", + model="command-r", messages = [{ "content": "Hello, how are you?","role": "user"}] ) ``` @@ -32,7 +32,7 @@ os.environ["COHERE_API_KEY"] = "cohere key" # cohere call response = completion( - model="command-nightly", + model="command-r", messages = [{ "content": "Hello, how are you?","role": "user"}], stream=True ) @@ -41,7 +41,17 @@ for chunk in response: print(chunk) ``` -LiteLLM supports 'command', 'command-light', 'command-medium', 'command-medium-beta', 'command-xlarge-beta', 'command-nightly' models from [Cohere](https://cohere.com/). + +## Supported Models +| Model Name | Function Call | +|------------|----------------| +| command-r | `completion('command-r', messages)` | +| command-light | `completion('command-light', messages)` | +| command-medium | `completion('command-medium', messages)` | +| command-medium-beta | `completion('command-medium-beta', messages)` | +| command-xlarge-beta | `completion('command-xlarge-beta', messages)` | +| command-nightly | `completion('command-nightly', messages)` | + ## Embedding diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 62a2d3842..44c4a30f4 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -131,6 +131,7 @@ const sidebars = { "providers/anthropic", "providers/aws_sagemaker", "providers/bedrock", + "providers/cohere", "providers/anyscale", "providers/huggingface", "providers/ollama", @@ -143,7 +144,6 @@ const sidebars = { "providers/ai21", "providers/nlp_cloud", "providers/replicate", - "providers/cohere", "providers/togetherai", "providers/voyage", "providers/aleph_alpha", diff --git a/litellm/__init__.py b/litellm/__init__.py index 2e3110aa4..a821bde30 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -252,6 +252,7 @@ config_path = None open_ai_chat_completion_models: List = [] open_ai_text_completion_models: List = [] cohere_models: List = [] +cohere_chat_models: List = [] anthropic_models: List = [] openrouter_models: List = [] vertex_language_models: List = [] @@ -274,6 +275,8 @@ for key, value in model_cost.items(): open_ai_text_completion_models.append(key) elif value.get("litellm_provider") == "cohere": cohere_models.append(key) + elif value.get("litellm_provider") == "cohere_chat": + cohere_chat_models.append(key) elif value.get("litellm_provider") == "anthropic": anthropic_models.append(key) elif value.get("litellm_provider") == "openrouter": @@ -421,6 +424,7 @@ model_list = ( open_ai_chat_completion_models + open_ai_text_completion_models + cohere_models + + cohere_chat_models + anthropic_models + replicate_models + openrouter_models @@ -444,6 +448,7 @@ provider_list: List = [ "custom_openai", "text-completion-openai", "cohere", + "cohere_chat", "anthropic", "replicate", "huggingface", @@ -479,6 +484,7 @@ provider_list: List = [ models_by_provider: dict = { "openai": open_ai_chat_completion_models + open_ai_text_completion_models, "cohere": cohere_models, + "cohere_chat": cohere_chat_models, "anthropic": anthropic_models, "replicate": replicate_models, "huggingface": huggingface_models, diff --git a/litellm/llms/cohere_chat.py b/litellm/llms/cohere_chat.py new file mode 100644 index 000000000..9027572e6 --- /dev/null +++ b/litellm/llms/cohere_chat.py @@ -0,0 +1,204 @@ +import os, types +import json +from enum import Enum +import requests +import time, traceback +from typing import Callable, Optional +from litellm.utils import ModelResponse, Choices, Message, Usage +import litellm +import httpx + + +class CohereError(Exception): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + self.request = httpx.Request(method="POST", url="https://api.cohere.ai/v1/chat") + self.response = httpx.Response(status_code=status_code, request=self.request) + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs + + +class CohereChatConfig: + """ + Configuration class for Cohere's API interface. + + Args: + preamble (str, optional): When specified, the default Cohere preamble will be replaced with the provided one. + chat_history (List[Dict[str, str]], optional): A list of previous messages between the user and the model. + generation_id (str, optional): Unique identifier for the generated reply. + response_id (str, optional): Unique identifier for the response. + conversation_id (str, optional): An alternative to chat_history, creates or resumes a persisted conversation. + prompt_truncation (str, optional): Dictates how the prompt will be constructed. Options: 'AUTO', 'AUTO_PRESERVE_ORDER', 'OFF'. + connectors (List[Dict[str, str]], optional): List of connectors (e.g., web-search) to enrich the model's reply. + search_queries_only (bool, optional): When true, the response will only contain a list of generated search queries. + documents (List[Dict[str, str]], optional): A list of relevant documents that the model can cite. + temperature (float, optional): A non-negative float that tunes the degree of randomness in generation. + max_tokens (int, optional): The maximum number of tokens the model will generate as part of the response. + k (int, optional): Ensures only the top k most likely tokens are considered for generation at each step. + p (float, optional): Ensures that only the most likely tokens, with total probability mass of p, are considered for generation. + frequency_penalty (float, optional): Used to reduce repetitiveness of generated tokens. + presence_penalty (float, optional): Used to reduce repetitiveness of generated tokens. + tools (List[Dict[str, str]], optional): A list of available tools (functions) that the model may suggest invoking. + tool_results (List[Dict[str, Any]], optional): A list of results from invoking tools. + """ + + preamble: Optional[str] = None + chat_history: Optional[list] = None + generation_id: Optional[str] = None + response_id: Optional[str] = None + conversation_id: Optional[str] = None + prompt_truncation: Optional[str] = None + connectors: Optional[list] = None + search_queries_only: Optional[bool] = None + documents: Optional[list] = None + temperature: Optional[int] = None + max_tokens: Optional[int] = None + k: Optional[int] = None + p: Optional[int] = None + frequency_penalty: Optional[int] = None + presence_penalty: Optional[int] = None + tools: Optional[list] = None + tool_results: Optional[list] = None + + def __init__( + self, + preamble: Optional[str] = None, + chat_history: Optional[list] = None, + generation_id: Optional[str] = None, + response_id: Optional[str] = None, + conversation_id: Optional[str] = None, + prompt_truncation: Optional[str] = None, + connectors: Optional[list] = None, + search_queries_only: Optional[bool] = None, + documents: Optional[list] = None, + temperature: Optional[int] = None, + max_tokens: Optional[int] = None, + k: Optional[int] = None, + p: Optional[int] = None, + frequency_penalty: Optional[int] = None, + presence_penalty: Optional[int] = None, + tools: Optional[list] = None, + tool_results: Optional[list] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + +def validate_environment(api_key): + headers = { + "accept": "application/json", + "content-type": "application/json", + } + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + return headers + + +def completion( + model: str, + messages: list, + api_base: str, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + optional_params=None, + litellm_params=None, + logger_fn=None, +): + headers = validate_environment(api_key) + completion_url = api_base + model = model + prompt = " ".join(message["content"] for message in messages) + + ## Load Config + config = litellm.CohereConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + + data = { + "model": model, + "message": prompt, + **optional_params, + } + + ## LOGGING + logging_obj.pre_call( + input=prompt, + api_key=api_key, + additional_args={ + "complete_input_dict": data, + "headers": headers, + "api_base": completion_url, + }, + ) + ## COMPLETION CALL + response = requests.post( + completion_url, + headers=headers, + data=json.dumps(data), + stream=optional_params["stream"] if "stream" in optional_params else False, + ) + ## error handling for cohere calls + if response.status_code != 200: + raise CohereError(message=response.text, status_code=response.status_code) + + if "stream" in optional_params and optional_params["stream"] == True: + return response.iter_lines() + else: + ## LOGGING + logging_obj.post_call( + input=prompt, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) + print_verbose(f"raw model_response: {response.text}") + ## RESPONSE OBJECT + completion_response = response.json() + try: + model_response.choices[0].message.content = completion_response["text"] # type: ignore + except Exception as e: + raise CohereError(message=response.text, status_code=response.status_code) + + ## CALCULATING USAGE - use cohere `billed_units` for returning usage + billed_units = completion_response.get("meta", {}).get("billed_units", {}) + + prompt_tokens = billed_units.get("input_tokens", 0) + completion_tokens = billed_units.get("output_tokens", 0) + + model_response["created"] = int(time.time()) + model_response["model"] = model + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + model_response.usage = usage + return model_response diff --git a/litellm/main.py b/litellm/main.py index 06ba78c2b..509830631 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -55,6 +55,7 @@ from .llms import ( ollama_chat, cloudflare, cohere, + cohere_chat, petals, oobabooga, openrouter, @@ -1287,6 +1288,46 @@ def completion( ) return response response = model_response + elif custom_llm_provider == "cohere_chat": + cohere_key = ( + api_key + or litellm.cohere_key + or get_secret("COHERE_API_KEY") + or get_secret("CO_API_KEY") + or litellm.api_key + ) + + api_base = ( + api_base + or litellm.api_base + or get_secret("COHERE_API_BASE") + or "https://api.cohere.ai/v1/chat" + ) + + model_response = cohere_chat.completion( + model=model, + messages=messages, + api_base=api_base, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + api_key=cohere_key, + logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements + ) + + if "stream" in optional_params and optional_params["stream"] == True: + # don't try to access stream object, + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="cohere_chat", + logging_obj=logging, + ) + return response + response = model_response elif custom_llm_provider == "maritalk": maritalk_key = ( api_key diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 18c4b0d9a..55762982f 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -981,35 +981,45 @@ "litellm_provider": "gemini", "mode": "chat" }, - "command-nightly": { + + "cohere_chat/command-r": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.00000050, + "output_cost_per_token": 0.0000015, + "litellm_provider": "cohere_chat", + "mode": "chat" + }, + "cohere_chat/command-light": { + "max_tokens": 4096, + "input_cost_per_token": 0.000015, + "output_cost_per_token": 0.000015, + "litellm_provider": "cohere_chat", + "mode": "chat" + }, + "cohere/command-nightly": { "max_tokens": 4096, "input_cost_per_token": 0.000015, "output_cost_per_token": 0.000015, "litellm_provider": "cohere", "mode": "completion" }, - "command": { + "cohere/command": { "max_tokens": 4096, "input_cost_per_token": 0.000015, "output_cost_per_token": 0.000015, "litellm_provider": "cohere", "mode": "completion" }, - "command-light": { + "cohere/command-medium-beta": { "max_tokens": 4096, "input_cost_per_token": 0.000015, "output_cost_per_token": 0.000015, "litellm_provider": "cohere", "mode": "completion" }, - "command-medium-beta": { - "max_tokens": 4096, - "input_cost_per_token": 0.000015, - "output_cost_per_token": 0.000015, - "litellm_provider": "cohere", - "mode": "completion" - }, - "command-xlarge-beta": { + "cohere/command-xlarge-beta": { "max_tokens": 4096, "input_cost_per_token": 0.000015, "output_cost_per_token": 0.000015, diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 2b372f57a..b86d16c56 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -1984,6 +1984,50 @@ def test_completion_cohere(): pytest.fail(f"Error occurred: {e}") +# FYI - cohere_chat looks quite unstable, even when testing locally +def test_chat_completion_cohere(): + try: + litellm.set_verbose = True + messages = [ + {"role": "system", "content": "You're a good bot"}, + { + "role": "user", + "content": "Hey", + }, + ] + response = completion( + model="cohere_chat/command-r", + messages=messages, + max_tokens=10, + ) + print(response) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + +def test_chat_completion_cohere_stream(): + try: + litellm.set_verbose = False + messages = [ + {"role": "system", "content": "You're a good bot"}, + { + "role": "user", + "content": "Hey", + }, + ] + response = completion( + model="cohere_chat/command-r", + messages=messages, + max_tokens=10, + stream=True, + ) + print(response) + for chunk in response: + print(chunk) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + def test_azure_cloudflare_api(): litellm.set_verbose = True try: diff --git a/litellm/utils.py b/litellm/utils.py index 262935faa..3c1bf989a 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -7411,7 +7411,9 @@ def exception_type( model=model, response=original_exception.response, ) - elif custom_llm_provider == "cohere": # Cohere + elif ( + custom_llm_provider == "cohere" or custom_llm_provider == "cohere_chat" + ): # Cohere if ( "invalid api token" in error_str or "No API key provided." in error_str @@ -8544,6 +8546,29 @@ class CustomStreamWrapper: except: raise ValueError(f"Unable to parse response. Original response: {chunk}") + def handle_cohere_chat_chunk(self, chunk): + chunk = chunk.decode("utf-8") + data_json = json.loads(chunk) + print_verbose(f"chunk: {chunk}") + try: + text = "" + is_finished = False + finish_reason = "" + if "text" in data_json: + text = data_json["text"] + elif "is_finished" in data_json and data_json["is_finished"] == True: + is_finished = data_json["is_finished"] + finish_reason = data_json["finish_reason"] + else: + return + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } + except: + raise ValueError(f"Unable to parse response. Original response: {chunk}") + def handle_azure_chunk(self, chunk): is_finished = False finish_reason = "" @@ -9073,6 +9098,15 @@ class CustomStreamWrapper: model_response.choices[0].finish_reason = response_obj[ "finish_reason" ] + elif self.custom_llm_provider == "cohere_chat": + response_obj = self.handle_cohere_chat_chunk(chunk) + if response_obj is None: + return + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj[ + "finish_reason" + ] elif self.custom_llm_provider == "bedrock": if self.sent_last_chunk: raise StopIteration diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 18c4b0d9a..55762982f 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -981,35 +981,45 @@ "litellm_provider": "gemini", "mode": "chat" }, - "command-nightly": { + + "cohere_chat/command-r": { + "max_tokens": 128000, + "max_input_tokens": 128000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.00000050, + "output_cost_per_token": 0.0000015, + "litellm_provider": "cohere_chat", + "mode": "chat" + }, + "cohere_chat/command-light": { + "max_tokens": 4096, + "input_cost_per_token": 0.000015, + "output_cost_per_token": 0.000015, + "litellm_provider": "cohere_chat", + "mode": "chat" + }, + "cohere/command-nightly": { "max_tokens": 4096, "input_cost_per_token": 0.000015, "output_cost_per_token": 0.000015, "litellm_provider": "cohere", "mode": "completion" }, - "command": { + "cohere/command": { "max_tokens": 4096, "input_cost_per_token": 0.000015, "output_cost_per_token": 0.000015, "litellm_provider": "cohere", "mode": "completion" }, - "command-light": { + "cohere/command-medium-beta": { "max_tokens": 4096, "input_cost_per_token": 0.000015, "output_cost_per_token": 0.000015, "litellm_provider": "cohere", "mode": "completion" }, - "command-medium-beta": { - "max_tokens": 4096, - "input_cost_per_token": 0.000015, - "output_cost_per_token": 0.000015, - "litellm_provider": "cohere", - "mode": "completion" - }, - "command-xlarge-beta": { + "cohere/command-xlarge-beta": { "max_tokens": 4096, "input_cost_per_token": 0.000015, "output_cost_per_token": 0.000015,