mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
feat: cohore_v2 support
This commit is contained in:
parent
ff3a6830a4
commit
ac9c9bb729
10 changed files with 1480 additions and 6 deletions
|
@ -368,6 +368,7 @@ open_ai_chat_completion_models: List = []
|
||||||
open_ai_text_completion_models: List = []
|
open_ai_text_completion_models: List = []
|
||||||
cohere_models: List = []
|
cohere_models: List = []
|
||||||
cohere_chat_models: List = []
|
cohere_chat_models: List = []
|
||||||
|
cohere_chat_v2_models: List = []
|
||||||
mistral_chat_models: List = []
|
mistral_chat_models: List = []
|
||||||
text_completion_codestral_models: List = []
|
text_completion_codestral_models: List = []
|
||||||
anthropic_models: List = []
|
anthropic_models: List = []
|
||||||
|
@ -464,6 +465,8 @@ def add_known_models():
|
||||||
cohere_models.append(key)
|
cohere_models.append(key)
|
||||||
elif value.get("litellm_provider") == "cohere_chat":
|
elif value.get("litellm_provider") == "cohere_chat":
|
||||||
cohere_chat_models.append(key)
|
cohere_chat_models.append(key)
|
||||||
|
elif value.get("litellm_provider") == "cohere_chat_v2":
|
||||||
|
cohere_chat_v2_models.append(key)
|
||||||
elif value.get("litellm_provider") == "mistral":
|
elif value.get("litellm_provider") == "mistral":
|
||||||
mistral_chat_models.append(key)
|
mistral_chat_models.append(key)
|
||||||
elif value.get("litellm_provider") == "anthropic":
|
elif value.get("litellm_provider") == "anthropic":
|
||||||
|
@ -605,6 +608,7 @@ model_list = (
|
||||||
+ open_ai_text_completion_models
|
+ open_ai_text_completion_models
|
||||||
+ cohere_models
|
+ cohere_models
|
||||||
+ cohere_chat_models
|
+ cohere_chat_models
|
||||||
|
+ cohere_chat_v2_models
|
||||||
+ anthropic_models
|
+ anthropic_models
|
||||||
+ replicate_models
|
+ replicate_models
|
||||||
+ openrouter_models
|
+ openrouter_models
|
||||||
|
@ -655,8 +659,9 @@ provider_list: List[Union[LlmProviders, str]] = list(LlmProviders)
|
||||||
models_by_provider: dict = {
|
models_by_provider: dict = {
|
||||||
"openai": open_ai_chat_completion_models + open_ai_text_completion_models,
|
"openai": open_ai_chat_completion_models + open_ai_text_completion_models,
|
||||||
"text-completion-openai": open_ai_text_completion_models,
|
"text-completion-openai": open_ai_text_completion_models,
|
||||||
"cohere": cohere_models + cohere_chat_models,
|
"cohere": cohere_models + cohere_chat_models + cohere_chat_v2_models,
|
||||||
"cohere_chat": cohere_chat_models,
|
"cohere_chat": cohere_chat_models,
|
||||||
|
"cohere_chat_v2": cohere_chat_v2_models,
|
||||||
"anthropic": anthropic_models,
|
"anthropic": anthropic_models,
|
||||||
"replicate": replicate_models,
|
"replicate": replicate_models,
|
||||||
"huggingface": huggingface_models,
|
"huggingface": huggingface_models,
|
||||||
|
@ -919,6 +924,7 @@ from .llms.bedrock.embed.amazon_titan_v2_transformation import (
|
||||||
AmazonTitanV2Config,
|
AmazonTitanV2Config,
|
||||||
)
|
)
|
||||||
from .llms.cohere.chat.transformation import CohereChatConfig
|
from .llms.cohere.chat.transformation import CohereChatConfig
|
||||||
|
from .llms.cohere.chat.transformation_v2 import CohereChatConfigV2
|
||||||
from .llms.bedrock.embed.cohere_transformation import BedrockCohereEmbeddingConfig
|
from .llms.bedrock.embed.cohere_transformation import BedrockCohereEmbeddingConfig
|
||||||
from .llms.openai.openai import OpenAIConfig, MistralEmbeddingConfig
|
from .llms.openai.openai import OpenAIConfig, MistralEmbeddingConfig
|
||||||
from .llms.openai.image_variations.transformation import OpenAIImageVariationConfig
|
from .llms.openai.image_variations.transformation import OpenAIImageVariationConfig
|
||||||
|
|
|
@ -96,6 +96,7 @@ LITELLM_CHAT_PROVIDERS = [
|
||||||
"text-completion-openai",
|
"text-completion-openai",
|
||||||
"cohere",
|
"cohere",
|
||||||
"cohere_chat",
|
"cohere_chat",
|
||||||
|
"cohere_chat_v2",
|
||||||
"clarifai",
|
"clarifai",
|
||||||
"anthropic",
|
"anthropic",
|
||||||
"anthropic_text",
|
"anthropic_text",
|
||||||
|
|
|
@ -23,14 +23,16 @@ def _is_non_openai_azure_model(model: str) -> bool:
|
||||||
|
|
||||||
|
|
||||||
def handle_cohere_chat_model_custom_llm_provider(
|
def handle_cohere_chat_model_custom_llm_provider(
|
||||||
model: str, custom_llm_provider: Optional[str] = None
|
model: str, custom_llm_provider: Optional[str] = None, api_version: Optional[str] = None
|
||||||
) -> Tuple[str, Optional[str]]:
|
) -> Tuple[str, Optional[str]]:
|
||||||
"""
|
"""
|
||||||
if user sets model = "cohere/command-r" -> use custom_llm_provider = "cohere_chat"
|
if user sets model = "cohere/command-r" -> use custom_llm_provider = "cohere_chat"
|
||||||
|
if api_version = "v2" -> use custom_llm_provider = "cohere_chat_v2"
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model:
|
model: The model name
|
||||||
custom_llm_provider:
|
custom_llm_provider: The custom LLM provider if specified
|
||||||
|
api_version: The API version (v1 or v2)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
model, custom_llm_provider
|
model, custom_llm_provider
|
||||||
|
@ -38,6 +40,9 @@ def handle_cohere_chat_model_custom_llm_provider(
|
||||||
|
|
||||||
if custom_llm_provider:
|
if custom_llm_provider:
|
||||||
if custom_llm_provider == "cohere" and model in litellm.cohere_chat_models:
|
if custom_llm_provider == "cohere" and model in litellm.cohere_chat_models:
|
||||||
|
# Check if v2 API version is specified
|
||||||
|
if api_version == "v2":
|
||||||
|
return model, "cohere_chat_v2"
|
||||||
return model, "cohere_chat"
|
return model, "cohere_chat"
|
||||||
|
|
||||||
if "/" in model:
|
if "/" in model:
|
||||||
|
@ -47,6 +52,9 @@ def handle_cohere_chat_model_custom_llm_provider(
|
||||||
and _custom_llm_provider == "cohere"
|
and _custom_llm_provider == "cohere"
|
||||||
and _model in litellm.cohere_chat_models
|
and _model in litellm.cohere_chat_models
|
||||||
):
|
):
|
||||||
|
# Check if v2 API version is specified
|
||||||
|
if api_version == "v2":
|
||||||
|
return _model, "cohere_chat_v2"
|
||||||
return _model, "cohere_chat"
|
return _model, "cohere_chat"
|
||||||
|
|
||||||
return model, custom_llm_provider
|
return model, custom_llm_provider
|
||||||
|
@ -122,8 +130,18 @@ def get_llm_provider( # noqa: PLR0915
|
||||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||||
|
|
||||||
### Handle cases when custom_llm_provider is set to cohere/command-r-plus but it should use cohere_chat route
|
### Handle cases when custom_llm_provider is set to cohere/command-r-plus but it should use cohere_chat route
|
||||||
|
# Extract api_version from optional_params if it exists
|
||||||
|
api_version = None
|
||||||
|
if litellm_params and hasattr(litellm_params, "optional_params") and litellm_params.optional_params:
|
||||||
|
api_version = litellm_params.optional_params.get("api_version")
|
||||||
|
|
||||||
|
# Handle direct cohere_chat_v2 model format
|
||||||
|
if model.startswith("cohere_chat_v2/"):
|
||||||
|
model = model.replace("cohere_chat_v2/", "")
|
||||||
|
custom_llm_provider = "cohere_chat_v2"
|
||||||
|
|
||||||
model, custom_llm_provider = handle_cohere_chat_model_custom_llm_provider(
|
model, custom_llm_provider = handle_cohere_chat_model_custom_llm_provider(
|
||||||
model, custom_llm_provider
|
model, custom_llm_provider, api_version
|
||||||
)
|
)
|
||||||
|
|
||||||
model, custom_llm_provider = handle_anthropic_text_model_custom_llm_provider(
|
model, custom_llm_provider = handle_anthropic_text_model_custom_llm_provider(
|
||||||
|
|
|
@ -2007,6 +2007,57 @@ def cohere_messages_pt_v2( # noqa: PLR0915
|
||||||
return returned_message, new_messages
|
return returned_message, new_messages
|
||||||
|
|
||||||
|
|
||||||
|
def cohere_messages_pt_v3(messages: List, model: str, llm_provider: str):
|
||||||
|
"""
|
||||||
|
Format messages for Cohere v2 API
|
||||||
|
|
||||||
|
In v2, messages are combined in a single array with the following format:
|
||||||
|
[
|
||||||
|
{"role": "USER", "content": "Hello"},
|
||||||
|
{"role": "ASSISTANT", "content": "Hi there!"},
|
||||||
|
{"role": "USER", "content": "How are you?"}
|
||||||
|
]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of formatted messages in Cohere v2 format
|
||||||
|
"""
|
||||||
|
cohere_messages = []
|
||||||
|
|
||||||
|
for msg_i, message in enumerate(messages):
|
||||||
|
role = message["role"].upper()
|
||||||
|
|
||||||
|
# Map OpenAI roles to Cohere v2 roles
|
||||||
|
if role == "USER":
|
||||||
|
pass # Keep as USER
|
||||||
|
elif role == "ASSISTANT":
|
||||||
|
role = "CHATBOT" # Cohere v2 uses CHATBOT instead of ASSISTANT
|
||||||
|
elif role == "SYSTEM":
|
||||||
|
role = "USER" # System messages are sent as USER with a special prefix
|
||||||
|
message["content"] = f"<admin>{message['content']}</admin>"
|
||||||
|
elif role == "TOOL":
|
||||||
|
# Skip tool messages as they'll be handled separately with tool_results
|
||||||
|
continue
|
||||||
|
elif role == "FUNCTION":
|
||||||
|
# Skip function messages as they'll be handled separately with tool_results
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle content
|
||||||
|
content = ""
|
||||||
|
if isinstance(message.get("content"), str):
|
||||||
|
content = message["content"]
|
||||||
|
elif isinstance(message.get("content"), list):
|
||||||
|
# Handle content list (text and images)
|
||||||
|
for item in message["content"]:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
if item.get("type") == "text":
|
||||||
|
content += item.get("text", "")
|
||||||
|
|
||||||
|
# Add message to the list
|
||||||
|
cohere_messages.append({"role": role, "content": content})
|
||||||
|
|
||||||
|
return cohere_messages
|
||||||
|
|
||||||
|
|
||||||
def cohere_message_pt(messages: list):
|
def cohere_message_pt(messages: list):
|
||||||
tool_calls: List = get_all_tool_calls(messages=messages)
|
tool_calls: List = get_all_tool_calls(messages=messages)
|
||||||
prompt = ""
|
prompt = ""
|
||||||
|
|
353
litellm/llms/cohere/chat/transformation_v2.py
Normal file
353
litellm/llms/cohere/chat/transformation_v2.py
Normal file
|
@ -0,0 +1,353 @@
|
||||||
|
"""Cohere Chat V2 API Integration Module.
|
||||||
|
|
||||||
|
This module provides the necessary classes and functions to interact with Cohere's V2 Chat API.
|
||||||
|
It handles the transformation of requests and responses between LiteLLM's standard format and
|
||||||
|
Cohere's specific API requirements.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.litellm_core_utils.prompt_templates.factory import cohere_messages_pt_v3
|
||||||
|
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
from litellm.types.utils import ModelResponse, Usage
|
||||||
|
|
||||||
|
# Use absolute imports instead of relative imports
|
||||||
|
from litellm.llms.cohere.common_utils import ModelResponseIterator as CohereModelResponseIterator
|
||||||
|
from litellm.llms.cohere.common_utils import validate_environment as cohere_validate_environment
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||||
|
|
||||||
|
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||||
|
else:
|
||||||
|
LiteLLMLoggingObj = Any
|
||||||
|
|
||||||
|
|
||||||
|
class CohereErrorV2(BaseLLMException):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
status_code: int,
|
||||||
|
message: str,
|
||||||
|
headers: Optional[httpx.Headers] = None,
|
||||||
|
):
|
||||||
|
self.status_code = status_code
|
||||||
|
self.message = message
|
||||||
|
self.request = httpx.Request(method="POST", url="https://api.cohere.com/v2/chat")
|
||||||
|
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||||
|
super().__init__(
|
||||||
|
status_code=status_code,
|
||||||
|
message=message,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CohereChatConfigV2(BaseConfig):
|
||||||
|
"""
|
||||||
|
Configuration class for Cohere's V2 API interface.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
preamble (str, optional): When specified, the default Cohere preamble will be replaced with the provided one.
|
||||||
|
generation_id (str, optional): Unique identifier for the generated reply.
|
||||||
|
conversation_id (str, optional): Creates or resumes a persisted conversation.
|
||||||
|
prompt_truncation (str, optional): Dictates how the prompt will be constructed. Options: 'AUTO', 'AUTO_PRESERVE_ORDER', 'OFF'.
|
||||||
|
connectors (List[Dict[str, str]], optional): List of connectors (e.g., web-search) to enrich the model's reply.
|
||||||
|
search_queries_only (bool, optional): When true, the response will only contain a list of generated search queries.
|
||||||
|
documents (List[Dict[str, str]] or List[str], optional): A list of relevant documents that the model can cite.
|
||||||
|
temperature (float, optional): A non-negative float that tunes the degree of randomness in generation.
|
||||||
|
max_tokens (int, optional): The maximum number of tokens the model will generate as part of the response.
|
||||||
|
k (int, optional): Ensures only the top k most likely tokens are considered for generation at each step.
|
||||||
|
p (float, optional): Ensures that only the most likely tokens, with total probability mass of p, are considered for generation.
|
||||||
|
frequency_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
|
||||||
|
presence_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
|
||||||
|
tools (List[Dict[str, str]], optional): A list of available tools (functions) that the model may suggest invoking.
|
||||||
|
tool_results (List[Dict[str, Any]], optional): A list of results from invoking tools.
|
||||||
|
seed (int, optional): A seed to assist reproducibility of the model's response.
|
||||||
|
"""
|
||||||
|
|
||||||
|
preamble: Optional[str] = None
|
||||||
|
generation_id: Optional[str] = None
|
||||||
|
conversation_id: Optional[str] = None
|
||||||
|
prompt_truncation: Optional[str] = None
|
||||||
|
connectors: Optional[list] = None
|
||||||
|
search_queries_only: Optional[bool] = None
|
||||||
|
documents: Optional[list] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
k: Optional[int] = None
|
||||||
|
p: Optional[float] = None
|
||||||
|
frequency_penalty: Optional[float] = None
|
||||||
|
presence_penalty: Optional[float] = None
|
||||||
|
tools: Optional[list] = None
|
||||||
|
tool_results: Optional[list] = None
|
||||||
|
seed: Optional[int] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
preamble: Optional[str] = None,
|
||||||
|
generation_id: Optional[str] = None,
|
||||||
|
conversation_id: Optional[str] = None,
|
||||||
|
prompt_truncation: Optional[str] = None,
|
||||||
|
connectors: Optional[list] = None,
|
||||||
|
search_queries_only: Optional[bool] = None,
|
||||||
|
documents: Optional[list] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
k: Optional[int] = None,
|
||||||
|
p: Optional[float] = None,
|
||||||
|
frequency_penalty: Optional[float] = None,
|
||||||
|
presence_penalty: Optional[float] = None,
|
||||||
|
tools: Optional[list] = None,
|
||||||
|
tool_results: Optional[list] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals().copy()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
headers: dict,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
return cohere_validate_environment(
|
||||||
|
headers=headers,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
api_key=api_key,
|
||||||
|
api_version="v2", # Specify v2 API version
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||||
|
return [
|
||||||
|
"stream",
|
||||||
|
"temperature",
|
||||||
|
"max_tokens",
|
||||||
|
"top_p",
|
||||||
|
"frequency_penalty",
|
||||||
|
"presence_penalty",
|
||||||
|
"stop",
|
||||||
|
"n",
|
||||||
|
"tools",
|
||||||
|
"tool_choice",
|
||||||
|
"seed",
|
||||||
|
"extra_headers",
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
) -> dict:
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param == "stream":
|
||||||
|
optional_params["stream"] = value
|
||||||
|
if param == "temperature":
|
||||||
|
optional_params["temperature"] = value
|
||||||
|
if param == "max_tokens":
|
||||||
|
optional_params["max_tokens"] = value
|
||||||
|
if param == "n":
|
||||||
|
optional_params["num_generations"] = value
|
||||||
|
if param == "top_p":
|
||||||
|
optional_params["p"] = value
|
||||||
|
if param == "frequency_penalty":
|
||||||
|
optional_params["frequency_penalty"] = value
|
||||||
|
if param == "presence_penalty":
|
||||||
|
optional_params["presence_penalty"] = value
|
||||||
|
if param == "stop":
|
||||||
|
optional_params["stop_sequences"] = value
|
||||||
|
if param == "tools":
|
||||||
|
optional_params["tools"] = value
|
||||||
|
if param == "seed":
|
||||||
|
optional_params["seed"] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def transform_request(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
headers: dict,
|
||||||
|
) -> dict:
|
||||||
|
## Load Config
|
||||||
|
for k, v in litellm.CohereChatConfigV2.get_config().items():
|
||||||
|
if (
|
||||||
|
k not in optional_params
|
||||||
|
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
|
optional_params[k] = v
|
||||||
|
|
||||||
|
# In v2, messages are combined in a single array
|
||||||
|
cohere_messages = cohere_messages_pt_v3(
|
||||||
|
messages=messages, model=model, llm_provider="cohere_chat"
|
||||||
|
)
|
||||||
|
optional_params["messages"] = cohere_messages
|
||||||
|
optional_params["model"] = model.split("/")[-1] # Extract model name from model string
|
||||||
|
|
||||||
|
## Handle Tool Calling
|
||||||
|
if "tools" in optional_params:
|
||||||
|
cohere_tools = self._construct_cohere_tool(tools=optional_params["tools"])
|
||||||
|
optional_params["tools"] = cohere_tools
|
||||||
|
|
||||||
|
# Handle tool results if present
|
||||||
|
if "tool_results" in optional_params and isinstance(optional_params["tool_results"], list):
|
||||||
|
# Convert tool results to v2 format if needed
|
||||||
|
tool_results = []
|
||||||
|
for result in optional_params["tool_results"]:
|
||||||
|
if isinstance(result, dict) and "content" in result:
|
||||||
|
# Format from v1 to v2
|
||||||
|
tool_result = {
|
||||||
|
"tool_call_id": result.get("tool_call_id", ""),
|
||||||
|
"output": result.get("content", ""),
|
||||||
|
}
|
||||||
|
tool_results.append(tool_result)
|
||||||
|
else:
|
||||||
|
# Already in v2 format
|
||||||
|
tool_results.append(result)
|
||||||
|
optional_params["tool_results"] = tool_results
|
||||||
|
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def transform_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
raw_response: httpx.Response,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
request_data: dict,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
encoding: Any,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
json_mode: Optional[bool] = None,
|
||||||
|
) -> ModelResponse:
|
||||||
|
try:
|
||||||
|
raw_response_json = raw_response.json()
|
||||||
|
model_response.choices[0].message.content = raw_response_json.get("text", "") # type: ignore
|
||||||
|
except Exception:
|
||||||
|
raise CohereErrorV2(
|
||||||
|
message=raw_response.text, status_code=raw_response.status_code
|
||||||
|
)
|
||||||
|
|
||||||
|
## ADD CITATIONS
|
||||||
|
if "citations" in raw_response_json:
|
||||||
|
setattr(model_response, "citations", raw_response_json["citations"])
|
||||||
|
|
||||||
|
## Tool calling response
|
||||||
|
cohere_tools_response = raw_response_json.get("tool_calls", None)
|
||||||
|
if cohere_tools_response is not None and cohere_tools_response != []:
|
||||||
|
# convert cohere_tools_response to OpenAI response format
|
||||||
|
tool_calls = []
|
||||||
|
for tool in cohere_tools_response:
|
||||||
|
function_name = tool.get("name", "")
|
||||||
|
tool_call_id = tool.get("id", "")
|
||||||
|
parameters = tool.get("parameters", {})
|
||||||
|
tool_call = {
|
||||||
|
"id": tool_call_id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": function_name,
|
||||||
|
"arguments": json.dumps(parameters),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tool_calls.append(tool_call)
|
||||||
|
_message = litellm.Message(
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
content=None,
|
||||||
|
)
|
||||||
|
model_response.choices[0].message = _message # type: ignore
|
||||||
|
|
||||||
|
## CALCULATING USAGE - use cohere `billed_units` for returning usage
|
||||||
|
billed_units = raw_response_json.get("usage", {})
|
||||||
|
|
||||||
|
prompt_tokens = billed_units.get("input_tokens", 0)
|
||||||
|
completion_tokens = billed_units.get("output_tokens", 0)
|
||||||
|
|
||||||
|
model_response.created = int(time.time())
|
||||||
|
model_response.model = model
|
||||||
|
usage = Usage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
)
|
||||||
|
setattr(model_response, "usage", usage)
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
def _construct_cohere_tool(
|
||||||
|
self,
|
||||||
|
tools: Optional[list] = None,
|
||||||
|
):
|
||||||
|
if tools is None:
|
||||||
|
tools = []
|
||||||
|
cohere_tools = []
|
||||||
|
for tool in tools:
|
||||||
|
cohere_tool = self._translate_openai_tool_to_cohere(tool)
|
||||||
|
cohere_tools.append(cohere_tool)
|
||||||
|
return cohere_tools
|
||||||
|
|
||||||
|
def _translate_openai_tool_to_cohere(
|
||||||
|
self,
|
||||||
|
openai_tool: dict,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Translates OpenAI tool format to Cohere v2 tool format
|
||||||
|
|
||||||
|
Cohere v2 tools look like this:
|
||||||
|
{
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA"
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
cohere_tool = {
|
||||||
|
"name": openai_tool["function"]["name"],
|
||||||
|
"description": openai_tool["function"]["description"],
|
||||||
|
"input_schema": openai_tool["function"]["parameters"],
|
||||||
|
}
|
||||||
|
|
||||||
|
return cohere_tool
|
||||||
|
|
||||||
|
def get_model_response_iterator(
|
||||||
|
self,
|
||||||
|
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||||
|
sync_stream: bool,
|
||||||
|
json_mode: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
return CohereModelResponseIterator(
|
||||||
|
streaming_response=streaming_response,
|
||||||
|
sync_stream=sync_stream,
|
||||||
|
json_mode=json_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_error_class(
|
||||||
|
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||||
|
) -> BaseLLMException:
|
||||||
|
return CohereErrorV2(status_code=status_code, message=error_message)
|
|
@ -21,11 +21,15 @@ def validate_environment(
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
api_version: Optional[str] = "v1",
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Return headers to use for cohere chat completion request
|
Return headers to use for cohere chat completion request
|
||||||
|
|
||||||
Cohere API Ref: https://docs.cohere.com/reference/chat
|
Cohere API Ref:
|
||||||
|
- v1: https://docs.cohere.com/reference/chat
|
||||||
|
- v2: https://docs.cohere.com/v2/reference/chat
|
||||||
|
|
||||||
Expected headers:
|
Expected headers:
|
||||||
{
|
{
|
||||||
"Request-Source": "unspecified:litellm",
|
"Request-Source": "unspecified:litellm",
|
||||||
|
|
|
@ -2108,6 +2108,45 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
api_key=cohere_key,
|
api_key=cohere_key,
|
||||||
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "cohere_chat_v2":
|
||||||
|
cohere_key = (
|
||||||
|
api_key
|
||||||
|
or litellm.cohere_key
|
||||||
|
or get_secret_str("COHERE_API_KEY")
|
||||||
|
or get_secret_str("CO_API_KEY")
|
||||||
|
or litellm.api_key
|
||||||
|
)
|
||||||
|
|
||||||
|
api_base = (
|
||||||
|
api_base
|
||||||
|
or litellm.api_base
|
||||||
|
or get_secret_str("COHERE_API_BASE")
|
||||||
|
or "https://api.cohere.ai/v2/chat"
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = headers or litellm.headers or {}
|
||||||
|
if headers is None:
|
||||||
|
headers = {}
|
||||||
|
|
||||||
|
if extra_headers is not None:
|
||||||
|
headers.update(extra_headers)
|
||||||
|
|
||||||
|
response = base_llm_http_handler.completion(
|
||||||
|
model=model,
|
||||||
|
stream=stream,
|
||||||
|
messages=messages,
|
||||||
|
acompletion=acompletion,
|
||||||
|
api_base=api_base,
|
||||||
|
model_response=model_response,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
custom_llm_provider="cohere_chat_v2",
|
||||||
|
timeout=timeout,
|
||||||
|
headers=headers,
|
||||||
|
encoding=encoding,
|
||||||
|
api_key=cohere_key,
|
||||||
|
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||||
|
)
|
||||||
elif custom_llm_provider == "maritalk":
|
elif custom_llm_provider == "maritalk":
|
||||||
maritalk_key = (
|
maritalk_key = (
|
||||||
api_key
|
api_key
|
||||||
|
|
|
@ -2004,6 +2004,7 @@ class LlmProviders(str, Enum):
|
||||||
TEXT_COMPLETION_OPENAI = "text-completion-openai"
|
TEXT_COMPLETION_OPENAI = "text-completion-openai"
|
||||||
COHERE = "cohere"
|
COHERE = "cohere"
|
||||||
COHERE_CHAT = "cohere_chat"
|
COHERE_CHAT = "cohere_chat"
|
||||||
|
COHERE_CHAT_V2 = "cohere_chat_v2"
|
||||||
CLARIFAI = "clarifai"
|
CLARIFAI = "clarifai"
|
||||||
ANTHROPIC = "anthropic"
|
ANTHROPIC = "anthropic"
|
||||||
ANTHROPIC_TEXT = "anthropic_text"
|
ANTHROPIC_TEXT = "anthropic_text"
|
||||||
|
|
|
@ -6233,6 +6233,8 @@ class ProviderConfigManager:
|
||||||
return litellm.OpenAITextCompletionConfig()
|
return litellm.OpenAITextCompletionConfig()
|
||||||
elif litellm.LlmProviders.COHERE_CHAT == provider:
|
elif litellm.LlmProviders.COHERE_CHAT == provider:
|
||||||
return litellm.CohereChatConfig()
|
return litellm.CohereChatConfig()
|
||||||
|
elif litellm.LlmProviders.COHERE_CHAT_V2 == provider:
|
||||||
|
return litellm.CohereChatConfigV2()
|
||||||
elif litellm.LlmProviders.COHERE == provider:
|
elif litellm.LlmProviders.COHERE == provider:
|
||||||
return litellm.CohereConfig()
|
return litellm.CohereConfig()
|
||||||
elif litellm.LlmProviders.SNOWFLAKE == provider:
|
elif litellm.LlmProviders.SNOWFLAKE == provider:
|
||||||
|
|
999
tests/llm_translation/test_cohere_v2.py
Normal file
999
tests/llm_translation/test_cohere_v2.py
Normal file
|
@ -0,0 +1,999 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# For testing, make sure the COHERE_API_KEY or CO_API_KEY environment variable is set
|
||||||
|
# You can set it before running the tests with: export COHERE_API_KEY=your_api_key
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
|
|
||||||
|
litellm.num_retries = 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("stream", [True, False])
|
||||||
|
@pytest.mark.flaky(retries=3, delay=1)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_completion_cohere_v2_citations(stream):
|
||||||
|
try:
|
||||||
|
class MockResponse:
|
||||||
|
def __init__(self, status_code, json_data, is_stream=False):
|
||||||
|
self.status_code = status_code
|
||||||
|
self._json_data = json_data
|
||||||
|
self.headers = {}
|
||||||
|
self.is_stream = is_stream
|
||||||
|
|
||||||
|
# For streaming responses with citations
|
||||||
|
if is_stream:
|
||||||
|
# Create streaming chunks with citations at the end
|
||||||
|
self._iter_content_chunks = [
|
||||||
|
json.dumps({"text": "Emperor"}).encode(),
|
||||||
|
json.dumps({"text": " penguins"}).encode(),
|
||||||
|
json.dumps({"text": " are"}).encode(),
|
||||||
|
json.dumps({"text": " the"}).encode(),
|
||||||
|
json.dumps({"text": " tallest"}).encode(),
|
||||||
|
json.dumps({"text": " and"}).encode(),
|
||||||
|
json.dumps({"text": " they"}).encode(),
|
||||||
|
json.dumps({"text": " live"}).encode(),
|
||||||
|
json.dumps({"text": " in"}).encode(),
|
||||||
|
json.dumps({"text": " Antarctica"}).encode(),
|
||||||
|
json.dumps({"text": "."}).encode(),
|
||||||
|
# Citations in a separate chunk
|
||||||
|
json.dumps({"citations": [
|
||||||
|
{
|
||||||
|
"start": 0,
|
||||||
|
"end": 30,
|
||||||
|
"text": "Emperor penguins are the tallest",
|
||||||
|
"document_ids": ["doc1"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"start": 31,
|
||||||
|
"end": 70,
|
||||||
|
"text": "they live in Antarctica",
|
||||||
|
"document_ids": ["doc2"]
|
||||||
|
}
|
||||||
|
]}).encode(),
|
||||||
|
json.dumps({"finish_reason": "COMPLETE"}).encode(),
|
||||||
|
]
|
||||||
|
|
||||||
|
def json(self):
|
||||||
|
return self._json_data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text(self):
|
||||||
|
return json.dumps(self._json_data)
|
||||||
|
|
||||||
|
def iter_lines(self):
|
||||||
|
if self.is_stream:
|
||||||
|
for chunk in self._iter_content_chunks:
|
||||||
|
yield chunk
|
||||||
|
else:
|
||||||
|
yield json.dumps(self._json_data).encode()
|
||||||
|
|
||||||
|
async def aiter_lines(self):
|
||||||
|
if self.is_stream:
|
||||||
|
for chunk in self._iter_content_chunks:
|
||||||
|
yield chunk
|
||||||
|
else:
|
||||||
|
yield json.dumps(self._json_data).encode()
|
||||||
|
|
||||||
|
async def mock_async_post(*args, **kwargs):
|
||||||
|
# For asynchronous HTTP client
|
||||||
|
data = kwargs.get("data", "{}")
|
||||||
|
request_body = json.loads(data)
|
||||||
|
print("Async Request body:", request_body)
|
||||||
|
|
||||||
|
# Verify the messages are formatted correctly for v2
|
||||||
|
messages = request_body.get("messages", [])
|
||||||
|
assert len(messages) > 0
|
||||||
|
assert "role" in messages[0]
|
||||||
|
assert "content" in messages[0]
|
||||||
|
|
||||||
|
# Check if documents are included
|
||||||
|
documents = request_body.get("documents", [])
|
||||||
|
assert len(documents) > 0
|
||||||
|
|
||||||
|
# Mock response with citations
|
||||||
|
mock_response = {
|
||||||
|
"text": "Emperor penguins are the tallest penguins and they live in Antarctica.",
|
||||||
|
"generation_id": "mock-id",
|
||||||
|
"id": "mock-completion",
|
||||||
|
"usage": {"input_tokens": 10, "output_tokens": 20},
|
||||||
|
"citations": [
|
||||||
|
{
|
||||||
|
"start": 0,
|
||||||
|
"end": 30,
|
||||||
|
"text": "Emperor penguins are the tallest",
|
||||||
|
"document_ids": ["doc1"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"start": 31,
|
||||||
|
"end": 70,
|
||||||
|
"text": "they live in Antarctica",
|
||||||
|
"document_ids": ["doc2"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create a streaming response with citations
|
||||||
|
if stream:
|
||||||
|
return MockResponse(
|
||||||
|
200,
|
||||||
|
{
|
||||||
|
"text": "Emperor penguins are the tallest penguins and they live in Antarctica.",
|
||||||
|
"generation_id": "mock-id",
|
||||||
|
"id": "mock-completion",
|
||||||
|
"usage": {"input_tokens": 10, "output_tokens": 20},
|
||||||
|
"citations": [
|
||||||
|
{
|
||||||
|
"start": 0,
|
||||||
|
"end": 30,
|
||||||
|
"text": "Emperor penguins are the tallest",
|
||||||
|
"document_ids": ["doc1"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"start": 31,
|
||||||
|
"end": 70,
|
||||||
|
"text": "they live in Antarctica",
|
||||||
|
"document_ids": ["doc2"]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stream": True
|
||||||
|
},
|
||||||
|
is_stream=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return MockResponse(200, mock_response)
|
||||||
|
|
||||||
|
# Mock the async HTTP client
|
||||||
|
with patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", new_callable=AsyncMock, side_effect=mock_async_post):
|
||||||
|
litellm.set_verbose = True
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Which penguins are the tallest?",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="cohere_chat_v2/command-r",
|
||||||
|
messages=messages,
|
||||||
|
stream=stream,
|
||||||
|
documents=[
|
||||||
|
{"title": "Tall penguins", "text": "Emperor penguins are the tallest."},
|
||||||
|
{
|
||||||
|
"title": "Penguin habitats",
|
||||||
|
"text": "Emperor penguins only live in Antarctica.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
citations_chunk = False
|
||||||
|
async for chunk in response:
|
||||||
|
print("received chunk", chunk)
|
||||||
|
if hasattr(chunk, "citations") or (isinstance(chunk, dict) and "citations" in chunk):
|
||||||
|
citations_chunk = True
|
||||||
|
break
|
||||||
|
assert citations_chunk
|
||||||
|
else:
|
||||||
|
assert hasattr(response, "citations")
|
||||||
|
except litellm.ServiceUnavailableError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion_cohere_v2_command_r_plus_function_call():
|
||||||
|
litellm.set_verbose = True
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What's the weather like in Boston today in Fahrenheit?",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
# test without max tokens
|
||||||
|
response = completion(
|
||||||
|
model="command-r-plus",
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="auto",
|
||||||
|
api_version="v2", # Specify v2 API version
|
||||||
|
)
|
||||||
|
# Add any assertions, here to check response args
|
||||||
|
print(response)
|
||||||
|
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
|
||||||
|
assert isinstance(
|
||||||
|
response.choices[0].message.tool_calls[0].function.arguments, str
|
||||||
|
)
|
||||||
|
|
||||||
|
messages.append(
|
||||||
|
response.choices[0].message.model_dump()
|
||||||
|
) # Add assistant tool invokes
|
||||||
|
tool_result = (
|
||||||
|
'{"location": "Boston", "temperature": "72", "unit": "fahrenheit"}'
|
||||||
|
)
|
||||||
|
# Add user submitted tool results in the OpenAI format
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"tool_call_id": response.choices[0].message.tool_calls[0].id,
|
||||||
|
"role": "tool",
|
||||||
|
"name": response.choices[0].message.tool_calls[0].function.name,
|
||||||
|
"content": tool_result,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# In the second response, Cohere should deduce answer from tool results
|
||||||
|
second_response = completion(
|
||||||
|
model="command-r-plus",
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="auto",
|
||||||
|
force_single_step=True,
|
||||||
|
api_version="v2", # Specify v2 API version
|
||||||
|
)
|
||||||
|
print(second_response)
|
||||||
|
except litellm.Timeout:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.flaky(retries=6, delay=1)
|
||||||
|
def test_completion_cohere_v2():
|
||||||
|
try:
|
||||||
|
# litellm.set_verbose=True
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You're a good bot"},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hey",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
response = completion(
|
||||||
|
model="command-r",
|
||||||
|
messages=messages,
|
||||||
|
api_version="v2", # Specify v2 API version
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
async def test_chat_completion_cohere_v2(sync_mode):
|
||||||
|
try:
|
||||||
|
class MockResponse:
|
||||||
|
def __init__(self, status_code, json_data, is_stream=False):
|
||||||
|
self.status_code = status_code
|
||||||
|
self._json_data = json_data
|
||||||
|
self.headers = {}
|
||||||
|
self.is_stream = is_stream
|
||||||
|
|
||||||
|
# For streaming responses with citations
|
||||||
|
if is_stream:
|
||||||
|
# Create streaming chunks with citations at the end
|
||||||
|
self._iter_content_chunks = [
|
||||||
|
json.dumps({"text": "Emperor"}).encode(),
|
||||||
|
json.dumps({"text": " penguins"}).encode(),
|
||||||
|
json.dumps({"text": " are"}).encode(),
|
||||||
|
json.dumps({"text": " the"}).encode(),
|
||||||
|
json.dumps({"text": " tallest"}).encode(),
|
||||||
|
json.dumps({"text": " and"}).encode(),
|
||||||
|
json.dumps({"text": " they"}).encode(),
|
||||||
|
json.dumps({"text": " live"}).encode(),
|
||||||
|
json.dumps({"text": " in"}).encode(),
|
||||||
|
json.dumps({"text": " Antarctica"}).encode(),
|
||||||
|
json.dumps({"text": "."}).encode(),
|
||||||
|
# Citations in a separate chunk
|
||||||
|
json.dumps({"citations": [
|
||||||
|
{
|
||||||
|
"start": 0,
|
||||||
|
"end": 30,
|
||||||
|
"text": "Emperor penguins are the tallest",
|
||||||
|
"document_ids": ["doc1"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"start": 31,
|
||||||
|
"end": 70,
|
||||||
|
"text": "they live in Antarctica",
|
||||||
|
"document_ids": ["doc2"]
|
||||||
|
}
|
||||||
|
]}).encode(),
|
||||||
|
json.dumps({"finish_reason": "COMPLETE"}).encode(),
|
||||||
|
]
|
||||||
|
|
||||||
|
def json(self):
|
||||||
|
return self._json_data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text(self):
|
||||||
|
return json.dumps(self._json_data)
|
||||||
|
|
||||||
|
def iter_lines(self):
|
||||||
|
if self.is_stream:
|
||||||
|
for chunk in self._iter_content_chunks:
|
||||||
|
yield chunk
|
||||||
|
else:
|
||||||
|
yield json.dumps(self._json_data).encode()
|
||||||
|
|
||||||
|
async def aiter_lines(self):
|
||||||
|
if self.is_stream:
|
||||||
|
for chunk in self._iter_content_chunks:
|
||||||
|
yield chunk
|
||||||
|
else:
|
||||||
|
yield json.dumps(self._json_data).encode()
|
||||||
|
|
||||||
|
def mock_sync_post(*args, **kwargs):
|
||||||
|
# For synchronous HTTP client
|
||||||
|
data = kwargs.get("data", "{}")
|
||||||
|
request_body = json.loads(data)
|
||||||
|
print("Sync Request body:", request_body)
|
||||||
|
|
||||||
|
# Verify the model is passed correctly
|
||||||
|
assert request_body.get("model") == "command-r"
|
||||||
|
|
||||||
|
# Verify max_tokens is passed correctly
|
||||||
|
assert request_body.get("max_tokens") == 10
|
||||||
|
|
||||||
|
# Verify the messages are formatted correctly for v2
|
||||||
|
messages = request_body.get("messages", [])
|
||||||
|
assert len(messages) > 0
|
||||||
|
assert "role" in messages[0]
|
||||||
|
assert "content" in messages[0]
|
||||||
|
|
||||||
|
# Mock response
|
||||||
|
return MockResponse(
|
||||||
|
200,
|
||||||
|
{
|
||||||
|
"text": "This is a mocked response for sync request",
|
||||||
|
"generation_id": "mock-id",
|
||||||
|
"id": "mock-completion",
|
||||||
|
"usage": {"input_tokens": 10, "output_tokens": 20},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_async_post(*args, **kwargs):
|
||||||
|
# For asynchronous HTTP client
|
||||||
|
data = kwargs.get("data", "{}")
|
||||||
|
request_body = json.loads(data)
|
||||||
|
print("Async Request body:", request_body)
|
||||||
|
|
||||||
|
# Verify the model is passed correctly
|
||||||
|
assert request_body.get("model") == "command-r"
|
||||||
|
|
||||||
|
# Verify max_tokens is passed correctly
|
||||||
|
assert request_body.get("max_tokens") == 10
|
||||||
|
|
||||||
|
# Verify the messages are formatted correctly for v2
|
||||||
|
messages = request_body.get("messages", [])
|
||||||
|
assert len(messages) > 0
|
||||||
|
assert "role" in messages[0]
|
||||||
|
assert "content" in messages[0]
|
||||||
|
|
||||||
|
# Mock response
|
||||||
|
return MockResponse(
|
||||||
|
200,
|
||||||
|
{
|
||||||
|
"text": "This is a mocked response for async request",
|
||||||
|
"generation_id": "mock-id",
|
||||||
|
"id": "mock-completion",
|
||||||
|
"usage": {"input_tokens": 10, "output_tokens": 20},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock both sync and async HTTP clients
|
||||||
|
with patch("litellm.llms.custom_httpx.http_handler.HTTPHandler.post", side_effect=mock_sync_post):
|
||||||
|
with patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", new_callable=AsyncMock, side_effect=mock_async_post):
|
||||||
|
litellm.set_verbose = True
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You're a good bot"},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hey",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
if sync_mode is False:
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="cohere_chat_v2/command-r",
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = completion(
|
||||||
|
model="cohere_chat_v2/command-r",
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
assert response is not None
|
||||||
|
assert "This is a mocked response" in response.choices[0].message.content
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("sync_mode", [False])
|
||||||
|
async def test_chat_completion_cohere_v2_stream(sync_mode):
|
||||||
|
try:
|
||||||
|
class MockResponse:
|
||||||
|
def __init__(self, status_code, json_data, is_stream=False):
|
||||||
|
self.status_code = status_code
|
||||||
|
self._json_data = json_data
|
||||||
|
self.headers = {}
|
||||||
|
self.is_stream = is_stream
|
||||||
|
|
||||||
|
# For streaming responses
|
||||||
|
if is_stream:
|
||||||
|
self._iter_content_chunks = [
|
||||||
|
json.dumps({"text": "This"}).encode(),
|
||||||
|
json.dumps({"text": " is"}).encode(),
|
||||||
|
json.dumps({"text": " a"}).encode(),
|
||||||
|
json.dumps({"text": " streamed"}).encode(),
|
||||||
|
json.dumps({"text": " response"}).encode(),
|
||||||
|
json.dumps({"text": "."}).encode(),
|
||||||
|
json.dumps({"finish_reason": "COMPLETE"}).encode(),
|
||||||
|
]
|
||||||
|
|
||||||
|
def json(self):
|
||||||
|
return self._json_data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text(self):
|
||||||
|
return json.dumps(self._json_data)
|
||||||
|
|
||||||
|
def iter_lines(self):
|
||||||
|
if self.is_stream:
|
||||||
|
for chunk in self._iter_content_chunks:
|
||||||
|
yield chunk
|
||||||
|
else:
|
||||||
|
yield json.dumps(self._json_data).encode()
|
||||||
|
|
||||||
|
async def aiter_lines(self):
|
||||||
|
if self.is_stream:
|
||||||
|
for chunk in self._iter_content_chunks:
|
||||||
|
yield chunk
|
||||||
|
else:
|
||||||
|
yield json.dumps(self._json_data).encode()
|
||||||
|
|
||||||
|
async def mock_async_post(*args, **kwargs):
|
||||||
|
# For asynchronous HTTP client
|
||||||
|
data = kwargs.get("data", "{}")
|
||||||
|
request_body = json.loads(data)
|
||||||
|
print("Async Request body:", request_body)
|
||||||
|
|
||||||
|
# Verify the model is passed correctly
|
||||||
|
assert request_body.get("model") == "command-r"
|
||||||
|
|
||||||
|
# Verify max_tokens is passed correctly
|
||||||
|
assert request_body.get("max_tokens") == 10
|
||||||
|
|
||||||
|
# Verify stream is set to True
|
||||||
|
assert request_body.get("stream") == True
|
||||||
|
|
||||||
|
# Verify the messages are formatted correctly for v2
|
||||||
|
messages = request_body.get("messages", [])
|
||||||
|
assert len(messages) > 0
|
||||||
|
assert "role" in messages[0]
|
||||||
|
assert "content" in messages[0]
|
||||||
|
|
||||||
|
# Return a streaming response
|
||||||
|
return MockResponse(
|
||||||
|
200,
|
||||||
|
{
|
||||||
|
"text": "This is a streamed response.",
|
||||||
|
"generation_id": "mock-id",
|
||||||
|
"id": "mock-completion",
|
||||||
|
"usage": {"input_tokens": 10, "output_tokens": 20},
|
||||||
|
},
|
||||||
|
is_stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the async HTTP client for streaming
|
||||||
|
with patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", new_callable=AsyncMock, side_effect=mock_async_post):
|
||||||
|
litellm.set_verbose = True
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You're a good bot"},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hey",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
if sync_mode is False:
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="cohere_chat_v2/command-r",
|
||||||
|
messages=messages,
|
||||||
|
stream=True,
|
||||||
|
max_tokens=10,
|
||||||
|
)
|
||||||
|
# Verify we get streaming chunks
|
||||||
|
chunk_count = 0
|
||||||
|
async for chunk in response:
|
||||||
|
print(f"chunk: {chunk}")
|
||||||
|
chunk_count += 1
|
||||||
|
assert chunk_count > 0, "No streaming chunks were received"
|
||||||
|
else:
|
||||||
|
# This test is only for async mode
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_v2_mock_completion():
|
||||||
|
"""
|
||||||
|
Test cohere_chat_v2 completion with mocked responses to avoid API calls
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
class MockResponse:
|
||||||
|
def __init__(self, status_code, json_data, is_stream=False):
|
||||||
|
self.status_code = status_code
|
||||||
|
self._json_data = json_data
|
||||||
|
self.headers = {}
|
||||||
|
self.is_stream = is_stream
|
||||||
|
|
||||||
|
# For streaming responses with citations
|
||||||
|
if is_stream:
|
||||||
|
# Create streaming chunks with citations at the end
|
||||||
|
self._iter_content_chunks = [
|
||||||
|
json.dumps({"text": "Emperor"}).encode(),
|
||||||
|
json.dumps({"text": " penguins"}).encode(),
|
||||||
|
json.dumps({"text": " are"}).encode(),
|
||||||
|
json.dumps({"text": " the"}).encode(),
|
||||||
|
json.dumps({"text": " tallest"}).encode(),
|
||||||
|
json.dumps({"text": " and"}).encode(),
|
||||||
|
json.dumps({"text": " they"}).encode(),
|
||||||
|
json.dumps({"text": " live"}).encode(),
|
||||||
|
json.dumps({"text": " in"}).encode(),
|
||||||
|
json.dumps({"text": " Antarctica"}).encode(),
|
||||||
|
json.dumps({"text": "."}).encode(),
|
||||||
|
# Citations in a separate chunk
|
||||||
|
json.dumps({"citations": [
|
||||||
|
{
|
||||||
|
"start": 0,
|
||||||
|
"end": 30,
|
||||||
|
"text": "Emperor penguins are the tallest",
|
||||||
|
"document_ids": ["doc1"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"start": 31,
|
||||||
|
"end": 70,
|
||||||
|
"text": "they live in Antarctica",
|
||||||
|
"document_ids": ["doc2"]
|
||||||
|
}
|
||||||
|
]}).encode(),
|
||||||
|
json.dumps({"finish_reason": "COMPLETE"}).encode(),
|
||||||
|
]
|
||||||
|
|
||||||
|
def json(self):
|
||||||
|
return self._json_data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text(self):
|
||||||
|
return json.dumps(self._json_data)
|
||||||
|
|
||||||
|
def iter_lines(self):
|
||||||
|
if self.is_stream:
|
||||||
|
for chunk in self._iter_content_chunks:
|
||||||
|
yield chunk
|
||||||
|
else:
|
||||||
|
yield json.dumps(self._json_data).encode()
|
||||||
|
|
||||||
|
async def aiter_lines(self):
|
||||||
|
if self.is_stream:
|
||||||
|
for chunk in self._iter_content_chunks:
|
||||||
|
yield chunk
|
||||||
|
else:
|
||||||
|
yield json.dumps(self._json_data).encode()
|
||||||
|
|
||||||
|
def mock_sync_post(*args, **kwargs):
|
||||||
|
# For synchronous HTTP client
|
||||||
|
data = kwargs.get("data", "{}")
|
||||||
|
request_body = json.loads(data)
|
||||||
|
print("Sync Request body:", request_body)
|
||||||
|
|
||||||
|
# Verify the messages are formatted correctly for v2
|
||||||
|
messages = request_body.get("messages", [])
|
||||||
|
assert len(messages) > 0
|
||||||
|
assert "role" in messages[0]
|
||||||
|
assert "content" in messages[0]
|
||||||
|
|
||||||
|
# Mock response
|
||||||
|
return MockResponse(
|
||||||
|
200,
|
||||||
|
{
|
||||||
|
"text": "This is a mocked response from Cohere v2 API",
|
||||||
|
"generation_id": "mock-id",
|
||||||
|
"id": "mock-completion",
|
||||||
|
"usage": {"input_tokens": 10, "output_tokens": 20},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_async_post(*args, **kwargs):
|
||||||
|
# For asynchronous HTTP client
|
||||||
|
data = kwargs.get("data", "{}")
|
||||||
|
request_body = json.loads(data)
|
||||||
|
print("Async Request body:", request_body)
|
||||||
|
|
||||||
|
# Verify the messages are formatted correctly for v2
|
||||||
|
messages = request_body.get("messages", [])
|
||||||
|
assert len(messages) > 0
|
||||||
|
assert "role" in messages[0]
|
||||||
|
assert "content" in messages[0]
|
||||||
|
|
||||||
|
# Mock response
|
||||||
|
return MockResponse(
|
||||||
|
200,
|
||||||
|
{
|
||||||
|
"text": "This is a mocked response from Cohere v2 API",
|
||||||
|
"generation_id": "mock-id",
|
||||||
|
"id": "mock-completion",
|
||||||
|
"usage": {"input_tokens": 10, "output_tokens": 20},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock both sync and async HTTP clients
|
||||||
|
with patch("litellm.llms.custom_httpx.http_handler.HTTPHandler.post", side_effect=mock_sync_post):
|
||||||
|
with patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", new_callable=AsyncMock, side_effect=mock_async_post):
|
||||||
|
litellm.set_verbose = True
|
||||||
|
messages = [{"role": "user", "content": "Hello from mock test"}]
|
||||||
|
response = completion(
|
||||||
|
model="cohere_chat_v2/command-r",
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
assert response is not None
|
||||||
|
assert "This is a mocked response" in response.choices[0].message.content
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_v2_request_body_with_allowed_params():
|
||||||
|
"""
|
||||||
|
Test to validate that when allowed_openai_params is provided, the request body contains
|
||||||
|
the correct response_format and reasoning_effort values.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
class MockResponse:
|
||||||
|
def __init__(self, status_code, json_data, is_stream=False):
|
||||||
|
self.status_code = status_code
|
||||||
|
self._json_data = json_data
|
||||||
|
self.headers = {}
|
||||||
|
self.is_stream = is_stream
|
||||||
|
|
||||||
|
# For streaming responses with citations
|
||||||
|
if is_stream:
|
||||||
|
# Create streaming chunks with citations at the end
|
||||||
|
self._iter_content_chunks = [
|
||||||
|
json.dumps({"text": "Emperor"}).encode(),
|
||||||
|
json.dumps({"text": " penguins"}).encode(),
|
||||||
|
json.dumps({"text": " are"}).encode(),
|
||||||
|
json.dumps({"text": " the"}).encode(),
|
||||||
|
json.dumps({"text": " tallest"}).encode(),
|
||||||
|
json.dumps({"text": " and"}).encode(),
|
||||||
|
json.dumps({"text": " they"}).encode(),
|
||||||
|
json.dumps({"text": " live"}).encode(),
|
||||||
|
json.dumps({"text": " in"}).encode(),
|
||||||
|
json.dumps({"text": " Antarctica"}).encode(),
|
||||||
|
json.dumps({"text": "."}).encode(),
|
||||||
|
# Citations in a separate chunk
|
||||||
|
json.dumps({"citations": [
|
||||||
|
{
|
||||||
|
"start": 0,
|
||||||
|
"end": 30,
|
||||||
|
"text": "Emperor penguins are the tallest",
|
||||||
|
"document_ids": ["doc1"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"start": 31,
|
||||||
|
"end": 70,
|
||||||
|
"text": "they live in Antarctica",
|
||||||
|
"document_ids": ["doc2"]
|
||||||
|
}
|
||||||
|
]}).encode(),
|
||||||
|
json.dumps({"finish_reason": "COMPLETE"}).encode(),
|
||||||
|
]
|
||||||
|
|
||||||
|
def json(self):
|
||||||
|
return self._json_data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text(self):
|
||||||
|
return json.dumps(self._json_data)
|
||||||
|
|
||||||
|
def iter_lines(self):
|
||||||
|
if self.is_stream:
|
||||||
|
for chunk in self._iter_content_chunks:
|
||||||
|
yield chunk
|
||||||
|
else:
|
||||||
|
yield json.dumps(self._json_data).encode()
|
||||||
|
|
||||||
|
async def aiter_lines(self):
|
||||||
|
if self.is_stream:
|
||||||
|
for chunk in self._iter_content_chunks:
|
||||||
|
yield chunk
|
||||||
|
else:
|
||||||
|
yield json.dumps(self._json_data).encode()
|
||||||
|
|
||||||
|
def mock_sync_post(*args, **kwargs):
|
||||||
|
# For synchronous HTTP client
|
||||||
|
data = kwargs.get("data", "{}")
|
||||||
|
request_body = json.loads(data)
|
||||||
|
print("Sync Request body:", request_body)
|
||||||
|
|
||||||
|
# Verify the model is passed correctly
|
||||||
|
assert request_body.get("model") == "command-r"
|
||||||
|
|
||||||
|
# Verify the messages are formatted correctly for v2
|
||||||
|
messages = request_body.get("messages", [])
|
||||||
|
assert len(messages) > 0
|
||||||
|
assert "role" in messages[0]
|
||||||
|
assert "content" in messages[0]
|
||||||
|
|
||||||
|
# Mock response
|
||||||
|
return MockResponse(
|
||||||
|
200,
|
||||||
|
{
|
||||||
|
"text": "This is a test response",
|
||||||
|
"generation_id": "test-id",
|
||||||
|
"id": "test",
|
||||||
|
"usage": {"input_tokens": 10, "output_tokens": 20},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_async_post(*args, **kwargs):
|
||||||
|
# For asynchronous HTTP client
|
||||||
|
data = kwargs.get("data", "{}")
|
||||||
|
request_body = json.loads(data)
|
||||||
|
print("Async Request body:", request_body)
|
||||||
|
|
||||||
|
# Verify the messages are formatted correctly for v2
|
||||||
|
messages = request_body.get("messages", [])
|
||||||
|
assert len(messages) > 0
|
||||||
|
assert "role" in messages[0]
|
||||||
|
assert "content" in messages[0]
|
||||||
|
|
||||||
|
# Mock response
|
||||||
|
return MockResponse(
|
||||||
|
200,
|
||||||
|
{
|
||||||
|
"text": "This is a test response",
|
||||||
|
"generation_id": "test-id",
|
||||||
|
"id": "test",
|
||||||
|
"usage": {"input_tokens": 10, "output_tokens": 20},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock both sync and async HTTP clients
|
||||||
|
with patch("litellm.llms.custom_httpx.http_handler.HTTPHandler.post", side_effect=mock_sync_post):
|
||||||
|
with patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", new_callable=AsyncMock, side_effect=mock_async_post):
|
||||||
|
litellm.set_verbose = True
|
||||||
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
|
response = completion(
|
||||||
|
model="cohere_chat_v2/command-r",
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
assert response is not None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_completion_cohere_v2_streaming_citations():
|
||||||
|
"""
|
||||||
|
Test specifically for streaming with citations in Cohere v2
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
class MockResponse:
|
||||||
|
def __init__(self, status_code, json_data, is_stream=False):
|
||||||
|
self.status_code = status_code
|
||||||
|
self._json_data = json_data
|
||||||
|
self.headers = {}
|
||||||
|
self.is_stream = is_stream
|
||||||
|
|
||||||
|
# For streaming responses with citations
|
||||||
|
if is_stream:
|
||||||
|
# Create streaming chunks with citations at the end
|
||||||
|
self._iter_content_chunks = [
|
||||||
|
json.dumps({"text": "Emperor"}).encode(),
|
||||||
|
json.dumps({"text": " penguins"}).encode(),
|
||||||
|
json.dumps({"text": " are"}).encode(),
|
||||||
|
json.dumps({"text": " the"}).encode(),
|
||||||
|
json.dumps({"text": " tallest"}).encode(),
|
||||||
|
json.dumps({"text": " and"}).encode(),
|
||||||
|
json.dumps({"text": " they"}).encode(),
|
||||||
|
json.dumps({"text": " live"}).encode(),
|
||||||
|
json.dumps({"text": " in"}).encode(),
|
||||||
|
json.dumps({"text": " Antarctica"}).encode(),
|
||||||
|
json.dumps({"text": "."}).encode(),
|
||||||
|
# Citations in a separate chunk
|
||||||
|
json.dumps({"citations": [
|
||||||
|
{
|
||||||
|
"start": 0,
|
||||||
|
"end": 30,
|
||||||
|
"text": "Emperor penguins are the tallest",
|
||||||
|
"document_ids": ["doc1"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"start": 31,
|
||||||
|
"end": 70,
|
||||||
|
"text": "they live in Antarctica",
|
||||||
|
"document_ids": ["doc2"]
|
||||||
|
}
|
||||||
|
]}).encode(),
|
||||||
|
json.dumps({"finish_reason": "COMPLETE"}).encode(),
|
||||||
|
]
|
||||||
|
|
||||||
|
def json(self):
|
||||||
|
return self._json_data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text(self):
|
||||||
|
return json.dumps(self._json_data)
|
||||||
|
|
||||||
|
def iter_lines(self):
|
||||||
|
if self.is_stream:
|
||||||
|
for chunk in self._iter_content_chunks:
|
||||||
|
yield chunk
|
||||||
|
else:
|
||||||
|
yield json.dumps(self._json_data).encode()
|
||||||
|
|
||||||
|
async def aiter_lines(self):
|
||||||
|
if self.is_stream:
|
||||||
|
for chunk in self._iter_content_chunks:
|
||||||
|
yield chunk
|
||||||
|
else:
|
||||||
|
yield json.dumps(self._json_data).encode()
|
||||||
|
|
||||||
|
async def mock_async_post(*args, **kwargs):
|
||||||
|
# For asynchronous HTTP client
|
||||||
|
data = kwargs.get("data", "{}")
|
||||||
|
request_body = json.loads(data)
|
||||||
|
print("Async Request body:", request_body)
|
||||||
|
|
||||||
|
# Verify the messages are formatted correctly for v2
|
||||||
|
messages = request_body.get("messages", [])
|
||||||
|
assert len(messages) > 0
|
||||||
|
assert "role" in messages[0]
|
||||||
|
assert "content" in messages[0]
|
||||||
|
|
||||||
|
# Check if documents are included
|
||||||
|
documents = request_body.get("documents", [])
|
||||||
|
assert len(documents) > 0
|
||||||
|
|
||||||
|
# Verify stream is set to True
|
||||||
|
assert request_body.get("stream") == True
|
||||||
|
|
||||||
|
# Return a streaming response with citations
|
||||||
|
return MockResponse(
|
||||||
|
200,
|
||||||
|
{
|
||||||
|
"text": "Emperor penguins are the tallest penguins and they live in Antarctica.",
|
||||||
|
"generation_id": "mock-id",
|
||||||
|
"id": "mock-completion",
|
||||||
|
"usage": {"input_tokens": 10, "output_tokens": 20},
|
||||||
|
},
|
||||||
|
is_stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the async HTTP client
|
||||||
|
with patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", new_callable=AsyncMock, side_effect=mock_async_post):
|
||||||
|
litellm.set_verbose = True
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Which penguins are the tallest?",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="cohere_chat_v2/command-r",
|
||||||
|
messages=messages,
|
||||||
|
stream=True,
|
||||||
|
documents=[
|
||||||
|
{"title": "Tall penguins", "text": "Emperor penguins are the tallest."},
|
||||||
|
{
|
||||||
|
"title": "Penguin habitats",
|
||||||
|
"text": "Emperor penguins only live in Antarctica.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify we get streaming chunks with citations
|
||||||
|
citations_chunk = False
|
||||||
|
async for chunk in response:
|
||||||
|
print("received chunk", chunk)
|
||||||
|
if hasattr(chunk, "citations") or (isinstance(chunk, dict) and "citations" in chunk):
|
||||||
|
citations_chunk = True
|
||||||
|
break
|
||||||
|
assert citations_chunk, "No citations chunk was received"
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
@pytest.mark.skip(reason="Only run this test when you want to test with a real API key")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cohere_v2_real_api_call():
|
||||||
|
"""
|
||||||
|
Test for making a real API call to Cohere v2. This test is skipped by default.
|
||||||
|
To run this test, remove the skip mark and ensure you have a valid Cohere API key.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Set the API key from environment variable
|
||||||
|
os.environ["CO_API_KEY"] = "LitgtFBRwgpnyF5KAaJINtLNJkx5Ty6LsFVV1IYM" # Using the provided API key
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the capital of France?",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Make a real API call
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="cohere_chat_v2/command-r",
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Real API Response:", response)
|
||||||
|
assert response is not None
|
||||||
|
assert response.choices[0].message.content is not None
|
||||||
|
assert len(response.choices[0].message.content) > 0
|
||||||
|
|
||||||
|
# Test streaming with real API
|
||||||
|
stream_response = await litellm.acompletion(
|
||||||
|
model="cohere_chat_v2/command-r",
|
||||||
|
messages=messages,
|
||||||
|
stream=True,
|
||||||
|
max_tokens=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify we get streaming chunks
|
||||||
|
chunk_count = 0
|
||||||
|
async for chunk in stream_response:
|
||||||
|
print(f"Stream chunk: {chunk}")
|
||||||
|
chunk_count += 1
|
||||||
|
if chunk_count > 5: # Just check a few chunks to avoid long test
|
||||||
|
break
|
||||||
|
|
||||||
|
assert chunk_count > 0, "No streaming chunks were received"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred with real API call: {e}")
|
Loading…
Add table
Add a link
Reference in a new issue