This commit is contained in:
sajda 2025-04-24 00:55:59 -07:00 committed by GitHub
commit af2d5a11b2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 2513 additions and 6 deletions

View file

@ -383,6 +383,7 @@ open_ai_chat_completion_models: List = []
open_ai_text_completion_models: List = []
cohere_models: List = []
cohere_chat_models: List = []
cohere_v2_models: List = []
mistral_chat_models: List = []
text_completion_codestral_models: List = []
anthropic_models: List = []
@ -480,6 +481,8 @@ def add_known_models():
cohere_models.append(key)
elif value.get("litellm_provider") == "cohere_chat":
cohere_chat_models.append(key)
elif value.get("litellm_provider") == "cohere_v2":
cohere_v2_models.append(key)
elif value.get("litellm_provider") == "mistral":
mistral_chat_models.append(key)
elif value.get("litellm_provider") == "anthropic":
@ -623,6 +626,7 @@ model_list = (
+ open_ai_text_completion_models
+ cohere_models
+ cohere_chat_models
+ cohere_v2_models
+ anthropic_models
+ replicate_models
+ openrouter_models
@ -674,8 +678,9 @@ provider_list: List[Union[LlmProviders, str]] = list(LlmProviders)
models_by_provider: dict = {
"openai": open_ai_chat_completion_models + open_ai_text_completion_models,
"text-completion-openai": open_ai_text_completion_models,
"cohere": cohere_models + cohere_chat_models,
"cohere": cohere_models + cohere_chat_models + cohere_v2_models,
"cohere_chat": cohere_chat_models,
"cohere_v2": cohere_v2_models,
"anthropic": anthropic_models,
"replicate": replicate_models,
"huggingface": huggingface_models,
@ -940,6 +945,7 @@ from .llms.bedrock.embed.amazon_titan_v2_transformation import (
AmazonTitanV2Config,
)
from .llms.cohere.chat.transformation import CohereChatConfig
from .llms.cohere.chat.transformation_v2 import CohereChatConfigV2
from .llms.bedrock.embed.cohere_transformation import BedrockCohereEmbeddingConfig
from .llms.openai.openai import OpenAIConfig, MistralEmbeddingConfig
from .llms.openai.image_variations.transformation import OpenAIImageVariationConfig

View file

@ -105,6 +105,7 @@ LITELLM_CHAT_PROVIDERS = [
"text-completion-openai",
"cohere",
"cohere_chat",
"cohere_v2",
"clarifai",
"anthropic",
"anthropic_text",

View file

@ -23,14 +23,16 @@ def _is_non_openai_azure_model(model: str) -> bool:
def handle_cohere_chat_model_custom_llm_provider(
model: str, custom_llm_provider: Optional[str] = None
model: str, custom_llm_provider: Optional[str] = None, api_version: Optional[str] = None
) -> Tuple[str, Optional[str]]:
"""
if user sets model = "cohere/command-r" -> use custom_llm_provider = "cohere_chat"
if api_version = "v2" -> use custom_llm_provider = "cohere_v2"
Args:
model:
custom_llm_provider:
model: The model name
custom_llm_provider: The custom LLM provider if specified
api_version: The API version (v1 or v2)
Returns:
model, custom_llm_provider
@ -38,6 +40,9 @@ def handle_cohere_chat_model_custom_llm_provider(
if custom_llm_provider:
if custom_llm_provider == "cohere" and model in litellm.cohere_chat_models:
# Check if v2 API version is specified
if api_version == "v2":
return model, "cohere_v2"
return model, "cohere_chat"
if "/" in model:
@ -47,6 +52,9 @@ def handle_cohere_chat_model_custom_llm_provider(
and _custom_llm_provider == "cohere"
and _model in litellm.cohere_chat_models
):
# Check if v2 API version is specified
if api_version == "v2":
return _model, "cohere_v2"
return _model, "cohere_chat"
return model, custom_llm_provider
@ -122,8 +130,23 @@ def get_llm_provider( # noqa: PLR0915
return model, custom_llm_provider, dynamic_api_key, api_base
### Handle cases when custom_llm_provider is set to cohere/command-r-plus but it should use cohere_chat route
# Extract api_version from optional_params if it exists
api_version = None
if litellm_params and hasattr(litellm_params, "optional_params") and litellm_params.optional_params:
api_version = litellm_params.optional_params.get("api_version")
# Handle direct cohere_v2 model format
if model.startswith("cohere_v2/"):
model = model.replace("cohere_v2/", "")
custom_llm_provider = "cohere_v2"
# For backward compatibility
elif model.startswith("cohere_v2/"):
model = model.replace("cohere_v2/", "")
custom_llm_provider = "cohere_v2"
model, custom_llm_provider = handle_cohere_chat_model_custom_llm_provider(
model, custom_llm_provider
model, custom_llm_provider, api_version
)
model, custom_llm_provider = handle_anthropic_text_model_custom_llm_provider(

View file

@ -2005,6 +2005,57 @@ def cohere_messages_pt_v2( # noqa: PLR0915
return returned_message, new_messages
def cohere_messages_pt_v3(messages: List, model: str, llm_provider: str):
"""
Format messages for Cohere v2 API
In v2, messages are combined in a single array with the following format:
[
{"role": "USER", "content": "Hello"},
{"role": "ASSISTANT", "content": "Hi there!"},
{"role": "USER", "content": "How are you?"}
]
Returns:
List of formatted messages in Cohere v2 format
"""
cohere_messages = []
for msg_i, message in enumerate(messages):
role = message["role"].upper()
# Map OpenAI roles to Cohere v2 roles
if role == "USER":
pass # Keep as USER
elif role == "ASSISTANT":
role = "CHATBOT" # Cohere v2 uses CHATBOT instead of ASSISTANT
elif role == "SYSTEM":
role = "USER" # System messages are sent as USER with a special prefix
message["content"] = f"<admin>{message['content']}</admin>"
elif role == "TOOL":
# Skip tool messages as they'll be handled separately with tool_results
continue
elif role == "FUNCTION":
# Skip function messages as they'll be handled separately with tool_results
continue
# Handle content
content = ""
if isinstance(message.get("content"), str):
content = message["content"]
elif isinstance(message.get("content"), list):
# Handle content list (text and images)
for item in message["content"]:
if isinstance(item, dict):
if item.get("type") == "text":
content += item.get("text", "")
# Add message to the list
cohere_messages.append({"role": role, "content": content})
return cohere_messages
def cohere_message_pt(messages: list):
tool_calls: List = get_all_tool_calls(messages=messages)
prompt = ""

View file

@ -0,0 +1,375 @@
"""Cohere Chat V2 API Integration Module.
This module provides the necessary classes and functions to interact with Cohere's V2 Chat API.
It handles the transformation of requests and responses between LiteLLM's standard format and
Cohere's specific API requirements.
"""
import json
import time
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union
import httpx
import litellm
from litellm.litellm_core_utils.prompt_templates.factory import cohere_messages_pt_v3
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse, Usage
# Use absolute imports instead of relative imports
from litellm.llms.cohere.common_utils import ModelResponseIterator as CohereModelResponseIterator
from litellm.llms.cohere.common_utils import validate_environment as cohere_validate_environment
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class CohereErrorV2(BaseLLMException):
"""
Exception class for Cohere v2 API errors.
This class handles errors returned by the Cohere v2 API and formats them
in a way that is consistent with the LiteLLM error handling system.
"""
def __init__(
self,
status_code: int,
message: str,
headers: Optional[httpx.Headers] = None,
):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://api.cohere.com/v2/chat")
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
status_code=status_code,
message=message,
headers=headers,
)
class CohereChatConfigV2(BaseConfig):
"""
Configuration class for Cohere's V2 API interface.
Args:
preamble (str, optional): When specified, the default Cohere preamble will be replaced
with the provided one.
generation_id (str, optional): Unique identifier for the generated reply.
conversation_id (str, optional): Creates or resumes a persisted conversation.
prompt_truncation (str, optional): Dictates how the prompt will be constructed.
Options: 'AUTO', 'AUTO_PRESERVE_ORDER', 'OFF'.
connectors (List[Dict[str, str]], optional): List of connectors (e.g., web-search)
to enrich the model's reply.
search_queries_only (bool, optional): When true, the response will only contain a list
of generated search queries.
documents (List[Dict[str, str]] or List[str], optional): A list of relevant documents
that the model can cite.
temperature (float, optional): A non-negative float that tunes the degree of randomness
in generation.
max_tokens (int, optional): The maximum number of tokens the model will generate as part
of the response.
k (int, optional): Ensures only the top k most likely tokens are considered for generation
at each step.
p (float, optional): Ensures that only the most likely tokens, with total probability mass
of p, are considered for generation.
frequency_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
presence_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
tools (List[Dict[str, str]], optional): A list of available tools (functions) that the model
may suggest invoking.
tool_results (List[Dict[str, Any]], optional): A list of results from invoking tools.
seed (int, optional): A seed to assist reproducibility of the model's response.
"""
preamble: Optional[str] = None
generation_id: Optional[str] = None
conversation_id: Optional[str] = None
prompt_truncation: Optional[str] = None
connectors: Optional[list] = None
search_queries_only: Optional[bool] = None
documents: Optional[list] = None
temperature: Optional[float] = None
max_tokens: Optional[int] = None
k: Optional[int] = None
p: Optional[float] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
tools: Optional[list] = None
tool_results: Optional[list] = None
seed: Optional[int] = None
def __init__(self, **kwargs) -> None:
"""
Initialize the CohereChatConfigV2 with parameters matching Cohere v2 API specification.
All parameters are passed as keyword arguments and set as class attributes
if they have a non-None value. This approach allows for future API changes
without requiring code modifications.
Args:
**kwargs: Arbitrary keyword arguments matching Cohere v2 API parameters.
See class docstring for details on supported parameters.
"""
# Process all keyword arguments and set as class attributes if not None
for key, value in kwargs.items():
if value is not None:
setattr(self.__class__, key, value)
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
# Use the api_key parameter directly
# litellm_params is required by the base class but not used by cohere_validate_environment
return cohere_validate_environment(
headers=headers,
model=model,
messages=messages,
optional_params=optional_params,
api_key=api_key,
api_version="v2" # Specify v2 API version
)
def get_supported_openai_params(self, model: str) -> List[str]:
return [
"stream",
"temperature",
"max_tokens",
"top_p",
"frequency_penalty",
"presence_penalty",
"stop",
"n",
"tools",
"tool_choice",
"seed",
"extra_headers",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for param, value in non_default_params.items():
if param == "stream":
optional_params["stream"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "max_tokens":
optional_params["max_tokens"] = value
if param == "n":
optional_params["num_generations"] = value
if param == "top_p":
optional_params["p"] = value
if param == "frequency_penalty":
optional_params["frequency_penalty"] = value
if param == "presence_penalty":
optional_params["presence_penalty"] = value
if param == "stop":
optional_params["stop_sequences"] = value
if param == "tools":
cohere_tools = self._construct_cohere_tool(tools=value)
optional_params["tools"] = cohere_tools
if param == "seed":
optional_params["seed"] = value
return optional_params
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
# Use the explicit parameters passed to the method
# These variables are used by the parent class implementation
## Load Config
for k, v in litellm.CohereChatConfigV2.get_config().items():
if (
k not in optional_params
): # completion(top_k=3) > cohere_config(top_k=3)
# Allows for dynamic variables to be passed in
optional_params[k] = v
# In v2, messages are combined in a single array
cohere_messages = cohere_messages_pt_v3(
messages=messages, model=model, llm_provider="cohere_chat"
)
optional_params["messages"] = cohere_messages
optional_params["model"] = model
## Tool Calling is now handled in map_openai_params
# Handle tool results if present
if "tool_results" in optional_params and isinstance(optional_params["tool_results"], list):
# Convert tool results to v2 format if needed
tool_results = []
for result in optional_params["tool_results"]:
if isinstance(result, dict) and "content" in result:
# Format from v1 to v2
tool_result = {
"tool_call_id": result.get("tool_call_id", ""),
"output": result.get("content", ""),
}
tool_results.append(tool_result)
else:
# Already in v2 format
tool_results.append(result)
optional_params["tool_results"] = tool_results
return optional_params
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
try:
raw_response_json = raw_response.json()
# Get the text content from the response
# Set the text content from the response
# Handle both regular and streaming choices
if hasattr(model_response.choices[0], 'message'):
model_response.choices[0].message.content = raw_response_json.get("text", "")
else:
# For streaming responses
model_response.choices[0].delta.content = raw_response_json.get("text", "")
except Exception as exc:
raise CohereErrorV2(
message=raw_response.text, status_code=raw_response.status_code
) from exc
## ADD CITATIONS
# Add citation information to the model response if available
if "citations" in raw_response_json:
citations = raw_response_json["citations"]
setattr(model_response, "citations", citations)
## Tool calling response
cohere_tools_response = raw_response_json.get("tool_calls", None)
if cohere_tools_response is not None and cohere_tools_response != []:
# convert cohere_tools_response to OpenAI response format
tool_calls = []
for tool in cohere_tools_response:
function_name = tool.get("name", "")
tool_call_id = tool.get("id", "")
parameters = tool.get("parameters", {})
tool_call = {
"id": tool_call_id,
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(parameters),
},
}
tool_calls.append(tool_call)
_message = litellm.Message(
tool_calls=tool_calls,
content=None,
)
model_response.choices[0].message = _message # type: ignore
## CALCULATING USAGE - use cohere `billed_units` for returning usage
billed_units = raw_response_json.get("usage", {})
prompt_tokens = billed_units.get("input_tokens", 0)
completion_tokens = billed_units.get("output_tokens", 0)
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def _construct_cohere_tool(
self,
tools: Optional[list] = None,
):
if tools is None:
tools = []
cohere_tools = []
for tool in tools:
cohere_tool = self._translate_openai_tool_to_cohere(tool)
cohere_tools.append(cohere_tool)
return cohere_tools
def _translate_openai_tool_to_cohere(
self,
openai_tool: dict,
):
"""
Translates OpenAI tool format to Cohere v2 tool format
Cohere v2 tools look like this:
{
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"input_schema": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["location"]
}
}
"""
cohere_tool = {
"name": openai_tool["function"]["name"],
"description": openai_tool["function"]["description"],
"input_schema": openai_tool["function"]["parameters"],
}
return cohere_tool
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
sync_stream: bool,
json_mode: Optional[bool] = False,
) -> Any:
return CohereModelResponseIterator(
streaming_response=streaming_response,
sync_stream=sync_stream,
json_mode=json_mode,
)
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return CohereErrorV2(status_code=status_code, message=error_message)

View file

@ -21,11 +21,15 @@ def validate_environment(
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_version: Optional[str] = "v1",
) -> dict:
"""
Return headers to use for cohere chat completion request
Cohere API Ref: https://docs.cohere.com/reference/chat
Cohere API Ref:
- v1: https://docs.cohere.com/reference/chat
- v2: https://docs.cohere.com/v2/reference/chat
Expected headers:
{
"Request-Source": "unspecified:litellm",

View file

@ -2143,6 +2143,46 @@ def completion( # type: ignore # noqa: PLR0915
api_key=cohere_key,
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
)
elif custom_llm_provider == "cohere_v2":
cohere_key = (
api_key
or litellm.cohere_key
or get_secret_str("COHERE_API_KEY")
or get_secret_str("CO_API_KEY")
or litellm.api_key
)
api_base = (
api_base
or litellm.api_base
or get_secret_str("COHERE_API_BASE")
or "https://api.cohere.ai/v2/chat"
)
headers = headers or litellm.headers or {}
if headers is None:
headers = {}
if extra_headers is not None:
headers.update(extra_headers)
response = base_llm_http_handler.completion(
model=model,
stream=stream,
messages=messages,
acompletion=acompletion,
api_base=api_base,
model_response=model_response,
optional_params=optional_params,
litellm_params=litellm_params,
custom_llm_provider="cohere_v2",
timeout=timeout,
headers=headers,
encoding=encoding,
api_key=cohere_key,
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
client=client,
)
elif custom_llm_provider == "maritalk":
maritalk_key = (
api_key

View file

@ -2040,6 +2040,7 @@ class LlmProviders(str, Enum):
TEXT_COMPLETION_OPENAI = "text-completion-openai"
COHERE = "cohere"
COHERE_CHAT = "cohere_chat"
COHERE_V2 = "cohere_v2"
CLARIFAI = "clarifai"
ANTHROPIC = "anthropic"
ANTHROPIC_TEXT = "anthropic_text"

View file

@ -6409,6 +6409,8 @@ class ProviderConfigManager:
return litellm.OpenAITextCompletionConfig()
elif litellm.LlmProviders.COHERE_CHAT == provider:
return litellm.CohereChatConfig()
elif litellm.LlmProviders.COHERE_V2 == provider:
return litellm.CohereChatConfigV2()
elif litellm.LlmProviders.COHERE == provider:
return litellm.CohereConfig()
elif litellm.LlmProviders.SNOWFLAKE == provider:

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,999 @@
import os
import sys
import traceback
from dotenv import load_dotenv
load_dotenv()
# For testing, make sure the COHERE_API_KEY or CO_API_KEY environment variable is set
# You can set it before running the tests with: export COHERE_API_KEY=your_api_key
import io
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import json
import pytest
import litellm
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
from unittest.mock import AsyncMock, patch
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
litellm.num_retries = 3
@pytest.mark.parametrize("stream", [True, False])
@pytest.mark.flaky(retries=3, delay=1)
@pytest.mark.asyncio
async def test_chat_completion_cohere_v2_citations(stream):
try:
class MockResponse:
def __init__(self, status_code, json_data, is_stream=False):
self.status_code = status_code
self._json_data = json_data
self.headers = {}
self.is_stream = is_stream
# For streaming responses with citations
if is_stream:
# Create streaming chunks with citations at the end
self._iter_content_chunks = [
json.dumps({"text": "Emperor"}).encode(),
json.dumps({"text": " penguins"}).encode(),
json.dumps({"text": " are"}).encode(),
json.dumps({"text": " the"}).encode(),
json.dumps({"text": " tallest"}).encode(),
json.dumps({"text": " and"}).encode(),
json.dumps({"text": " they"}).encode(),
json.dumps({"text": " live"}).encode(),
json.dumps({"text": " in"}).encode(),
json.dumps({"text": " Antarctica"}).encode(),
json.dumps({"text": "."}).encode(),
# Citations in a separate chunk
json.dumps({"citations": [
{
"start": 0,
"end": 30,
"text": "Emperor penguins are the tallest",
"document_ids": ["doc1"]
},
{
"start": 31,
"end": 70,
"text": "they live in Antarctica",
"document_ids": ["doc2"]
}
]}).encode(),
json.dumps({"finish_reason": "COMPLETE"}).encode(),
]
def json(self):
return self._json_data
@property
def text(self):
return json.dumps(self._json_data)
def iter_lines(self):
if self.is_stream:
for chunk in self._iter_content_chunks:
yield chunk
else:
yield json.dumps(self._json_data).encode()
async def aiter_lines(self):
if self.is_stream:
for chunk in self._iter_content_chunks:
yield chunk
else:
yield json.dumps(self._json_data).encode()
async def mock_async_post(*args, **kwargs):
# For asynchronous HTTP client
data = kwargs.get("data", "{}")
request_body = json.loads(data)
print("Async Request body:", request_body)
# Verify the messages are formatted correctly for v2
messages = request_body.get("messages", [])
assert len(messages) > 0
assert "role" in messages[0]
assert "content" in messages[0]
# Check if documents are included
documents = request_body.get("documents", [])
assert len(documents) > 0
# Mock response with citations
mock_response = {
"text": "Emperor penguins are the tallest penguins and they live in Antarctica.",
"generation_id": "mock-id",
"id": "mock-completion",
"usage": {"input_tokens": 10, "output_tokens": 20},
"citations": [
{
"start": 0,
"end": 30,
"text": "Emperor penguins are the tallest",
"document_ids": ["doc1"]
},
{
"start": 31,
"end": 70,
"text": "they live in Antarctica",
"document_ids": ["doc2"]
}
]
}
# Create a streaming response with citations
if stream:
return MockResponse(
200,
{
"text": "Emperor penguins are the tallest penguins and they live in Antarctica.",
"generation_id": "mock-id",
"id": "mock-completion",
"usage": {"input_tokens": 10, "output_tokens": 20},
"citations": [
{
"start": 0,
"end": 30,
"text": "Emperor penguins are the tallest",
"document_ids": ["doc1"]
},
{
"start": 31,
"end": 70,
"text": "they live in Antarctica",
"document_ids": ["doc2"]
}
],
"stream": True
},
is_stream=True
)
else:
return MockResponse(200, mock_response)
# Mock the async HTTP client
with patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", new_callable=AsyncMock, side_effect=mock_async_post):
litellm.set_verbose = True
messages = [
{
"role": "user",
"content": "Which penguins are the tallest?",
},
]
response = await litellm.acompletion(
model="cohere_chat_v2/command-r",
messages=messages,
stream=stream,
documents=[
{"title": "Tall penguins", "text": "Emperor penguins are the tallest."},
{
"title": "Penguin habitats",
"text": "Emperor penguins only live in Antarctica.",
},
],
)
if stream:
citations_chunk = False
async for chunk in response:
print("received chunk", chunk)
if hasattr(chunk, "citations") or (isinstance(chunk, dict) and "citations" in chunk):
citations_chunk = True
break
assert citations_chunk
else:
assert hasattr(response, "citations")
except litellm.ServiceUnavailableError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_completion_cohere_v2_command_r_plus_function_call():
litellm.set_verbose = True
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
messages = [
{
"role": "user",
"content": "What's the weather like in Boston today in Fahrenheit?",
}
]
try:
# test without max tokens
response = completion(
model="command-r-plus",
messages=messages,
tools=tools,
tool_choice="auto",
api_version="v2", # Specify v2 API version
)
# Add any assertions, here to check response args
print(response)
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
assert isinstance(
response.choices[0].message.tool_calls[0].function.arguments, str
)
messages.append(
response.choices[0].message.model_dump()
) # Add assistant tool invokes
tool_result = (
'{"location": "Boston", "temperature": "72", "unit": "fahrenheit"}'
)
# Add user submitted tool results in the OpenAI format
messages.append(
{
"tool_call_id": response.choices[0].message.tool_calls[0].id,
"role": "tool",
"name": response.choices[0].message.tool_calls[0].function.name,
"content": tool_result,
}
)
# In the second response, Cohere should deduce answer from tool results
second_response = completion(
model="command-r-plus",
messages=messages,
tools=tools,
tool_choice="auto",
force_single_step=True,
api_version="v2", # Specify v2 API version
)
print(second_response)
except litellm.Timeout:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.flaky(retries=6, delay=1)
def test_completion_cohere_v2():
try:
# litellm.set_verbose=True
messages = [
{"role": "system", "content": "You're a good bot"},
{
"role": "user",
"content": "Hey",
},
]
response = completion(
model="command-r",
messages=messages,
api_version="v2", # Specify v2 API version
)
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
@pytest.mark.parametrize("sync_mode", [True, False])
async def test_chat_completion_cohere_v2(sync_mode):
try:
class MockResponse:
def __init__(self, status_code, json_data, is_stream=False):
self.status_code = status_code
self._json_data = json_data
self.headers = {}
self.is_stream = is_stream
# For streaming responses with citations
if is_stream:
# Create streaming chunks with citations at the end
self._iter_content_chunks = [
json.dumps({"text": "Emperor"}).encode(),
json.dumps({"text": " penguins"}).encode(),
json.dumps({"text": " are"}).encode(),
json.dumps({"text": " the"}).encode(),
json.dumps({"text": " tallest"}).encode(),
json.dumps({"text": " and"}).encode(),
json.dumps({"text": " they"}).encode(),
json.dumps({"text": " live"}).encode(),
json.dumps({"text": " in"}).encode(),
json.dumps({"text": " Antarctica"}).encode(),
json.dumps({"text": "."}).encode(),
# Citations in a separate chunk
json.dumps({"citations": [
{
"start": 0,
"end": 30,
"text": "Emperor penguins are the tallest",
"document_ids": ["doc1"]
},
{
"start": 31,
"end": 70,
"text": "they live in Antarctica",
"document_ids": ["doc2"]
}
]}).encode(),
json.dumps({"finish_reason": "COMPLETE"}).encode(),
]
def json(self):
return self._json_data
@property
def text(self):
return json.dumps(self._json_data)
def iter_lines(self):
if self.is_stream:
for chunk in self._iter_content_chunks:
yield chunk
else:
yield json.dumps(self._json_data).encode()
async def aiter_lines(self):
if self.is_stream:
for chunk in self._iter_content_chunks:
yield chunk
else:
yield json.dumps(self._json_data).encode()
def mock_sync_post(*args, **kwargs):
# For synchronous HTTP client
data = kwargs.get("data", "{}")
request_body = json.loads(data)
print("Sync Request body:", request_body)
# Verify the model is passed correctly
assert request_body.get("model") == "command-r"
# Verify max_tokens is passed correctly
assert request_body.get("max_tokens") == 10
# Verify the messages are formatted correctly for v2
messages = request_body.get("messages", [])
assert len(messages) > 0
assert "role" in messages[0]
assert "content" in messages[0]
# Mock response
return MockResponse(
200,
{
"text": "This is a mocked response for sync request",
"generation_id": "mock-id",
"id": "mock-completion",
"usage": {"input_tokens": 10, "output_tokens": 20},
},
)
async def mock_async_post(*args, **kwargs):
# For asynchronous HTTP client
data = kwargs.get("data", "{}")
request_body = json.loads(data)
print("Async Request body:", request_body)
# Verify the model is passed correctly
assert request_body.get("model") == "command-r"
# Verify max_tokens is passed correctly
assert request_body.get("max_tokens") == 10
# Verify the messages are formatted correctly for v2
messages = request_body.get("messages", [])
assert len(messages) > 0
assert "role" in messages[0]
assert "content" in messages[0]
# Mock response
return MockResponse(
200,
{
"text": "This is a mocked response for async request",
"generation_id": "mock-id",
"id": "mock-completion",
"usage": {"input_tokens": 10, "output_tokens": 20},
},
)
# Mock both sync and async HTTP clients
with patch("litellm.llms.custom_httpx.http_handler.HTTPHandler.post", side_effect=mock_sync_post):
with patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", new_callable=AsyncMock, side_effect=mock_async_post):
litellm.set_verbose = True
messages = [
{"role": "system", "content": "You're a good bot"},
{
"role": "user",
"content": "Hey",
},
]
if sync_mode is False:
response = await litellm.acompletion(
model="cohere_chat_v2/command-r",
messages=messages,
max_tokens=10,
)
else:
response = completion(
model="cohere_chat_v2/command-r",
messages=messages,
max_tokens=10,
)
print(response)
assert response is not None
assert "This is a mocked response" in response.choices[0].message.content
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
@pytest.mark.parametrize("sync_mode", [False])
async def test_chat_completion_cohere_v2_stream(sync_mode):
try:
class MockResponse:
def __init__(self, status_code, json_data, is_stream=False):
self.status_code = status_code
self._json_data = json_data
self.headers = {}
self.is_stream = is_stream
# For streaming responses
if is_stream:
self._iter_content_chunks = [
json.dumps({"text": "This"}).encode(),
json.dumps({"text": " is"}).encode(),
json.dumps({"text": " a"}).encode(),
json.dumps({"text": " streamed"}).encode(),
json.dumps({"text": " response"}).encode(),
json.dumps({"text": "."}).encode(),
json.dumps({"finish_reason": "COMPLETE"}).encode(),
]
def json(self):
return self._json_data
@property
def text(self):
return json.dumps(self._json_data)
def iter_lines(self):
if self.is_stream:
for chunk in self._iter_content_chunks:
yield chunk
else:
yield json.dumps(self._json_data).encode()
async def aiter_lines(self):
if self.is_stream:
for chunk in self._iter_content_chunks:
yield chunk
else:
yield json.dumps(self._json_data).encode()
async def mock_async_post(*args, **kwargs):
# For asynchronous HTTP client
data = kwargs.get("data", "{}")
request_body = json.loads(data)
print("Async Request body:", request_body)
# Verify the model is passed correctly
assert request_body.get("model") == "command-r"
# Verify max_tokens is passed correctly
assert request_body.get("max_tokens") == 10
# Verify stream is set to True
assert request_body.get("stream") == True
# Verify the messages are formatted correctly for v2
messages = request_body.get("messages", [])
assert len(messages) > 0
assert "role" in messages[0]
assert "content" in messages[0]
# Return a streaming response
return MockResponse(
200,
{
"text": "This is a streamed response.",
"generation_id": "mock-id",
"id": "mock-completion",
"usage": {"input_tokens": 10, "output_tokens": 20},
},
is_stream=True
)
# Mock the async HTTP client for streaming
with patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", new_callable=AsyncMock, side_effect=mock_async_post):
litellm.set_verbose = True
messages = [
{"role": "system", "content": "You're a good bot"},
{
"role": "user",
"content": "Hey",
},
]
if sync_mode is False:
response = await litellm.acompletion(
model="cohere_chat_v2/command-r",
messages=messages,
stream=True,
max_tokens=10,
)
# Verify we get streaming chunks
chunk_count = 0
async for chunk in response:
print(f"chunk: {chunk}")
chunk_count += 1
assert chunk_count > 0, "No streaming chunks were received"
else:
# This test is only for async mode
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_cohere_v2_mock_completion():
"""
Test cohere_chat_v2 completion with mocked responses to avoid API calls
"""
try:
import httpx
class MockResponse:
def __init__(self, status_code, json_data, is_stream=False):
self.status_code = status_code
self._json_data = json_data
self.headers = {}
self.is_stream = is_stream
# For streaming responses with citations
if is_stream:
# Create streaming chunks with citations at the end
self._iter_content_chunks = [
json.dumps({"text": "Emperor"}).encode(),
json.dumps({"text": " penguins"}).encode(),
json.dumps({"text": " are"}).encode(),
json.dumps({"text": " the"}).encode(),
json.dumps({"text": " tallest"}).encode(),
json.dumps({"text": " and"}).encode(),
json.dumps({"text": " they"}).encode(),
json.dumps({"text": " live"}).encode(),
json.dumps({"text": " in"}).encode(),
json.dumps({"text": " Antarctica"}).encode(),
json.dumps({"text": "."}).encode(),
# Citations in a separate chunk
json.dumps({"citations": [
{
"start": 0,
"end": 30,
"text": "Emperor penguins are the tallest",
"document_ids": ["doc1"]
},
{
"start": 31,
"end": 70,
"text": "they live in Antarctica",
"document_ids": ["doc2"]
}
]}).encode(),
json.dumps({"finish_reason": "COMPLETE"}).encode(),
]
def json(self):
return self._json_data
@property
def text(self):
return json.dumps(self._json_data)
def iter_lines(self):
if self.is_stream:
for chunk in self._iter_content_chunks:
yield chunk
else:
yield json.dumps(self._json_data).encode()
async def aiter_lines(self):
if self.is_stream:
for chunk in self._iter_content_chunks:
yield chunk
else:
yield json.dumps(self._json_data).encode()
def mock_sync_post(*args, **kwargs):
# For synchronous HTTP client
data = kwargs.get("data", "{}")
request_body = json.loads(data)
print("Sync Request body:", request_body)
# Verify the messages are formatted correctly for v2
messages = request_body.get("messages", [])
assert len(messages) > 0
assert "role" in messages[0]
assert "content" in messages[0]
# Mock response
return MockResponse(
200,
{
"text": "This is a mocked response from Cohere v2 API",
"generation_id": "mock-id",
"id": "mock-completion",
"usage": {"input_tokens": 10, "output_tokens": 20},
},
)
async def mock_async_post(*args, **kwargs):
# For asynchronous HTTP client
data = kwargs.get("data", "{}")
request_body = json.loads(data)
print("Async Request body:", request_body)
# Verify the messages are formatted correctly for v2
messages = request_body.get("messages", [])
assert len(messages) > 0
assert "role" in messages[0]
assert "content" in messages[0]
# Mock response
return MockResponse(
200,
{
"text": "This is a mocked response from Cohere v2 API",
"generation_id": "mock-id",
"id": "mock-completion",
"usage": {"input_tokens": 10, "output_tokens": 20},
},
)
# Mock both sync and async HTTP clients
with patch("litellm.llms.custom_httpx.http_handler.HTTPHandler.post", side_effect=mock_sync_post):
with patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", new_callable=AsyncMock, side_effect=mock_async_post):
litellm.set_verbose = True
messages = [{"role": "user", "content": "Hello from mock test"}]
response = completion(
model="cohere_chat_v2/command-r",
messages=messages,
)
assert response is not None
assert "This is a mocked response" in response.choices[0].message.content
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_cohere_v2_request_body_with_allowed_params():
"""
Test to validate that when allowed_openai_params is provided, the request body contains
the correct response_format and reasoning_effort values.
"""
try:
import httpx
class MockResponse:
def __init__(self, status_code, json_data, is_stream=False):
self.status_code = status_code
self._json_data = json_data
self.headers = {}
self.is_stream = is_stream
# For streaming responses with citations
if is_stream:
# Create streaming chunks with citations at the end
self._iter_content_chunks = [
json.dumps({"text": "Emperor"}).encode(),
json.dumps({"text": " penguins"}).encode(),
json.dumps({"text": " are"}).encode(),
json.dumps({"text": " the"}).encode(),
json.dumps({"text": " tallest"}).encode(),
json.dumps({"text": " and"}).encode(),
json.dumps({"text": " they"}).encode(),
json.dumps({"text": " live"}).encode(),
json.dumps({"text": " in"}).encode(),
json.dumps({"text": " Antarctica"}).encode(),
json.dumps({"text": "."}).encode(),
# Citations in a separate chunk
json.dumps({"citations": [
{
"start": 0,
"end": 30,
"text": "Emperor penguins are the tallest",
"document_ids": ["doc1"]
},
{
"start": 31,
"end": 70,
"text": "they live in Antarctica",
"document_ids": ["doc2"]
}
]}).encode(),
json.dumps({"finish_reason": "COMPLETE"}).encode(),
]
def json(self):
return self._json_data
@property
def text(self):
return json.dumps(self._json_data)
def iter_lines(self):
if self.is_stream:
for chunk in self._iter_content_chunks:
yield chunk
else:
yield json.dumps(self._json_data).encode()
async def aiter_lines(self):
if self.is_stream:
for chunk in self._iter_content_chunks:
yield chunk
else:
yield json.dumps(self._json_data).encode()
def mock_sync_post(*args, **kwargs):
# For synchronous HTTP client
data = kwargs.get("data", "{}")
request_body = json.loads(data)
print("Sync Request body:", request_body)
# Verify the model is passed correctly
assert request_body.get("model") == "command-r"
# Verify the messages are formatted correctly for v2
messages = request_body.get("messages", [])
assert len(messages) > 0
assert "role" in messages[0]
assert "content" in messages[0]
# Mock response
return MockResponse(
200,
{
"text": "This is a test response",
"generation_id": "test-id",
"id": "test",
"usage": {"input_tokens": 10, "output_tokens": 20},
},
)
async def mock_async_post(*args, **kwargs):
# For asynchronous HTTP client
data = kwargs.get("data", "{}")
request_body = json.loads(data)
print("Async Request body:", request_body)
# Verify the messages are formatted correctly for v2
messages = request_body.get("messages", [])
assert len(messages) > 0
assert "role" in messages[0]
assert "content" in messages[0]
# Mock response
return MockResponse(
200,
{
"text": "This is a test response",
"generation_id": "test-id",
"id": "test",
"usage": {"input_tokens": 10, "output_tokens": 20},
},
)
# Mock both sync and async HTTP clients
with patch("litellm.llms.custom_httpx.http_handler.HTTPHandler.post", side_effect=mock_sync_post):
with patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", new_callable=AsyncMock, side_effect=mock_async_post):
litellm.set_verbose = True
messages = [{"role": "user", "content": "Hello"}]
response = completion(
model="cohere_chat_v2/command-r",
messages=messages,
)
assert response is not None
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
async def test_chat_completion_cohere_v2_streaming_citations():
"""
Test specifically for streaming with citations in Cohere v2
"""
try:
class MockResponse:
def __init__(self, status_code, json_data, is_stream=False):
self.status_code = status_code
self._json_data = json_data
self.headers = {}
self.is_stream = is_stream
# For streaming responses with citations
if is_stream:
# Create streaming chunks with citations at the end
self._iter_content_chunks = [
json.dumps({"text": "Emperor"}).encode(),
json.dumps({"text": " penguins"}).encode(),
json.dumps({"text": " are"}).encode(),
json.dumps({"text": " the"}).encode(),
json.dumps({"text": " tallest"}).encode(),
json.dumps({"text": " and"}).encode(),
json.dumps({"text": " they"}).encode(),
json.dumps({"text": " live"}).encode(),
json.dumps({"text": " in"}).encode(),
json.dumps({"text": " Antarctica"}).encode(),
json.dumps({"text": "."}).encode(),
# Citations in a separate chunk
json.dumps({"citations": [
{
"start": 0,
"end": 30,
"text": "Emperor penguins are the tallest",
"document_ids": ["doc1"]
},
{
"start": 31,
"end": 70,
"text": "they live in Antarctica",
"document_ids": ["doc2"]
}
]}).encode(),
json.dumps({"finish_reason": "COMPLETE"}).encode(),
]
def json(self):
return self._json_data
@property
def text(self):
return json.dumps(self._json_data)
def iter_lines(self):
if self.is_stream:
for chunk in self._iter_content_chunks:
yield chunk
else:
yield json.dumps(self._json_data).encode()
async def aiter_lines(self):
if self.is_stream:
for chunk in self._iter_content_chunks:
yield chunk
else:
yield json.dumps(self._json_data).encode()
async def mock_async_post(*args, **kwargs):
# For asynchronous HTTP client
data = kwargs.get("data", "{}")
request_body = json.loads(data)
print("Async Request body:", request_body)
# Verify the messages are formatted correctly for v2
messages = request_body.get("messages", [])
assert len(messages) > 0
assert "role" in messages[0]
assert "content" in messages[0]
# Check if documents are included
documents = request_body.get("documents", [])
assert len(documents) > 0
# Verify stream is set to True
assert request_body.get("stream") == True
# Return a streaming response with citations
return MockResponse(
200,
{
"text": "Emperor penguins are the tallest penguins and they live in Antarctica.",
"generation_id": "mock-id",
"id": "mock-completion",
"usage": {"input_tokens": 10, "output_tokens": 20},
},
is_stream=True
)
# Mock the async HTTP client
with patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", new_callable=AsyncMock, side_effect=mock_async_post):
litellm.set_verbose = True
messages = [
{
"role": "user",
"content": "Which penguins are the tallest?",
},
]
response = await litellm.acompletion(
model="cohere_chat_v2/command-r",
messages=messages,
stream=True,
documents=[
{"title": "Tall penguins", "text": "Emperor penguins are the tallest."},
{
"title": "Penguin habitats",
"text": "Emperor penguins only live in Antarctica.",
},
],
)
# Verify we get streaming chunks with citations
citations_chunk = False
async for chunk in response:
print("received chunk", chunk)
if hasattr(chunk, "citations") or (isinstance(chunk, dict) and "citations" in chunk):
citations_chunk = True
break
assert citations_chunk, "No citations chunk was received"
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.skip(reason="Only run this test when you want to test with a real API key")
@pytest.mark.asyncio
async def test_cohere_v2_real_api_call():
"""
Test for making a real API call to Cohere v2. This test is skipped by default.
To run this test, remove the skip mark and ensure you have a valid Cohere API key.
"""
try:
# Set the API key from environment variable
os.environ["CO_API_KEY"] = "LitgtFBRwgpnyF5KAaJINtLNJkx5Ty6LsFVV1IYM" # Using the provided API key
litellm.set_verbose = True
messages = [
{
"role": "user",
"content": "What is the capital of France?",
},
]
# Make a real API call
response = await litellm.acompletion(
model="cohere_chat_v2/command-r",
messages=messages,
max_tokens=100,
)
print("Real API Response:", response)
assert response is not None
assert response.choices[0].message.content is not None
assert len(response.choices[0].message.content) > 0
# Test streaming with real API
stream_response = await litellm.acompletion(
model="cohere_chat_v2/command-r",
messages=messages,
stream=True,
max_tokens=100,
)
# Verify we get streaming chunks
chunk_count = 0
async for chunk in stream_response:
print(f"Stream chunk: {chunk}")
chunk_count += 1
if chunk_count > 5: # Just check a few chunks to avoid long test
break
assert chunk_count > 0, "No streaming chunks were received"
except Exception as e:
pytest.fail(f"Error occurred with real API call: {e}")