mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
Merge 565329c4ac
into b82af5b826
This commit is contained in:
commit
af2d5a11b2
11 changed files with 2513 additions and 6 deletions
|
@ -383,6 +383,7 @@ open_ai_chat_completion_models: List = []
|
|||
open_ai_text_completion_models: List = []
|
||||
cohere_models: List = []
|
||||
cohere_chat_models: List = []
|
||||
cohere_v2_models: List = []
|
||||
mistral_chat_models: List = []
|
||||
text_completion_codestral_models: List = []
|
||||
anthropic_models: List = []
|
||||
|
@ -480,6 +481,8 @@ def add_known_models():
|
|||
cohere_models.append(key)
|
||||
elif value.get("litellm_provider") == "cohere_chat":
|
||||
cohere_chat_models.append(key)
|
||||
elif value.get("litellm_provider") == "cohere_v2":
|
||||
cohere_v2_models.append(key)
|
||||
elif value.get("litellm_provider") == "mistral":
|
||||
mistral_chat_models.append(key)
|
||||
elif value.get("litellm_provider") == "anthropic":
|
||||
|
@ -623,6 +626,7 @@ model_list = (
|
|||
+ open_ai_text_completion_models
|
||||
+ cohere_models
|
||||
+ cohere_chat_models
|
||||
+ cohere_v2_models
|
||||
+ anthropic_models
|
||||
+ replicate_models
|
||||
+ openrouter_models
|
||||
|
@ -674,8 +678,9 @@ provider_list: List[Union[LlmProviders, str]] = list(LlmProviders)
|
|||
models_by_provider: dict = {
|
||||
"openai": open_ai_chat_completion_models + open_ai_text_completion_models,
|
||||
"text-completion-openai": open_ai_text_completion_models,
|
||||
"cohere": cohere_models + cohere_chat_models,
|
||||
"cohere": cohere_models + cohere_chat_models + cohere_v2_models,
|
||||
"cohere_chat": cohere_chat_models,
|
||||
"cohere_v2": cohere_v2_models,
|
||||
"anthropic": anthropic_models,
|
||||
"replicate": replicate_models,
|
||||
"huggingface": huggingface_models,
|
||||
|
@ -940,6 +945,7 @@ from .llms.bedrock.embed.amazon_titan_v2_transformation import (
|
|||
AmazonTitanV2Config,
|
||||
)
|
||||
from .llms.cohere.chat.transformation import CohereChatConfig
|
||||
from .llms.cohere.chat.transformation_v2 import CohereChatConfigV2
|
||||
from .llms.bedrock.embed.cohere_transformation import BedrockCohereEmbeddingConfig
|
||||
from .llms.openai.openai import OpenAIConfig, MistralEmbeddingConfig
|
||||
from .llms.openai.image_variations.transformation import OpenAIImageVariationConfig
|
||||
|
|
|
@ -105,6 +105,7 @@ LITELLM_CHAT_PROVIDERS = [
|
|||
"text-completion-openai",
|
||||
"cohere",
|
||||
"cohere_chat",
|
||||
"cohere_v2",
|
||||
"clarifai",
|
||||
"anthropic",
|
||||
"anthropic_text",
|
||||
|
|
|
@ -23,14 +23,16 @@ def _is_non_openai_azure_model(model: str) -> bool:
|
|||
|
||||
|
||||
def handle_cohere_chat_model_custom_llm_provider(
|
||||
model: str, custom_llm_provider: Optional[str] = None
|
||||
model: str, custom_llm_provider: Optional[str] = None, api_version: Optional[str] = None
|
||||
) -> Tuple[str, Optional[str]]:
|
||||
"""
|
||||
if user sets model = "cohere/command-r" -> use custom_llm_provider = "cohere_chat"
|
||||
if api_version = "v2" -> use custom_llm_provider = "cohere_v2"
|
||||
|
||||
Args:
|
||||
model:
|
||||
custom_llm_provider:
|
||||
model: The model name
|
||||
custom_llm_provider: The custom LLM provider if specified
|
||||
api_version: The API version (v1 or v2)
|
||||
|
||||
Returns:
|
||||
model, custom_llm_provider
|
||||
|
@ -38,6 +40,9 @@ def handle_cohere_chat_model_custom_llm_provider(
|
|||
|
||||
if custom_llm_provider:
|
||||
if custom_llm_provider == "cohere" and model in litellm.cohere_chat_models:
|
||||
# Check if v2 API version is specified
|
||||
if api_version == "v2":
|
||||
return model, "cohere_v2"
|
||||
return model, "cohere_chat"
|
||||
|
||||
if "/" in model:
|
||||
|
@ -47,6 +52,9 @@ def handle_cohere_chat_model_custom_llm_provider(
|
|||
and _custom_llm_provider == "cohere"
|
||||
and _model in litellm.cohere_chat_models
|
||||
):
|
||||
# Check if v2 API version is specified
|
||||
if api_version == "v2":
|
||||
return _model, "cohere_v2"
|
||||
return _model, "cohere_chat"
|
||||
|
||||
return model, custom_llm_provider
|
||||
|
@ -122,8 +130,23 @@ def get_llm_provider( # noqa: PLR0915
|
|||
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||
|
||||
### Handle cases when custom_llm_provider is set to cohere/command-r-plus but it should use cohere_chat route
|
||||
# Extract api_version from optional_params if it exists
|
||||
api_version = None
|
||||
if litellm_params and hasattr(litellm_params, "optional_params") and litellm_params.optional_params:
|
||||
api_version = litellm_params.optional_params.get("api_version")
|
||||
|
||||
# Handle direct cohere_v2 model format
|
||||
if model.startswith("cohere_v2/"):
|
||||
model = model.replace("cohere_v2/", "")
|
||||
custom_llm_provider = "cohere_v2"
|
||||
|
||||
# For backward compatibility
|
||||
elif model.startswith("cohere_v2/"):
|
||||
model = model.replace("cohere_v2/", "")
|
||||
custom_llm_provider = "cohere_v2"
|
||||
|
||||
model, custom_llm_provider = handle_cohere_chat_model_custom_llm_provider(
|
||||
model, custom_llm_provider
|
||||
model, custom_llm_provider, api_version
|
||||
)
|
||||
|
||||
model, custom_llm_provider = handle_anthropic_text_model_custom_llm_provider(
|
||||
|
|
|
@ -2005,6 +2005,57 @@ def cohere_messages_pt_v2( # noqa: PLR0915
|
|||
return returned_message, new_messages
|
||||
|
||||
|
||||
def cohere_messages_pt_v3(messages: List, model: str, llm_provider: str):
|
||||
"""
|
||||
Format messages for Cohere v2 API
|
||||
|
||||
In v2, messages are combined in a single array with the following format:
|
||||
[
|
||||
{"role": "USER", "content": "Hello"},
|
||||
{"role": "ASSISTANT", "content": "Hi there!"},
|
||||
{"role": "USER", "content": "How are you?"}
|
||||
]
|
||||
|
||||
Returns:
|
||||
List of formatted messages in Cohere v2 format
|
||||
"""
|
||||
cohere_messages = []
|
||||
|
||||
for msg_i, message in enumerate(messages):
|
||||
role = message["role"].upper()
|
||||
|
||||
# Map OpenAI roles to Cohere v2 roles
|
||||
if role == "USER":
|
||||
pass # Keep as USER
|
||||
elif role == "ASSISTANT":
|
||||
role = "CHATBOT" # Cohere v2 uses CHATBOT instead of ASSISTANT
|
||||
elif role == "SYSTEM":
|
||||
role = "USER" # System messages are sent as USER with a special prefix
|
||||
message["content"] = f"<admin>{message['content']}</admin>"
|
||||
elif role == "TOOL":
|
||||
# Skip tool messages as they'll be handled separately with tool_results
|
||||
continue
|
||||
elif role == "FUNCTION":
|
||||
# Skip function messages as they'll be handled separately with tool_results
|
||||
continue
|
||||
|
||||
# Handle content
|
||||
content = ""
|
||||
if isinstance(message.get("content"), str):
|
||||
content = message["content"]
|
||||
elif isinstance(message.get("content"), list):
|
||||
# Handle content list (text and images)
|
||||
for item in message["content"]:
|
||||
if isinstance(item, dict):
|
||||
if item.get("type") == "text":
|
||||
content += item.get("text", "")
|
||||
|
||||
# Add message to the list
|
||||
cohere_messages.append({"role": role, "content": content})
|
||||
|
||||
return cohere_messages
|
||||
|
||||
|
||||
def cohere_message_pt(messages: list):
|
||||
tool_calls: List = get_all_tool_calls(messages=messages)
|
||||
prompt = ""
|
||||
|
|
375
litellm/llms/cohere/chat/transformation_v2.py
Normal file
375
litellm/llms/cohere/chat/transformation_v2.py
Normal file
|
@ -0,0 +1,375 @@
|
|||
"""Cohere Chat V2 API Integration Module.
|
||||
|
||||
This module provides the necessary classes and functions to interact with Cohere's V2 Chat API.
|
||||
It handles the transformation of requests and responses between LiteLLM's standard format and
|
||||
Cohere's specific API requirements.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import cohere_messages_pt_v3
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse, Usage
|
||||
|
||||
# Use absolute imports instead of relative imports
|
||||
from litellm.llms.cohere.common_utils import ModelResponseIterator as CohereModelResponseIterator
|
||||
from litellm.llms.cohere.common_utils import validate_environment as cohere_validate_environment
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class CohereErrorV2(BaseLLMException):
|
||||
"""
|
||||
Exception class for Cohere v2 API errors.
|
||||
|
||||
This class handles errors returned by the Cohere v2 API and formats them
|
||||
in a way that is consistent with the LiteLLM error handling system.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
message: str,
|
||||
headers: Optional[httpx.Headers] = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(method="POST", url="https://api.cohere.com/v2/chat")
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
status_code=status_code,
|
||||
message=message,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
||||
class CohereChatConfigV2(BaseConfig):
|
||||
"""
|
||||
Configuration class for Cohere's V2 API interface.
|
||||
|
||||
Args:
|
||||
preamble (str, optional): When specified, the default Cohere preamble will be replaced
|
||||
with the provided one.
|
||||
generation_id (str, optional): Unique identifier for the generated reply.
|
||||
conversation_id (str, optional): Creates or resumes a persisted conversation.
|
||||
prompt_truncation (str, optional): Dictates how the prompt will be constructed.
|
||||
Options: 'AUTO', 'AUTO_PRESERVE_ORDER', 'OFF'.
|
||||
connectors (List[Dict[str, str]], optional): List of connectors (e.g., web-search)
|
||||
to enrich the model's reply.
|
||||
search_queries_only (bool, optional): When true, the response will only contain a list
|
||||
of generated search queries.
|
||||
documents (List[Dict[str, str]] or List[str], optional): A list of relevant documents
|
||||
that the model can cite.
|
||||
temperature (float, optional): A non-negative float that tunes the degree of randomness
|
||||
in generation.
|
||||
max_tokens (int, optional): The maximum number of tokens the model will generate as part
|
||||
of the response.
|
||||
k (int, optional): Ensures only the top k most likely tokens are considered for generation
|
||||
at each step.
|
||||
p (float, optional): Ensures that only the most likely tokens, with total probability mass
|
||||
of p, are considered for generation.
|
||||
frequency_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
|
||||
presence_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
|
||||
tools (List[Dict[str, str]], optional): A list of available tools (functions) that the model
|
||||
may suggest invoking.
|
||||
tool_results (List[Dict[str, Any]], optional): A list of results from invoking tools.
|
||||
seed (int, optional): A seed to assist reproducibility of the model's response.
|
||||
"""
|
||||
|
||||
preamble: Optional[str] = None
|
||||
generation_id: Optional[str] = None
|
||||
conversation_id: Optional[str] = None
|
||||
prompt_truncation: Optional[str] = None
|
||||
connectors: Optional[list] = None
|
||||
search_queries_only: Optional[bool] = None
|
||||
documents: Optional[list] = None
|
||||
temperature: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
k: Optional[int] = None
|
||||
p: Optional[float] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
tools: Optional[list] = None
|
||||
tool_results: Optional[list] = None
|
||||
seed: Optional[int] = None
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
"""
|
||||
Initialize the CohereChatConfigV2 with parameters matching Cohere v2 API specification.
|
||||
|
||||
All parameters are passed as keyword arguments and set as class attributes
|
||||
if they have a non-None value. This approach allows for future API changes
|
||||
without requiring code modifications.
|
||||
|
||||
Args:
|
||||
**kwargs: Arbitrary keyword arguments matching Cohere v2 API parameters.
|
||||
See class docstring for details on supported parameters.
|
||||
"""
|
||||
# Process all keyword arguments and set as class attributes if not None
|
||||
for key, value in kwargs.items():
|
||||
if value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
# Use the api_key parameter directly
|
||||
# litellm_params is required by the base class but not used by cohere_validate_environment
|
||||
return cohere_validate_environment(
|
||||
headers=headers,
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
api_key=api_key,
|
||||
api_version="v2" # Specify v2 API version
|
||||
)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return [
|
||||
"stream",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"stop",
|
||||
"n",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"seed",
|
||||
"extra_headers",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for param, value in non_default_params.items():
|
||||
if param == "stream":
|
||||
optional_params["stream"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "max_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
if param == "n":
|
||||
optional_params["num_generations"] = value
|
||||
if param == "top_p":
|
||||
optional_params["p"] = value
|
||||
if param == "frequency_penalty":
|
||||
optional_params["frequency_penalty"] = value
|
||||
if param == "presence_penalty":
|
||||
optional_params["presence_penalty"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop_sequences"] = value
|
||||
if param == "tools":
|
||||
cohere_tools = self._construct_cohere_tool(tools=value)
|
||||
optional_params["tools"] = cohere_tools
|
||||
if param == "seed":
|
||||
optional_params["seed"] = value
|
||||
return optional_params
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
# Use the explicit parameters passed to the method
|
||||
# These variables are used by the parent class implementation
|
||||
## Load Config
|
||||
for k, v in litellm.CohereChatConfigV2.get_config().items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > cohere_config(top_k=3)
|
||||
# Allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
# In v2, messages are combined in a single array
|
||||
cohere_messages = cohere_messages_pt_v3(
|
||||
messages=messages, model=model, llm_provider="cohere_chat"
|
||||
)
|
||||
optional_params["messages"] = cohere_messages
|
||||
optional_params["model"] = model
|
||||
|
||||
## Tool Calling is now handled in map_openai_params
|
||||
|
||||
# Handle tool results if present
|
||||
if "tool_results" in optional_params and isinstance(optional_params["tool_results"], list):
|
||||
# Convert tool results to v2 format if needed
|
||||
tool_results = []
|
||||
for result in optional_params["tool_results"]:
|
||||
if isinstance(result, dict) and "content" in result:
|
||||
# Format from v1 to v2
|
||||
tool_result = {
|
||||
"tool_call_id": result.get("tool_call_id", ""),
|
||||
"output": result.get("content", ""),
|
||||
}
|
||||
tool_results.append(tool_result)
|
||||
else:
|
||||
# Already in v2 format
|
||||
tool_results.append(result)
|
||||
optional_params["tool_results"] = tool_results
|
||||
|
||||
return optional_params
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
# Get the text content from the response
|
||||
# Set the text content from the response
|
||||
# Handle both regular and streaming choices
|
||||
if hasattr(model_response.choices[0], 'message'):
|
||||
model_response.choices[0].message.content = raw_response_json.get("text", "")
|
||||
else:
|
||||
# For streaming responses
|
||||
model_response.choices[0].delta.content = raw_response_json.get("text", "")
|
||||
except Exception as exc:
|
||||
raise CohereErrorV2(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
) from exc
|
||||
|
||||
## ADD CITATIONS
|
||||
# Add citation information to the model response if available
|
||||
if "citations" in raw_response_json:
|
||||
citations = raw_response_json["citations"]
|
||||
setattr(model_response, "citations", citations)
|
||||
|
||||
## Tool calling response
|
||||
cohere_tools_response = raw_response_json.get("tool_calls", None)
|
||||
if cohere_tools_response is not None and cohere_tools_response != []:
|
||||
# convert cohere_tools_response to OpenAI response format
|
||||
tool_calls = []
|
||||
for tool in cohere_tools_response:
|
||||
function_name = tool.get("name", "")
|
||||
tool_call_id = tool.get("id", "")
|
||||
parameters = tool.get("parameters", {})
|
||||
tool_call = {
|
||||
"id": tool_call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(parameters),
|
||||
},
|
||||
}
|
||||
tool_calls.append(tool_call)
|
||||
_message = litellm.Message(
|
||||
tool_calls=tool_calls,
|
||||
content=None,
|
||||
)
|
||||
model_response.choices[0].message = _message # type: ignore
|
||||
|
||||
## CALCULATING USAGE - use cohere `billed_units` for returning usage
|
||||
billed_units = raw_response_json.get("usage", {})
|
||||
|
||||
prompt_tokens = billed_units.get("input_tokens", 0)
|
||||
completion_tokens = billed_units.get("output_tokens", 0)
|
||||
|
||||
model_response.created = int(time.time())
|
||||
model_response.model = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
def _construct_cohere_tool(
|
||||
self,
|
||||
tools: Optional[list] = None,
|
||||
):
|
||||
if tools is None:
|
||||
tools = []
|
||||
cohere_tools = []
|
||||
for tool in tools:
|
||||
cohere_tool = self._translate_openai_tool_to_cohere(tool)
|
||||
cohere_tools.append(cohere_tool)
|
||||
return cohere_tools
|
||||
|
||||
def _translate_openai_tool_to_cohere(
|
||||
self,
|
||||
openai_tool: dict,
|
||||
):
|
||||
"""
|
||||
Translates OpenAI tool format to Cohere v2 tool format
|
||||
|
||||
Cohere v2 tools look like this:
|
||||
{
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"]
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
cohere_tool = {
|
||||
"name": openai_tool["function"]["name"],
|
||||
"description": openai_tool["function"]["description"],
|
||||
"input_schema": openai_tool["function"]["parameters"],
|
||||
}
|
||||
|
||||
return cohere_tool
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
) -> Any:
|
||||
return CohereModelResponseIterator(
|
||||
streaming_response=streaming_response,
|
||||
sync_stream=sync_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return CohereErrorV2(status_code=status_code, message=error_message)
|
|
@ -21,11 +21,15 @@ def validate_environment(
|
|||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_version: Optional[str] = "v1",
|
||||
) -> dict:
|
||||
"""
|
||||
Return headers to use for cohere chat completion request
|
||||
|
||||
Cohere API Ref: https://docs.cohere.com/reference/chat
|
||||
Cohere API Ref:
|
||||
- v1: https://docs.cohere.com/reference/chat
|
||||
- v2: https://docs.cohere.com/v2/reference/chat
|
||||
|
||||
Expected headers:
|
||||
{
|
||||
"Request-Source": "unspecified:litellm",
|
||||
|
|
|
@ -2143,6 +2143,46 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
api_key=cohere_key,
|
||||
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||
)
|
||||
elif custom_llm_provider == "cohere_v2":
|
||||
cohere_key = (
|
||||
api_key
|
||||
or litellm.cohere_key
|
||||
or get_secret_str("COHERE_API_KEY")
|
||||
or get_secret_str("CO_API_KEY")
|
||||
or litellm.api_key
|
||||
)
|
||||
|
||||
api_base = (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
or get_secret_str("COHERE_API_BASE")
|
||||
or "https://api.cohere.ai/v2/chat"
|
||||
)
|
||||
|
||||
headers = headers or litellm.headers or {}
|
||||
if headers is None:
|
||||
headers = {}
|
||||
|
||||
if extra_headers is not None:
|
||||
headers.update(extra_headers)
|
||||
|
||||
response = base_llm_http_handler.completion(
|
||||
model=model,
|
||||
stream=stream,
|
||||
messages=messages,
|
||||
acompletion=acompletion,
|
||||
api_base=api_base,
|
||||
model_response=model_response,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
custom_llm_provider="cohere_v2",
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
encoding=encoding,
|
||||
api_key=cohere_key,
|
||||
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||
client=client,
|
||||
)
|
||||
elif custom_llm_provider == "maritalk":
|
||||
maritalk_key = (
|
||||
api_key
|
||||
|
|
|
@ -2040,6 +2040,7 @@ class LlmProviders(str, Enum):
|
|||
TEXT_COMPLETION_OPENAI = "text-completion-openai"
|
||||
COHERE = "cohere"
|
||||
COHERE_CHAT = "cohere_chat"
|
||||
COHERE_V2 = "cohere_v2"
|
||||
CLARIFAI = "clarifai"
|
||||
ANTHROPIC = "anthropic"
|
||||
ANTHROPIC_TEXT = "anthropic_text"
|
||||
|
|
|
@ -6409,6 +6409,8 @@ class ProviderConfigManager:
|
|||
return litellm.OpenAITextCompletionConfig()
|
||||
elif litellm.LlmProviders.COHERE_CHAT == provider:
|
||||
return litellm.CohereChatConfig()
|
||||
elif litellm.LlmProviders.COHERE_V2 == provider:
|
||||
return litellm.CohereChatConfigV2()
|
||||
elif litellm.LlmProviders.COHERE == provider:
|
||||
return litellm.CohereConfig()
|
||||
elif litellm.LlmProviders.SNOWFLAKE == provider:
|
||||
|
|
1005
tests/litellm/llms/cohere/chat/test_transformation_v2.py
Normal file
1005
tests/litellm/llms/cohere/chat/test_transformation_v2.py
Normal file
File diff suppressed because it is too large
Load diff
999
tests/llm_translation/test_cohere_v2.py
Normal file
999
tests/llm_translation/test_cohere_v2.py
Normal file
|
@ -0,0 +1,999 @@
|
|||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# For testing, make sure the COHERE_API_KEY or CO_API_KEY environment variable is set
|
||||
# You can set it before running the tests with: export COHERE_API_KEY=your_api_key
|
||||
import io
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
|
||||
litellm.num_retries = 3
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_cohere_v2_citations(stream):
|
||||
try:
|
||||
class MockResponse:
|
||||
def __init__(self, status_code, json_data, is_stream=False):
|
||||
self.status_code = status_code
|
||||
self._json_data = json_data
|
||||
self.headers = {}
|
||||
self.is_stream = is_stream
|
||||
|
||||
# For streaming responses with citations
|
||||
if is_stream:
|
||||
# Create streaming chunks with citations at the end
|
||||
self._iter_content_chunks = [
|
||||
json.dumps({"text": "Emperor"}).encode(),
|
||||
json.dumps({"text": " penguins"}).encode(),
|
||||
json.dumps({"text": " are"}).encode(),
|
||||
json.dumps({"text": " the"}).encode(),
|
||||
json.dumps({"text": " tallest"}).encode(),
|
||||
json.dumps({"text": " and"}).encode(),
|
||||
json.dumps({"text": " they"}).encode(),
|
||||
json.dumps({"text": " live"}).encode(),
|
||||
json.dumps({"text": " in"}).encode(),
|
||||
json.dumps({"text": " Antarctica"}).encode(),
|
||||
json.dumps({"text": "."}).encode(),
|
||||
# Citations in a separate chunk
|
||||
json.dumps({"citations": [
|
||||
{
|
||||
"start": 0,
|
||||
"end": 30,
|
||||
"text": "Emperor penguins are the tallest",
|
||||
"document_ids": ["doc1"]
|
||||
},
|
||||
{
|
||||
"start": 31,
|
||||
"end": 70,
|
||||
"text": "they live in Antarctica",
|
||||
"document_ids": ["doc2"]
|
||||
}
|
||||
]}).encode(),
|
||||
json.dumps({"finish_reason": "COMPLETE"}).encode(),
|
||||
]
|
||||
|
||||
def json(self):
|
||||
return self._json_data
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
return json.dumps(self._json_data)
|
||||
|
||||
def iter_lines(self):
|
||||
if self.is_stream:
|
||||
for chunk in self._iter_content_chunks:
|
||||
yield chunk
|
||||
else:
|
||||
yield json.dumps(self._json_data).encode()
|
||||
|
||||
async def aiter_lines(self):
|
||||
if self.is_stream:
|
||||
for chunk in self._iter_content_chunks:
|
||||
yield chunk
|
||||
else:
|
||||
yield json.dumps(self._json_data).encode()
|
||||
|
||||
async def mock_async_post(*args, **kwargs):
|
||||
# For asynchronous HTTP client
|
||||
data = kwargs.get("data", "{}")
|
||||
request_body = json.loads(data)
|
||||
print("Async Request body:", request_body)
|
||||
|
||||
# Verify the messages are formatted correctly for v2
|
||||
messages = request_body.get("messages", [])
|
||||
assert len(messages) > 0
|
||||
assert "role" in messages[0]
|
||||
assert "content" in messages[0]
|
||||
|
||||
# Check if documents are included
|
||||
documents = request_body.get("documents", [])
|
||||
assert len(documents) > 0
|
||||
|
||||
# Mock response with citations
|
||||
mock_response = {
|
||||
"text": "Emperor penguins are the tallest penguins and they live in Antarctica.",
|
||||
"generation_id": "mock-id",
|
||||
"id": "mock-completion",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 20},
|
||||
"citations": [
|
||||
{
|
||||
"start": 0,
|
||||
"end": 30,
|
||||
"text": "Emperor penguins are the tallest",
|
||||
"document_ids": ["doc1"]
|
||||
},
|
||||
{
|
||||
"start": 31,
|
||||
"end": 70,
|
||||
"text": "they live in Antarctica",
|
||||
"document_ids": ["doc2"]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Create a streaming response with citations
|
||||
if stream:
|
||||
return MockResponse(
|
||||
200,
|
||||
{
|
||||
"text": "Emperor penguins are the tallest penguins and they live in Antarctica.",
|
||||
"generation_id": "mock-id",
|
||||
"id": "mock-completion",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 20},
|
||||
"citations": [
|
||||
{
|
||||
"start": 0,
|
||||
"end": 30,
|
||||
"text": "Emperor penguins are the tallest",
|
||||
"document_ids": ["doc1"]
|
||||
},
|
||||
{
|
||||
"start": 31,
|
||||
"end": 70,
|
||||
"text": "they live in Antarctica",
|
||||
"document_ids": ["doc2"]
|
||||
}
|
||||
],
|
||||
"stream": True
|
||||
},
|
||||
is_stream=True
|
||||
)
|
||||
else:
|
||||
return MockResponse(200, mock_response)
|
||||
|
||||
# Mock the async HTTP client
|
||||
with patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", new_callable=AsyncMock, side_effect=mock_async_post):
|
||||
litellm.set_verbose = True
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Which penguins are the tallest?",
|
||||
},
|
||||
]
|
||||
response = await litellm.acompletion(
|
||||
model="cohere_chat_v2/command-r",
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
documents=[
|
||||
{"title": "Tall penguins", "text": "Emperor penguins are the tallest."},
|
||||
{
|
||||
"title": "Penguin habitats",
|
||||
"text": "Emperor penguins only live in Antarctica.",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
if stream:
|
||||
citations_chunk = False
|
||||
async for chunk in response:
|
||||
print("received chunk", chunk)
|
||||
if hasattr(chunk, "citations") or (isinstance(chunk, dict) and "citations" in chunk):
|
||||
citations_chunk = True
|
||||
break
|
||||
assert citations_chunk
|
||||
else:
|
||||
assert hasattr(response, "citations")
|
||||
except litellm.ServiceUnavailableError:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_completion_cohere_v2_command_r_plus_function_call():
|
||||
litellm.set_verbose = True
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Boston today in Fahrenheit?",
|
||||
}
|
||||
]
|
||||
try:
|
||||
# test without max tokens
|
||||
response = completion(
|
||||
model="command-r-plus",
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
api_version="v2", # Specify v2 API version
|
||||
)
|
||||
# Add any assertions, here to check response args
|
||||
print(response)
|
||||
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
|
||||
assert isinstance(
|
||||
response.choices[0].message.tool_calls[0].function.arguments, str
|
||||
)
|
||||
|
||||
messages.append(
|
||||
response.choices[0].message.model_dump()
|
||||
) # Add assistant tool invokes
|
||||
tool_result = (
|
||||
'{"location": "Boston", "temperature": "72", "unit": "fahrenheit"}'
|
||||
)
|
||||
# Add user submitted tool results in the OpenAI format
|
||||
messages.append(
|
||||
{
|
||||
"tool_call_id": response.choices[0].message.tool_calls[0].id,
|
||||
"role": "tool",
|
||||
"name": response.choices[0].message.tool_calls[0].function.name,
|
||||
"content": tool_result,
|
||||
}
|
||||
)
|
||||
# In the second response, Cohere should deduce answer from tool results
|
||||
second_response = completion(
|
||||
model="command-r-plus",
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
force_single_step=True,
|
||||
api_version="v2", # Specify v2 API version
|
||||
)
|
||||
print(second_response)
|
||||
except litellm.Timeout:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.flaky(retries=6, delay=1)
|
||||
def test_completion_cohere_v2():
|
||||
try:
|
||||
# litellm.set_verbose=True
|
||||
messages = [
|
||||
{"role": "system", "content": "You're a good bot"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hey",
|
||||
},
|
||||
]
|
||||
response = completion(
|
||||
model="command-r",
|
||||
messages=messages,
|
||||
api_version="v2", # Specify v2 API version
|
||||
)
|
||||
print(response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
async def test_chat_completion_cohere_v2(sync_mode):
|
||||
try:
|
||||
class MockResponse:
|
||||
def __init__(self, status_code, json_data, is_stream=False):
|
||||
self.status_code = status_code
|
||||
self._json_data = json_data
|
||||
self.headers = {}
|
||||
self.is_stream = is_stream
|
||||
|
||||
# For streaming responses with citations
|
||||
if is_stream:
|
||||
# Create streaming chunks with citations at the end
|
||||
self._iter_content_chunks = [
|
||||
json.dumps({"text": "Emperor"}).encode(),
|
||||
json.dumps({"text": " penguins"}).encode(),
|
||||
json.dumps({"text": " are"}).encode(),
|
||||
json.dumps({"text": " the"}).encode(),
|
||||
json.dumps({"text": " tallest"}).encode(),
|
||||
json.dumps({"text": " and"}).encode(),
|
||||
json.dumps({"text": " they"}).encode(),
|
||||
json.dumps({"text": " live"}).encode(),
|
||||
json.dumps({"text": " in"}).encode(),
|
||||
json.dumps({"text": " Antarctica"}).encode(),
|
||||
json.dumps({"text": "."}).encode(),
|
||||
# Citations in a separate chunk
|
||||
json.dumps({"citations": [
|
||||
{
|
||||
"start": 0,
|
||||
"end": 30,
|
||||
"text": "Emperor penguins are the tallest",
|
||||
"document_ids": ["doc1"]
|
||||
},
|
||||
{
|
||||
"start": 31,
|
||||
"end": 70,
|
||||
"text": "they live in Antarctica",
|
||||
"document_ids": ["doc2"]
|
||||
}
|
||||
]}).encode(),
|
||||
json.dumps({"finish_reason": "COMPLETE"}).encode(),
|
||||
]
|
||||
|
||||
def json(self):
|
||||
return self._json_data
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
return json.dumps(self._json_data)
|
||||
|
||||
def iter_lines(self):
|
||||
if self.is_stream:
|
||||
for chunk in self._iter_content_chunks:
|
||||
yield chunk
|
||||
else:
|
||||
yield json.dumps(self._json_data).encode()
|
||||
|
||||
async def aiter_lines(self):
|
||||
if self.is_stream:
|
||||
for chunk in self._iter_content_chunks:
|
||||
yield chunk
|
||||
else:
|
||||
yield json.dumps(self._json_data).encode()
|
||||
|
||||
def mock_sync_post(*args, **kwargs):
|
||||
# For synchronous HTTP client
|
||||
data = kwargs.get("data", "{}")
|
||||
request_body = json.loads(data)
|
||||
print("Sync Request body:", request_body)
|
||||
|
||||
# Verify the model is passed correctly
|
||||
assert request_body.get("model") == "command-r"
|
||||
|
||||
# Verify max_tokens is passed correctly
|
||||
assert request_body.get("max_tokens") == 10
|
||||
|
||||
# Verify the messages are formatted correctly for v2
|
||||
messages = request_body.get("messages", [])
|
||||
assert len(messages) > 0
|
||||
assert "role" in messages[0]
|
||||
assert "content" in messages[0]
|
||||
|
||||
# Mock response
|
||||
return MockResponse(
|
||||
200,
|
||||
{
|
||||
"text": "This is a mocked response for sync request",
|
||||
"generation_id": "mock-id",
|
||||
"id": "mock-completion",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 20},
|
||||
},
|
||||
)
|
||||
|
||||
async def mock_async_post(*args, **kwargs):
|
||||
# For asynchronous HTTP client
|
||||
data = kwargs.get("data", "{}")
|
||||
request_body = json.loads(data)
|
||||
print("Async Request body:", request_body)
|
||||
|
||||
# Verify the model is passed correctly
|
||||
assert request_body.get("model") == "command-r"
|
||||
|
||||
# Verify max_tokens is passed correctly
|
||||
assert request_body.get("max_tokens") == 10
|
||||
|
||||
# Verify the messages are formatted correctly for v2
|
||||
messages = request_body.get("messages", [])
|
||||
assert len(messages) > 0
|
||||
assert "role" in messages[0]
|
||||
assert "content" in messages[0]
|
||||
|
||||
# Mock response
|
||||
return MockResponse(
|
||||
200,
|
||||
{
|
||||
"text": "This is a mocked response for async request",
|
||||
"generation_id": "mock-id",
|
||||
"id": "mock-completion",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 20},
|
||||
},
|
||||
)
|
||||
|
||||
# Mock both sync and async HTTP clients
|
||||
with patch("litellm.llms.custom_httpx.http_handler.HTTPHandler.post", side_effect=mock_sync_post):
|
||||
with patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", new_callable=AsyncMock, side_effect=mock_async_post):
|
||||
litellm.set_verbose = True
|
||||
messages = [
|
||||
{"role": "system", "content": "You're a good bot"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hey",
|
||||
},
|
||||
]
|
||||
if sync_mode is False:
|
||||
response = await litellm.acompletion(
|
||||
model="cohere_chat_v2/command-r",
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
)
|
||||
else:
|
||||
response = completion(
|
||||
model="cohere_chat_v2/command-r",
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
)
|
||||
print(response)
|
||||
assert response is not None
|
||||
assert "This is a mocked response" in response.choices[0].message.content
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("sync_mode", [False])
|
||||
async def test_chat_completion_cohere_v2_stream(sync_mode):
|
||||
try:
|
||||
class MockResponse:
|
||||
def __init__(self, status_code, json_data, is_stream=False):
|
||||
self.status_code = status_code
|
||||
self._json_data = json_data
|
||||
self.headers = {}
|
||||
self.is_stream = is_stream
|
||||
|
||||
# For streaming responses
|
||||
if is_stream:
|
||||
self._iter_content_chunks = [
|
||||
json.dumps({"text": "This"}).encode(),
|
||||
json.dumps({"text": " is"}).encode(),
|
||||
json.dumps({"text": " a"}).encode(),
|
||||
json.dumps({"text": " streamed"}).encode(),
|
||||
json.dumps({"text": " response"}).encode(),
|
||||
json.dumps({"text": "."}).encode(),
|
||||
json.dumps({"finish_reason": "COMPLETE"}).encode(),
|
||||
]
|
||||
|
||||
def json(self):
|
||||
return self._json_data
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
return json.dumps(self._json_data)
|
||||
|
||||
def iter_lines(self):
|
||||
if self.is_stream:
|
||||
for chunk in self._iter_content_chunks:
|
||||
yield chunk
|
||||
else:
|
||||
yield json.dumps(self._json_data).encode()
|
||||
|
||||
async def aiter_lines(self):
|
||||
if self.is_stream:
|
||||
for chunk in self._iter_content_chunks:
|
||||
yield chunk
|
||||
else:
|
||||
yield json.dumps(self._json_data).encode()
|
||||
|
||||
async def mock_async_post(*args, **kwargs):
|
||||
# For asynchronous HTTP client
|
||||
data = kwargs.get("data", "{}")
|
||||
request_body = json.loads(data)
|
||||
print("Async Request body:", request_body)
|
||||
|
||||
# Verify the model is passed correctly
|
||||
assert request_body.get("model") == "command-r"
|
||||
|
||||
# Verify max_tokens is passed correctly
|
||||
assert request_body.get("max_tokens") == 10
|
||||
|
||||
# Verify stream is set to True
|
||||
assert request_body.get("stream") == True
|
||||
|
||||
# Verify the messages are formatted correctly for v2
|
||||
messages = request_body.get("messages", [])
|
||||
assert len(messages) > 0
|
||||
assert "role" in messages[0]
|
||||
assert "content" in messages[0]
|
||||
|
||||
# Return a streaming response
|
||||
return MockResponse(
|
||||
200,
|
||||
{
|
||||
"text": "This is a streamed response.",
|
||||
"generation_id": "mock-id",
|
||||
"id": "mock-completion",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 20},
|
||||
},
|
||||
is_stream=True
|
||||
)
|
||||
|
||||
# Mock the async HTTP client for streaming
|
||||
with patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", new_callable=AsyncMock, side_effect=mock_async_post):
|
||||
litellm.set_verbose = True
|
||||
messages = [
|
||||
{"role": "system", "content": "You're a good bot"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hey",
|
||||
},
|
||||
]
|
||||
if sync_mode is False:
|
||||
response = await litellm.acompletion(
|
||||
model="cohere_chat_v2/command-r",
|
||||
messages=messages,
|
||||
stream=True,
|
||||
max_tokens=10,
|
||||
)
|
||||
# Verify we get streaming chunks
|
||||
chunk_count = 0
|
||||
async for chunk in response:
|
||||
print(f"chunk: {chunk}")
|
||||
chunk_count += 1
|
||||
assert chunk_count > 0, "No streaming chunks were received"
|
||||
else:
|
||||
# This test is only for async mode
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_cohere_v2_mock_completion():
|
||||
"""
|
||||
Test cohere_chat_v2 completion with mocked responses to avoid API calls
|
||||
"""
|
||||
try:
|
||||
import httpx
|
||||
|
||||
class MockResponse:
|
||||
def __init__(self, status_code, json_data, is_stream=False):
|
||||
self.status_code = status_code
|
||||
self._json_data = json_data
|
||||
self.headers = {}
|
||||
self.is_stream = is_stream
|
||||
|
||||
# For streaming responses with citations
|
||||
if is_stream:
|
||||
# Create streaming chunks with citations at the end
|
||||
self._iter_content_chunks = [
|
||||
json.dumps({"text": "Emperor"}).encode(),
|
||||
json.dumps({"text": " penguins"}).encode(),
|
||||
json.dumps({"text": " are"}).encode(),
|
||||
json.dumps({"text": " the"}).encode(),
|
||||
json.dumps({"text": " tallest"}).encode(),
|
||||
json.dumps({"text": " and"}).encode(),
|
||||
json.dumps({"text": " they"}).encode(),
|
||||
json.dumps({"text": " live"}).encode(),
|
||||
json.dumps({"text": " in"}).encode(),
|
||||
json.dumps({"text": " Antarctica"}).encode(),
|
||||
json.dumps({"text": "."}).encode(),
|
||||
# Citations in a separate chunk
|
||||
json.dumps({"citations": [
|
||||
{
|
||||
"start": 0,
|
||||
"end": 30,
|
||||
"text": "Emperor penguins are the tallest",
|
||||
"document_ids": ["doc1"]
|
||||
},
|
||||
{
|
||||
"start": 31,
|
||||
"end": 70,
|
||||
"text": "they live in Antarctica",
|
||||
"document_ids": ["doc2"]
|
||||
}
|
||||
]}).encode(),
|
||||
json.dumps({"finish_reason": "COMPLETE"}).encode(),
|
||||
]
|
||||
|
||||
def json(self):
|
||||
return self._json_data
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
return json.dumps(self._json_data)
|
||||
|
||||
def iter_lines(self):
|
||||
if self.is_stream:
|
||||
for chunk in self._iter_content_chunks:
|
||||
yield chunk
|
||||
else:
|
||||
yield json.dumps(self._json_data).encode()
|
||||
|
||||
async def aiter_lines(self):
|
||||
if self.is_stream:
|
||||
for chunk in self._iter_content_chunks:
|
||||
yield chunk
|
||||
else:
|
||||
yield json.dumps(self._json_data).encode()
|
||||
|
||||
def mock_sync_post(*args, **kwargs):
|
||||
# For synchronous HTTP client
|
||||
data = kwargs.get("data", "{}")
|
||||
request_body = json.loads(data)
|
||||
print("Sync Request body:", request_body)
|
||||
|
||||
# Verify the messages are formatted correctly for v2
|
||||
messages = request_body.get("messages", [])
|
||||
assert len(messages) > 0
|
||||
assert "role" in messages[0]
|
||||
assert "content" in messages[0]
|
||||
|
||||
# Mock response
|
||||
return MockResponse(
|
||||
200,
|
||||
{
|
||||
"text": "This is a mocked response from Cohere v2 API",
|
||||
"generation_id": "mock-id",
|
||||
"id": "mock-completion",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 20},
|
||||
},
|
||||
)
|
||||
|
||||
async def mock_async_post(*args, **kwargs):
|
||||
# For asynchronous HTTP client
|
||||
data = kwargs.get("data", "{}")
|
||||
request_body = json.loads(data)
|
||||
print("Async Request body:", request_body)
|
||||
|
||||
# Verify the messages are formatted correctly for v2
|
||||
messages = request_body.get("messages", [])
|
||||
assert len(messages) > 0
|
||||
assert "role" in messages[0]
|
||||
assert "content" in messages[0]
|
||||
|
||||
# Mock response
|
||||
return MockResponse(
|
||||
200,
|
||||
{
|
||||
"text": "This is a mocked response from Cohere v2 API",
|
||||
"generation_id": "mock-id",
|
||||
"id": "mock-completion",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 20},
|
||||
},
|
||||
)
|
||||
|
||||
# Mock both sync and async HTTP clients
|
||||
with patch("litellm.llms.custom_httpx.http_handler.HTTPHandler.post", side_effect=mock_sync_post):
|
||||
with patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", new_callable=AsyncMock, side_effect=mock_async_post):
|
||||
litellm.set_verbose = True
|
||||
messages = [{"role": "user", "content": "Hello from mock test"}]
|
||||
response = completion(
|
||||
model="cohere_chat_v2/command-r",
|
||||
messages=messages,
|
||||
)
|
||||
assert response is not None
|
||||
assert "This is a mocked response" in response.choices[0].message.content
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_cohere_v2_request_body_with_allowed_params():
|
||||
"""
|
||||
Test to validate that when allowed_openai_params is provided, the request body contains
|
||||
the correct response_format and reasoning_effort values.
|
||||
"""
|
||||
try:
|
||||
import httpx
|
||||
|
||||
class MockResponse:
|
||||
def __init__(self, status_code, json_data, is_stream=False):
|
||||
self.status_code = status_code
|
||||
self._json_data = json_data
|
||||
self.headers = {}
|
||||
self.is_stream = is_stream
|
||||
|
||||
# For streaming responses with citations
|
||||
if is_stream:
|
||||
# Create streaming chunks with citations at the end
|
||||
self._iter_content_chunks = [
|
||||
json.dumps({"text": "Emperor"}).encode(),
|
||||
json.dumps({"text": " penguins"}).encode(),
|
||||
json.dumps({"text": " are"}).encode(),
|
||||
json.dumps({"text": " the"}).encode(),
|
||||
json.dumps({"text": " tallest"}).encode(),
|
||||
json.dumps({"text": " and"}).encode(),
|
||||
json.dumps({"text": " they"}).encode(),
|
||||
json.dumps({"text": " live"}).encode(),
|
||||
json.dumps({"text": " in"}).encode(),
|
||||
json.dumps({"text": " Antarctica"}).encode(),
|
||||
json.dumps({"text": "."}).encode(),
|
||||
# Citations in a separate chunk
|
||||
json.dumps({"citations": [
|
||||
{
|
||||
"start": 0,
|
||||
"end": 30,
|
||||
"text": "Emperor penguins are the tallest",
|
||||
"document_ids": ["doc1"]
|
||||
},
|
||||
{
|
||||
"start": 31,
|
||||
"end": 70,
|
||||
"text": "they live in Antarctica",
|
||||
"document_ids": ["doc2"]
|
||||
}
|
||||
]}).encode(),
|
||||
json.dumps({"finish_reason": "COMPLETE"}).encode(),
|
||||
]
|
||||
|
||||
def json(self):
|
||||
return self._json_data
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
return json.dumps(self._json_data)
|
||||
|
||||
def iter_lines(self):
|
||||
if self.is_stream:
|
||||
for chunk in self._iter_content_chunks:
|
||||
yield chunk
|
||||
else:
|
||||
yield json.dumps(self._json_data).encode()
|
||||
|
||||
async def aiter_lines(self):
|
||||
if self.is_stream:
|
||||
for chunk in self._iter_content_chunks:
|
||||
yield chunk
|
||||
else:
|
||||
yield json.dumps(self._json_data).encode()
|
||||
|
||||
def mock_sync_post(*args, **kwargs):
|
||||
# For synchronous HTTP client
|
||||
data = kwargs.get("data", "{}")
|
||||
request_body = json.loads(data)
|
||||
print("Sync Request body:", request_body)
|
||||
|
||||
# Verify the model is passed correctly
|
||||
assert request_body.get("model") == "command-r"
|
||||
|
||||
# Verify the messages are formatted correctly for v2
|
||||
messages = request_body.get("messages", [])
|
||||
assert len(messages) > 0
|
||||
assert "role" in messages[0]
|
||||
assert "content" in messages[0]
|
||||
|
||||
# Mock response
|
||||
return MockResponse(
|
||||
200,
|
||||
{
|
||||
"text": "This is a test response",
|
||||
"generation_id": "test-id",
|
||||
"id": "test",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 20},
|
||||
},
|
||||
)
|
||||
|
||||
async def mock_async_post(*args, **kwargs):
|
||||
# For asynchronous HTTP client
|
||||
data = kwargs.get("data", "{}")
|
||||
request_body = json.loads(data)
|
||||
print("Async Request body:", request_body)
|
||||
|
||||
# Verify the messages are formatted correctly for v2
|
||||
messages = request_body.get("messages", [])
|
||||
assert len(messages) > 0
|
||||
assert "role" in messages[0]
|
||||
assert "content" in messages[0]
|
||||
|
||||
# Mock response
|
||||
return MockResponse(
|
||||
200,
|
||||
{
|
||||
"text": "This is a test response",
|
||||
"generation_id": "test-id",
|
||||
"id": "test",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 20},
|
||||
},
|
||||
)
|
||||
|
||||
# Mock both sync and async HTTP clients
|
||||
with patch("litellm.llms.custom_httpx.http_handler.HTTPHandler.post", side_effect=mock_sync_post):
|
||||
with patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", new_callable=AsyncMock, side_effect=mock_async_post):
|
||||
litellm.set_verbose = True
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
response = completion(
|
||||
model="cohere_chat_v2/command-r",
|
||||
messages=messages,
|
||||
)
|
||||
assert response is not None
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_cohere_v2_streaming_citations():
|
||||
"""
|
||||
Test specifically for streaming with citations in Cohere v2
|
||||
"""
|
||||
try:
|
||||
class MockResponse:
|
||||
def __init__(self, status_code, json_data, is_stream=False):
|
||||
self.status_code = status_code
|
||||
self._json_data = json_data
|
||||
self.headers = {}
|
||||
self.is_stream = is_stream
|
||||
|
||||
# For streaming responses with citations
|
||||
if is_stream:
|
||||
# Create streaming chunks with citations at the end
|
||||
self._iter_content_chunks = [
|
||||
json.dumps({"text": "Emperor"}).encode(),
|
||||
json.dumps({"text": " penguins"}).encode(),
|
||||
json.dumps({"text": " are"}).encode(),
|
||||
json.dumps({"text": " the"}).encode(),
|
||||
json.dumps({"text": " tallest"}).encode(),
|
||||
json.dumps({"text": " and"}).encode(),
|
||||
json.dumps({"text": " they"}).encode(),
|
||||
json.dumps({"text": " live"}).encode(),
|
||||
json.dumps({"text": " in"}).encode(),
|
||||
json.dumps({"text": " Antarctica"}).encode(),
|
||||
json.dumps({"text": "."}).encode(),
|
||||
# Citations in a separate chunk
|
||||
json.dumps({"citations": [
|
||||
{
|
||||
"start": 0,
|
||||
"end": 30,
|
||||
"text": "Emperor penguins are the tallest",
|
||||
"document_ids": ["doc1"]
|
||||
},
|
||||
{
|
||||
"start": 31,
|
||||
"end": 70,
|
||||
"text": "they live in Antarctica",
|
||||
"document_ids": ["doc2"]
|
||||
}
|
||||
]}).encode(),
|
||||
json.dumps({"finish_reason": "COMPLETE"}).encode(),
|
||||
]
|
||||
|
||||
def json(self):
|
||||
return self._json_data
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
return json.dumps(self._json_data)
|
||||
|
||||
def iter_lines(self):
|
||||
if self.is_stream:
|
||||
for chunk in self._iter_content_chunks:
|
||||
yield chunk
|
||||
else:
|
||||
yield json.dumps(self._json_data).encode()
|
||||
|
||||
async def aiter_lines(self):
|
||||
if self.is_stream:
|
||||
for chunk in self._iter_content_chunks:
|
||||
yield chunk
|
||||
else:
|
||||
yield json.dumps(self._json_data).encode()
|
||||
|
||||
async def mock_async_post(*args, **kwargs):
|
||||
# For asynchronous HTTP client
|
||||
data = kwargs.get("data", "{}")
|
||||
request_body = json.loads(data)
|
||||
print("Async Request body:", request_body)
|
||||
|
||||
# Verify the messages are formatted correctly for v2
|
||||
messages = request_body.get("messages", [])
|
||||
assert len(messages) > 0
|
||||
assert "role" in messages[0]
|
||||
assert "content" in messages[0]
|
||||
|
||||
# Check if documents are included
|
||||
documents = request_body.get("documents", [])
|
||||
assert len(documents) > 0
|
||||
|
||||
# Verify stream is set to True
|
||||
assert request_body.get("stream") == True
|
||||
|
||||
# Return a streaming response with citations
|
||||
return MockResponse(
|
||||
200,
|
||||
{
|
||||
"text": "Emperor penguins are the tallest penguins and they live in Antarctica.",
|
||||
"generation_id": "mock-id",
|
||||
"id": "mock-completion",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 20},
|
||||
},
|
||||
is_stream=True
|
||||
)
|
||||
|
||||
# Mock the async HTTP client
|
||||
with patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", new_callable=AsyncMock, side_effect=mock_async_post):
|
||||
litellm.set_verbose = True
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Which penguins are the tallest?",
|
||||
},
|
||||
]
|
||||
response = await litellm.acompletion(
|
||||
model="cohere_chat_v2/command-r",
|
||||
messages=messages,
|
||||
stream=True,
|
||||
documents=[
|
||||
{"title": "Tall penguins", "text": "Emperor penguins are the tallest."},
|
||||
{
|
||||
"title": "Penguin habitats",
|
||||
"text": "Emperor penguins only live in Antarctica.",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# Verify we get streaming chunks with citations
|
||||
citations_chunk = False
|
||||
async for chunk in response:
|
||||
print("received chunk", chunk)
|
||||
if hasattr(chunk, "citations") or (isinstance(chunk, dict) and "citations" in chunk):
|
||||
citations_chunk = True
|
||||
break
|
||||
assert citations_chunk, "No citations chunk was received"
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
@pytest.mark.skip(reason="Only run this test when you want to test with a real API key")
|
||||
@pytest.mark.asyncio
|
||||
async def test_cohere_v2_real_api_call():
|
||||
"""
|
||||
Test for making a real API call to Cohere v2. This test is skipped by default.
|
||||
To run this test, remove the skip mark and ensure you have a valid Cohere API key.
|
||||
"""
|
||||
try:
|
||||
# Set the API key from environment variable
|
||||
os.environ["CO_API_KEY"] = "LitgtFBRwgpnyF5KAaJINtLNJkx5Ty6LsFVV1IYM" # Using the provided API key
|
||||
|
||||
litellm.set_verbose = True
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the capital of France?",
|
||||
},
|
||||
]
|
||||
|
||||
# Make a real API call
|
||||
response = await litellm.acompletion(
|
||||
model="cohere_chat_v2/command-r",
|
||||
messages=messages,
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
print("Real API Response:", response)
|
||||
assert response is not None
|
||||
assert response.choices[0].message.content is not None
|
||||
assert len(response.choices[0].message.content) > 0
|
||||
|
||||
# Test streaming with real API
|
||||
stream_response = await litellm.acompletion(
|
||||
model="cohere_chat_v2/command-r",
|
||||
messages=messages,
|
||||
stream=True,
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
# Verify we get streaming chunks
|
||||
chunk_count = 0
|
||||
async for chunk in stream_response:
|
||||
print(f"Stream chunk: {chunk}")
|
||||
chunk_count += 1
|
||||
if chunk_count > 5: # Just check a few chunks to avoid long test
|
||||
break
|
||||
|
||||
assert chunk_count > 0, "No streaming chunks were received"
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred with real API call: {e}")
|
Loading…
Add table
Add a link
Reference in a new issue