(Refactor) Code Quality improvement - use Common base handler for Cohere (#7117)

* fix use new format for Cohere config

* fix base llm http handler

* Litellm code qa common config (#7116)

* feat(base_llm): initial commit for common base config class

Addresses code qa critique https://github.com/andrewyng/aisuite/issues/113#issuecomment-2512369132

* feat(base_llm/): add transform request/response abstract methods to base config class

---------

Co-authored-by: Krrish Dholakia <krrishdholakia@gmail.com>

* use base transform helpers

* use base_llm_http_handler for cohere

* working cohere using base llm handler

* add async cohere chat completion support on base handler

* fix completion code

* working sync cohere stream

* add async support cohere_chat

* fix types get_model_response_iterator

* async / sync tests cohere

* feat  cohere using base llm class

* fix linting errors

* fix _abc error

* add cohere params to transformation

* remove old cohere file

* fix type error

* fix merge conflicts

* fix cohere merge conflicts

* fix linting error

* fix litellm.llms.custom_httpx.http_handler.HTTPHandler.post

* fix passing cohere specific params

---------

Co-authored-by: Krrish Dholakia <krrishdholakia@gmail.com>
This commit is contained in:
Ishaan Jaff 2024-12-09 17:45:29 -08:00 committed by GitHub
parent 5bbf906c83
commit ff7c95694d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 933 additions and 720 deletions

View file

@ -1752,6 +1752,7 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "text-completion-openai"
or self.custom_llm_provider == "text-completion-codestral"
or self.custom_llm_provider == "azure_text"
or self.custom_llm_provider == "cohere_chat"
or self.custom_llm_provider == "anthropic"
or self.custom_llm_provider == "anthropic_text"
or self.custom_llm_provider == "huggingface"

View file

@ -15,11 +15,11 @@ from litellm.types.utils import ModelResponse
from ..common_utils import OpenAIError
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LoggingClass = LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LoggingClass = Any
LiteLLMLoggingObj = Any
class OpenAIGPTConfig(BaseConfig):
@ -189,12 +189,12 @@ class OpenAIGPTConfig(BaseConfig):
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LoggingClass,
api_key: str,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
encoding: Any,
encoding: str,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
"""
@ -216,10 +216,10 @@ class OpenAIGPTConfig(BaseConfig):
def validate_environment(
self,
api_key: str,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
) -> dict:
raise NotImplementedError

View file

@ -564,11 +564,11 @@ class AnthropicConfig(BaseConfig):
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LoggingClass,
api_key: str,
request_data: Dict,
messages: List[AllMessageValues],
optional_params: Dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
_hidden_params: Dict = {}
@ -721,7 +721,7 @@ class AnthropicConfig(BaseConfig):
return messages
def get_error_class(
self, error_message: str, status_code: int, headers: Dict
self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers]
) -> BaseLLMException:
return AnthropicError(
status_code=status_code,
@ -731,11 +731,11 @@ class AnthropicConfig(BaseConfig):
def validate_environment(
self,
api_key: str,
headers: Dict,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
) -> Dict:
if api_key is None:
raise litellm.AuthenticationError(

View file

@ -4,7 +4,16 @@ Common base config for all LLM providers
import types
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Callable,
Iterator,
List,
Optional,
Union,
)
import httpx
@ -12,11 +21,11 @@ from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LoggingClass = LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LoggingClass = Any
LiteLLMLoggingObj = Any
class BaseLLMException(Exception):
@ -78,11 +87,11 @@ class BaseConfig(ABC):
@abstractmethod
def validate_environment(
self,
api_key: str,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
) -> dict:
pass
@ -109,21 +118,26 @@ class BaseConfig(ABC):
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LoggingClass,
api_key: str,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
encoding: Any,
encoding: str,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
pass
@abstractmethod
def get_error_class(
self,
error_message: str,
status_code: int,
headers: dict,
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
pass
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str]],
sync_stream: bool,
json_mode: Optional[bool] = False,
) -> Any:
pass

View file

@ -1,5 +1,5 @@
import types
from typing import TYPE_CHECKING, Any, List, Optional
from typing import TYPE_CHECKING, Any, List, Optional, Union
import httpx
@ -104,11 +104,11 @@ class ClarifaiConfig(BaseConfig):
def validate_environment(
self,
api_key: str,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
) -> dict:
headers = {
"accept": "application/json",
@ -125,7 +125,7 @@ class ClarifaiConfig(BaseConfig):
raise NotImplementedError
def get_error_class(
self, error_message: str, status_code: int, headers: dict
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return ClarifaiError(message=error_message, status_code=status_code)
@ -135,11 +135,11 @@ class ClarifaiConfig(BaseConfig):
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LoggingClass,
api_key: str,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
encoding: str,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> litellm.ModelResponse:
logging_obj.post_call(

View file

@ -1,453 +0,0 @@
import json
import os
import time
import traceback
import types
from enum import Enum
from typing import Any, Callable, List, Optional, Tuple, Union
import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.cohere import ToolResultObject
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import (
ChatCompletionToolCallChunk,
ChatCompletionUsageBlock,
GenericStreamingChunk,
)
from litellm.utils import Choices, Message, ModelResponse, Usage
from ...prompt_templates.factory import cohere_message_pt, cohere_messages_pt_v2
from .transformation import CohereChatConfig, CohereError
def translate_openai_tool_to_cohere(openai_tool):
# cohere tools look like this
"""
{
"name": "query_daily_sales_report",
"description": "Connects to a database to retrieve overall sales volumes and sales information for a given day.",
"parameter_definitions": {
"day": {
"description": "Retrieves sales data for this day, formatted as YYYY-MM-DD.",
"type": "str",
"required": True
}
}
}
"""
# OpenAI tools look like this
"""
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
"""
cohere_tool = {
"name": openai_tool["function"]["name"],
"description": openai_tool["function"]["description"],
"parameter_definitions": {},
}
for param_name, param_def in openai_tool["function"]["parameters"][
"properties"
].items():
required_params = (
openai_tool.get("function", {}).get("parameters", {}).get("required", [])
)
cohere_param_def = {
"description": param_def.get("description", ""),
"type": param_def.get("type", ""),
"required": param_name in required_params,
}
cohere_tool["parameter_definitions"][param_name] = cohere_param_def
return cohere_tool
def construct_cohere_tool(tools=None):
if tools is None:
tools = []
cohere_tools = []
for tool in tools:
cohere_tool = translate_openai_tool_to_cohere(tool)
cohere_tools.append(cohere_tool)
return cohere_tools
async def make_call(
client: Optional[AsyncHTTPHandler],
api_base: str,
headers: dict,
data: str,
model: str,
messages: list,
logging_obj,
timeout: Optional[Union[float, httpx.Timeout]],
json_mode: bool,
) -> Tuple[Any, httpx.Headers]:
if client is None:
client = litellm.module_level_aclient
try:
response = await client.post(
api_base, headers=headers, data=data, stream=True, timeout=timeout
)
except httpx.HTTPStatusError as e:
error_headers = getattr(e, "headers", None)
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise CohereError(
status_code=e.response.status_code,
message=await e.response.aread(),
)
except Exception as e:
for exception in litellm.LITELLM_EXCEPTION_TYPES:
if isinstance(e, exception):
raise e
raise CohereError(status_code=500, message=str(e))
completion_stream = ModelResponseIterator(
streaming_response=response.aiter_lines(),
sync_stream=False,
json_mode=json_mode,
)
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=completion_stream, # Pass the completion stream for logging
additional_args={"complete_input_dict": data},
)
return completion_stream, response.headers
def make_sync_call(
client: Optional[HTTPHandler],
api_base: str,
headers: dict,
data: str,
model: str,
messages: list,
logging_obj,
timeout: Optional[Union[float, httpx.Timeout]],
) -> Tuple[Any, httpx.Headers]:
if client is None:
client = litellm.module_level_client # re-use a module level client
try:
response = client.post(
api_base, headers=headers, data=data, stream=True, timeout=timeout
)
except httpx.HTTPStatusError as e:
raise CohereError(
status_code=e.response.status_code,
message=e.response.read(),
)
except Exception as e:
for exception in litellm.LITELLM_EXCEPTION_TYPES:
if isinstance(e, exception):
raise e
raise CohereError(status_code=500, message=str(e))
if response.status_code != 200:
raise CohereError(
status_code=response.status_code,
message=response.read(),
)
completion_stream = ModelResponseIterator(
streaming_response=response.iter_lines(), sync_stream=True
)
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response="first stream response received",
additional_args={"complete_input_dict": data},
)
return completion_stream, response.headers
def completion( # noqa: PLR0915
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
optional_params: dict,
headers: dict,
encoding,
api_key,
logging_obj,
litellm_params=None,
logger_fn=None,
client=None,
timeout=None,
):
headers = litellm.CohereChatConfig().validate_environment(
api_key=api_key,
headers=headers,
model=model,
messages=messages,
optional_params=optional_params,
)
completion_url = api_base
model = model
most_recent_message, chat_history = cohere_messages_pt_v2(
messages=messages, model=model, llm_provider="cohere_chat"
)
## Load Config
config = litellm.CohereConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
## Handle Tool Calling
if "tools" in optional_params:
_is_function_call = True
cohere_tools = construct_cohere_tool(tools=optional_params["tools"])
optional_params["tools"] = cohere_tools
if isinstance(most_recent_message, dict):
optional_params["tool_results"] = [most_recent_message]
elif isinstance(most_recent_message, str):
optional_params["message"] = most_recent_message
## check if chat history message is 'user' and 'tool_results' is given -> force_single_step=True, else cohere api fails
if len(chat_history) > 0 and chat_history[-1]["role"] == "USER":
optional_params["force_single_step"] = True
data = {
"model": model,
"chat_history": chat_history,
**optional_params,
}
## LOGGING
logging_obj.pre_call(
input=most_recent_message,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": completion_url,
},
)
## COMPLETION CALL
response = requests.post(
completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False,
)
## error handling for cohere calls
if response.status_code != 200:
raise CohereError(message=response.text, status_code=response.status_code)
if "stream" in optional_params and optional_params["stream"] is True:
completion_stream, cohere_headers = make_sync_call(
client=client,
api_base=api_base,
headers=headers, # type: ignore
data=json.dumps(data),
model=model,
messages=messages,
logging_obj=logging_obj,
timeout=timeout,
)
return CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="cohere_chat",
logging_obj=logging_obj,
_response_headers=dict(cohere_headers),
)
else:
## LOGGING
logging_obj.post_call(
input=most_recent_message,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
completion_response = response.json()
try:
model_response.choices[0].message.content = completion_response["text"] # type: ignore
except Exception:
raise CohereError(message=response.text, status_code=response.status_code)
## ADD CITATIONS
if "citations" in completion_response:
setattr(model_response, "citations", completion_response["citations"])
## Tool calling response
cohere_tools_response = completion_response.get("tool_calls", None)
if cohere_tools_response is not None and cohere_tools_response != []:
# convert cohere_tools_response to OpenAI response format
tool_calls = []
for tool in cohere_tools_response:
function_name = tool.get("name", "")
generation_id = tool.get("generation_id", "")
parameters = tool.get("parameters", {})
tool_call = {
"id": f"call_{generation_id}",
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(parameters),
},
}
tool_calls.append(tool_call)
_message = litellm.Message(
tool_calls=tool_calls,
content=None,
)
model_response.choices[0].message = _message # type: ignore
## CALCULATING USAGE - use cohere `billed_units` for returning usage
billed_units = completion_response.get("meta", {}).get("billed_units", {})
prompt_tokens = billed_units.get("input_tokens", 0)
completion_tokens = billed_units.get("output_tokens", 0)
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
class ModelResponseIterator:
def __init__(
self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
):
self.streaming_response = streaming_response
self.response_iterator = self.streaming_response
self.content_blocks: List = []
self.tool_index = -1
self.json_mode = json_mode
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try:
text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None
is_finished = False
finish_reason = ""
usage: Optional[ChatCompletionUsageBlock] = None
provider_specific_fields = None
index = int(chunk.get("index", 0))
if "text" in chunk:
text = chunk["text"]
elif "is_finished" in chunk and chunk["is_finished"] is True:
is_finished = chunk["is_finished"]
finish_reason = chunk["finish_reason"]
if "citations" in chunk:
provider_specific_fields = {"citations": chunk["citations"]}
returned_chunk = GenericStreamingChunk(
text=text,
tool_use=tool_use,
is_finished=is_finished,
finish_reason=finish_reason,
usage=usage,
index=index,
provider_specific_fields=provider_specific_fields,
)
return returned_chunk
except json.JSONDecodeError:
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
# Sync iterator
def __iter__(self):
return self
def __next__(self):
try:
chunk = self.response_iterator.__next__()
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
str_line = chunk
if isinstance(chunk, bytes): # Handle binary data
str_line = chunk.decode("utf-8") # Convert bytes to string
index = str_line.find("data:")
if index != -1:
str_line = str_line[index:]
data_json = json.loads(str_line)
return self.chunk_parser(chunk=data_json)
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
# Async iterator
def __aiter__(self):
self.async_response_iterator = self.streaming_response.__aiter__()
return self
async def __anext__(self):
try:
chunk = await self.async_response_iterator.__anext__()
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
str_line = chunk
if isinstance(chunk, bytes): # Handle binary data
str_line = chunk.decode("utf-8") # Convert bytes to string
index = str_line.find("data:")
if index != -1:
str_line = str_line[index:]
data_json = json.loads(str_line)
return self.chunk_parser(chunk=data_json)
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")

View file

@ -1,21 +1,45 @@
import types
from typing import TYPE_CHECKING, Any, List, Optional
import json
import time
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union
import httpx
import litellm
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.llms.prompt_templates.factory import cohere_messages_pt_v2
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse
from ..common_utils import CohereError
from ..common_utils import validate_environment as cohere_validate_environment
from litellm.types.utils import (
ChatCompletionToolCallChunk,
ChatCompletionUsageBlock,
GenericStreamingChunk,
ModelResponse,
Usage,
)
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LoggingObj = LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LoggingObj = Any
LiteLLMLoggingObj = Any
class CohereError(BaseLLMException):
def __init__(
self,
status_code: int,
message: str,
headers: Optional[httpx.Headers] = None,
):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://api.cohere.ai/v1/chat")
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
status_code=status_code,
message=message,
headers=headers,
)
class CohereChatConfig(BaseConfig):
@ -88,19 +112,36 @@ class CohereChatConfig(BaseConfig):
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return super().get_config()
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
) -> dict:
"""
Return headers to use for cohere chat completion request
def _transform_messages(
self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
raise NotImplementedError
def get_error_class(
self, error_message: str, status_code: int, headers: dict
) -> BaseLLMException:
return CohereError(status_code=status_code, message=error_message)
Cohere API Ref: https://docs.cohere.com/reference/chat
Expected headers:
{
"Request-Source": "unspecified:litellm",
"accept": "application/json",
"content-type": "application/json",
"Authorization": "bearer $CO_API_KEY"
}
"""
headers.update(
{
"Request-Source": "unspecified:litellm",
"accept": "application/json",
"content-type": "application/json",
}
)
if api_key:
headers["Authorization"] = f"bearer {api_key}"
return headers
def get_supported_openai_params(self, model: str) -> List[str]:
return [
@ -156,29 +197,293 @@ class CohereChatConfig(BaseConfig):
litellm_params: dict,
headers: dict,
) -> dict:
raise NotImplementedError
## Load Config
for k, v in litellm.CohereChatConfig.get_config().items():
if (
k not in optional_params
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
most_recent_message, chat_history = cohere_messages_pt_v2(
messages=messages, model=model, llm_provider="cohere_chat"
)
## Handle Tool Calling
if "tools" in optional_params:
_is_function_call = True
cohere_tools = self._construct_cohere_tool(tools=optional_params["tools"])
optional_params["tools"] = cohere_tools
if isinstance(most_recent_message, dict):
optional_params["tool_results"] = [most_recent_message]
elif isinstance(most_recent_message, str):
optional_params["message"] = most_recent_message
## check if chat history message is 'user' and 'tool_results' is given -> force_single_step=True, else cohere api fails
if len(chat_history) > 0 and chat_history[-1]["role"] == "USER":
optional_params["force_single_step"] = True
return optional_params
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LoggingObj,
api_key: str,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
encoding: Any,
encoding: str,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
try:
raw_response_json = raw_response.json()
model_response.choices[0].message.content = raw_response_json["text"] # type: ignore
except Exception:
raise CohereError(
message=raw_response.text, status_code=raw_response.status_code
)
## ADD CITATIONS
if "citations" in raw_response_json:
setattr(model_response, "citations", raw_response_json["citations"])
## Tool calling response
cohere_tools_response = raw_response_json.get("tool_calls", None)
if cohere_tools_response is not None and cohere_tools_response != []:
# convert cohere_tools_response to OpenAI response format
tool_calls = []
for tool in cohere_tools_response:
function_name = tool.get("name", "")
generation_id = tool.get("generation_id", "")
parameters = tool.get("parameters", {})
tool_call = {
"id": f"call_{generation_id}",
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(parameters),
},
}
tool_calls.append(tool_call)
_message = litellm.Message(
tool_calls=tool_calls,
content=None,
)
model_response.choices[0].message = _message # type: ignore
## CALCULATING USAGE - use cohere `billed_units` for returning usage
billed_units = raw_response_json.get("meta", {}).get("billed_units", {})
prompt_tokens = billed_units.get("input_tokens", 0)
completion_tokens = billed_units.get("output_tokens", 0)
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def _construct_cohere_tool(
self,
tools: Optional[list] = None,
):
if tools is None:
tools = []
cohere_tools = []
for tool in tools:
cohere_tool = self._translate_openai_tool_to_cohere(tool)
cohere_tools.append(cohere_tool)
return cohere_tools
def _translate_openai_tool_to_cohere(
self,
openai_tool: dict,
):
# cohere tools look like this
"""
{
"name": "query_daily_sales_report",
"description": "Connects to a database to retrieve overall sales volumes and sales information for a given day.",
"parameter_definitions": {
"day": {
"description": "Retrieves sales data for this day, formatted as YYYY-MM-DD.",
"type": "str",
"required": True
}
}
}
"""
# OpenAI tools look like this
"""
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
"""
cohere_tool = {
"name": openai_tool["function"]["name"],
"description": openai_tool["function"]["description"],
"parameter_definitions": {},
}
for param_name, param_def in openai_tool["function"]["parameters"][
"properties"
].items():
required_params = (
openai_tool.get("function", {})
.get("parameters", {})
.get("required", [])
)
cohere_param_def = {
"description": param_def.get("description", ""),
"type": param_def.get("type", ""),
"required": param_name in required_params,
}
cohere_tool["parameter_definitions"][param_name] = cohere_param_def
return cohere_tool
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str]],
sync_stream: bool,
json_mode: Optional[bool] = False,
):
return ModelResponseIterator(
streaming_response=streaming_response,
sync_stream=sync_stream,
json_mode=json_mode,
)
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return CohereError(status_code=status_code, message=error_message)
def _transform_messages(
self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
raise NotImplementedError
def validate_environment(
self,
api_key: str,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
) -> dict:
return cohere_validate_environment(api_key=api_key, headers=headers)
class ModelResponseIterator:
def __init__(
self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
):
self.streaming_response = streaming_response
self.response_iterator = self.streaming_response
self.content_blocks: List = []
self.tool_index = -1
self.json_mode = json_mode
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try:
text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None
is_finished = False
finish_reason = ""
usage: Optional[ChatCompletionUsageBlock] = None
provider_specific_fields = None
index = int(chunk.get("index", 0))
if "text" in chunk:
text = chunk["text"]
elif "is_finished" in chunk and chunk["is_finished"] is True:
is_finished = chunk["is_finished"]
finish_reason = chunk["finish_reason"]
if "citations" in chunk:
provider_specific_fields = {"citations": chunk["citations"]}
returned_chunk = GenericStreamingChunk(
text=text,
tool_use=tool_use,
is_finished=is_finished,
finish_reason=finish_reason,
usage=usage,
index=index,
provider_specific_fields=provider_specific_fields,
)
return returned_chunk
except json.JSONDecodeError:
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
# Sync iterator
def __iter__(self):
return self
def __next__(self):
try:
chunk = self.response_iterator.__next__()
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
str_line = chunk
if isinstance(chunk, bytes): # Handle binary data
str_line = chunk.decode("utf-8") # Convert bytes to string
index = str_line.find("data:")
if index != -1:
str_line = str_line[index:]
data_json = json.loads(str_line)
return self.chunk_parser(chunk=data_json)
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
# Async iterator
def __aiter__(self):
self.async_response_iterator = self.streaming_response.__aiter__()
return self
async def __anext__(self):
try:
chunk = await self.async_response_iterator.__anext__()
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
str_line = chunk
if isinstance(chunk, bytes): # Handle binary data
str_line = chunk.decode("utf-8") # Convert bytes to string
index = str_line.find("data:")
if index != -1:
str_line = str_line[index:]
data_json = json.loads(str_line)
return self.chunk_parser(chunk=data_json)
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")

View file

@ -1,3 +1,5 @@
from typing import Optional
from litellm.llms.base_llm.transformation import BaseLLMException
@ -6,7 +8,7 @@ class CohereError(BaseLLMException):
super().__init__(status_code=status_code, message=message)
def validate_environment(*, api_key: str, headers: dict) -> dict:
def validate_environment(*, api_key: Optional[str], headers: dict) -> dict:
headers.update(
{
"Request-Source": "unspecified:litellm",

View file

@ -1,13 +1,9 @@
import types
from typing import TYPE_CHECKING, Any, List, Optional
from typing import TYPE_CHECKING, Any, List, Optional, Union
import httpx
from litellm.llms.base_llm.transformation import (
BaseConfig,
BaseLLMException,
LoggingClass,
)
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse
@ -15,11 +11,11 @@ from ..common_utils import CohereError
from ..common_utils import validate_environment as cohere_validate_environment
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LoggingObj = LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LoggingObj = Any
LiteLLMLoggingObj = Any
class CohereTextConfig(BaseConfig):
@ -96,11 +92,11 @@ class CohereTextConfig(BaseConfig):
def validate_environment(
self,
api_key: str,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
) -> dict:
return cohere_validate_environment(api_key=api_key, headers=headers)
@ -111,7 +107,7 @@ class CohereTextConfig(BaseConfig):
raise NotImplementedError
def get_error_class(
self, error_message: str, status_code: int, headers: dict
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return CohereError(status_code=status_code, message=error_message)
@ -172,12 +168,12 @@ class CohereTextConfig(BaseConfig):
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LoggingObj,
api_key: str,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
encoding: Any,
encoding: str,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
raise NotImplementedError

View file

@ -0,0 +1,355 @@
import copy
import json
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Literal,
Optional,
Tuple,
Union,
)
import httpx # type: ignore
import requests # type: ignore
from openai.types.chat.chat_completion_chunk import Choice as OpenAIStreamingChoice
import litellm
import litellm.litellm_core_utils
import litellm.types
import litellm.types.utils
from litellm import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.llms.custom_httpx.http_handler import (
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class BaseLLMHTTPHandler:
async def async_completion(
self,
custom_llm_provider: str,
provider_config: BaseConfig,
api_base: str,
headers: dict,
data: dict,
timeout: Union[float, httpx.Timeout],
model: str,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
messages: list,
optional_params: dict,
encoding: str,
api_key: Optional[str] = None,
):
async_httpx_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders(custom_llm_provider)
)
try:
response = await async_httpx_client.post(
url=api_base,
headers=headers,
data=json.dumps(data),
timeout=timeout,
)
except Exception as e:
raise self._handle_error(e=e, provider_config=provider_config)
return provider_config.transform_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=data,
messages=messages,
optional_params=optional_params,
encoding=encoding,
)
def completion(
self,
model: str,
messages: list,
api_base: str,
custom_llm_provider: str,
model_response: ModelResponse,
encoding,
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
timeout: Union[float, httpx.Timeout],
litellm_params: dict,
acompletion: bool,
stream: Optional[bool] = False,
api_key: Optional[str] = None,
headers={},
):
provider_config = ProviderConfigManager.get_provider_chat_config(
model=model, provider=litellm.LlmProviders(custom_llm_provider)
)
# get config from model, custom llm provider
headers = provider_config.validate_environment(
api_key=api_key,
headers=headers,
model=model,
messages=messages,
optional_params=optional_params,
)
data = provider_config.transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"api_base": api_base,
"headers": headers,
},
)
if acompletion is True:
if stream is True:
data["stream"] = stream
return self.acompletion_stream_function(
model=model,
messages=messages,
api_base=api_base,
headers=headers,
custom_llm_provider=custom_llm_provider,
provider_config=provider_config,
timeout=timeout,
logging_obj=logging_obj,
data=data,
)
else:
return self.async_completion(
custom_llm_provider=custom_llm_provider,
provider_config=provider_config,
api_base=api_base,
headers=headers,
data=data,
timeout=timeout,
model=model,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
messages=messages,
optional_params=optional_params,
encoding=encoding,
)
if stream is True:
data["stream"] = stream
completion_stream, headers = self.make_sync_call(
provider_config=provider_config,
api_base=api_base,
headers=headers, # type: ignore
data=json.dumps(data),
model=model,
messages=messages,
logging_obj=logging_obj,
timeout=timeout,
)
return CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider=custom_llm_provider,
logging_obj=logging_obj,
)
sync_httpx_client = _get_httpx_client()
try:
response = sync_httpx_client.post(
api_base,
headers=headers,
data=json.dumps(data),
timeout=timeout,
)
except Exception as e:
raise self._handle_error(
e=e,
provider_config=provider_config,
)
return provider_config.transform_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=data,
messages=messages,
optional_params=optional_params,
encoding=encoding,
)
def make_sync_call(
self,
provider_config: BaseConfig,
api_base: str,
headers: dict,
data: str,
model: str,
messages: list,
logging_obj,
timeout: Optional[Union[float, httpx.Timeout]],
) -> Tuple[Any, httpx.Headers]:
sync_httpx_client = _get_httpx_client()
try:
response = sync_httpx_client.post(
api_base, headers=headers, data=data, stream=True, timeout=timeout
)
except httpx.HTTPStatusError as e:
raise self._handle_error(
e=e,
provider_config=provider_config,
)
except Exception as e:
for exception in litellm.LITELLM_EXCEPTION_TYPES:
if isinstance(e, exception):
raise e
raise self._handle_error(
e=e,
provider_config=provider_config,
)
if response.status_code != 200:
raise BaseLLMException(
status_code=response.status_code,
message=str(response.read()),
)
completion_stream = provider_config.get_model_response_iterator(
streaming_response=response.iter_lines(), sync_stream=True
)
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response="first stream response received",
additional_args={"complete_input_dict": data},
)
return completion_stream, response.headers
async def acompletion_stream_function(
self,
model: str,
messages: list,
api_base: str,
custom_llm_provider: str,
headers: dict,
provider_config: BaseConfig,
timeout: Union[float, httpx.Timeout],
logging_obj: LiteLLMLoggingObj,
data: dict,
):
data["stream"] = True
completion_stream, _response_headers = await self.make_async_call(
custom_llm_provider=custom_llm_provider,
provider_config=provider_config,
api_base=api_base,
headers=headers,
data=json.dumps(data),
messages=messages,
logging_obj=logging_obj,
timeout=timeout,
)
streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider=custom_llm_provider,
logging_obj=logging_obj,
)
return streamwrapper
async def make_async_call(
self,
custom_llm_provider: str,
provider_config: BaseConfig,
api_base: str,
headers: dict,
data: str,
messages: list,
logging_obj: LiteLLMLoggingObj,
timeout: Optional[Union[float, httpx.Timeout]],
) -> Tuple[Any, httpx.Headers]:
async_httpx_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders(custom_llm_provider)
)
try:
response = await async_httpx_client.post(
api_base, headers=headers, data=data, stream=True, timeout=timeout
)
except httpx.HTTPStatusError as e:
raise self._handle_error(
e=e,
provider_config=provider_config,
)
except Exception as e:
for exception in litellm.LITELLM_EXCEPTION_TYPES:
if isinstance(e, exception):
raise e
raise self._handle_error(
e=e,
provider_config=provider_config,
)
if response.status_code != 200:
raise BaseLLMException(
status_code=response.status_code,
message=str(response.read()),
)
completion_stream = provider_config.get_model_response_iterator(
streaming_response=response.aiter_lines(), sync_stream=False
)
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response="first stream response received",
additional_args={"complete_input_dict": data},
)
return completion_stream, response.headers
def _handle_error(self, e: Exception, provider_config: BaseConfig):
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
if error_response and hasattr(error_response, "text"):
error_text = getattr(error_response, "text", error_text)
raise provider_config.error_class( # type: ignore
message=error_text,
status_code=status_code,
headers=error_headers,
)
def embedding(self):
pass

View file

@ -111,9 +111,9 @@ from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
from .llms.bedrock.embed.embedding import BedrockEmbedding
from .llms.bedrock.image.image_handler import BedrockImageGeneration
from .llms.clarifai.chat import handler
from .llms.cohere.chat import handler as cohere_chat
from .llms.cohere.completion import completion as cohere_completion # type: ignore
from .llms.cohere.embed import handler as cohere_embed
from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
from .llms.databricks.chat.handler import DatabricksChatCompletion
from .llms.databricks.embed.handler import DatabricksEmbeddingHandler
@ -233,6 +233,7 @@ sagemaker_llm = SagemakerLLM()
watsonx_chat_completion = WatsonXChatHandler()
openai_like_embedding = OpenAILikeEmbeddingHandler()
databricks_embedding = DatabricksEmbeddingHandler()
base_llm_http_handler = BaseLLMHTTPHandler()
####### COMPLETION ENDPOINTS ################
@ -446,6 +447,7 @@ async def acompletion(
or custom_llm_provider == "perplexity"
or custom_llm_provider == "groq"
or custom_llm_provider == "nvidia_nim"
or custom_llm_provider == "cohere_chat"
or custom_llm_provider == "cerebras"
or custom_llm_provider == "sambanova"
or custom_llm_provider == "ai21_chat"
@ -1941,15 +1943,15 @@ def completion( # type: ignore # noqa: PLR0915
cohere_key = (
api_key
or litellm.cohere_key
or get_secret("COHERE_API_KEY")
or get_secret("CO_API_KEY")
or get_secret_str("COHERE_API_KEY")
or get_secret_str("CO_API_KEY")
or litellm.api_key
)
api_base = (
api_base
or litellm.api_base
or get_secret("COHERE_API_BASE")
or get_secret_str("COHERE_API_BASE")
or "https://api.cohere.ai/v1/chat"
)
@ -1960,32 +1962,22 @@ def completion( # type: ignore # noqa: PLR0915
if extra_headers is not None:
headers.update(extra_headers)
model_response = cohere_chat.completion(
response = base_llm_http_handler.completion(
model=model,
stream=stream,
messages=messages,
acompletion=acompletion,
api_base=api_base,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
custom_llm_provider="cohere_chat",
timeout=timeout,
headers=headers,
logger_fn=logger_fn,
encoding=encoding,
api_key=cohere_key,
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
)
# if "stream" in optional_params and optional_params["stream"] is True:
# # don't try to access stream object,
# response = CustomStreamWrapper(
# model_response,
# model,
# custom_llm_provider="cohere_chat",
# logging_obj=logging,
# _response_headers=headers,
# )
# return response
response = model_response
elif custom_llm_provider == "maritalk":
maritalk_key = (
api_key

View file

@ -57,3 +57,167 @@ async def test_chat_completion_cohere_citations(stream):
assert response.citations is not None
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_completion_cohere_command_r_plus_function_call():
litellm.set_verbose = True
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
messages = [
{
"role": "user",
"content": "What's the weather like in Boston today in Fahrenheit?",
}
]
try:
# test without max tokens
response = completion(
model="command-r-plus",
messages=messages,
tools=tools,
tool_choice="auto",
)
# Add any assertions, here to check response args
print(response)
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
assert isinstance(
response.choices[0].message.tool_calls[0].function.arguments, str
)
messages.append(
response.choices[0].message.model_dump()
) # Add assistant tool invokes
tool_result = (
'{"location": "Boston", "temperature": "72", "unit": "fahrenheit"}'
)
# Add user submitted tool results in the OpenAI format
messages.append(
{
"tool_call_id": response.choices[0].message.tool_calls[0].id,
"role": "tool",
"name": response.choices[0].message.tool_calls[0].function.name,
"content": tool_result,
}
)
# In the second response, Cohere should deduce answer from tool results
second_response = completion(
model="command-r-plus",
messages=messages,
tools=tools,
tool_choice="auto",
force_single_step=True,
)
print(second_response)
except litellm.Timeout:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# @pytest.mark.skip(reason="flaky test, times out frequently")
@pytest.mark.flaky(retries=6, delay=1)
def test_completion_cohere():
try:
# litellm.set_verbose=True
messages = [
{"role": "system", "content": "You're a good bot"},
{"role": "assistant", "content": [{"text": "2", "type": "text"}]},
{"role": "assistant", "content": [{"text": "3", "type": "text"}]},
{
"role": "user",
"content": "Hey",
},
]
response = completion(
model="command-r",
messages=messages,
)
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# FYI - cohere_chat looks quite unstable, even when testing locally
@pytest.mark.asyncio
@pytest.mark.parametrize("sync_mode", [True, False])
async def test_chat_completion_cohere(sync_mode):
try:
litellm.set_verbose = True
messages = [
{"role": "system", "content": "You're a good bot"},
{
"role": "user",
"content": "Hey",
},
]
if sync_mode is False:
response = await litellm.acompletion(
model="cohere_chat/command-r",
messages=messages,
max_tokens=10,
)
else:
response = completion(
model="cohere_chat/command-r",
messages=messages,
max_tokens=10,
)
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
@pytest.mark.parametrize("sync_mode", [False])
async def test_chat_completion_cohere_stream(sync_mode):
try:
litellm.set_verbose = True
messages = [
{"role": "system", "content": "You're a good bot"},
{
"role": "user",
"content": "Hey",
},
]
if sync_mode is False:
response = await litellm.acompletion(
model="cohere_chat/command-r",
messages=messages,
max_tokens=10,
stream=True,
)
print("async cohere stream response", response)
async for chunk in response:
print(chunk)
else:
response = completion(
model="cohere_chat/command-r",
messages=messages,
max_tokens=10,
stream=True,
)
print(response)
for chunk in response:
print(chunk)
except litellm.APIConnectionError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -411,7 +411,9 @@ def test_dynamic_drop_params(drop_params):
def test_dynamic_drop_params_e2e():
with patch("requests.post", new=MagicMock()) as mock_response:
with patch(
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post", new=MagicMock()
) as mock_response:
try:
response = litellm.completion(
model="command-r",
@ -457,7 +459,9 @@ def test_dynamic_drop_params_parallel_tool_calls():
"""
https://github.com/BerriAI/litellm/issues/4584
"""
with patch("requests.post", new=MagicMock()) as mock_response:
with patch(
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post", new=MagicMock()
) as mock_response:
try:
response = litellm.completion(
model="command-r",
@ -498,7 +502,9 @@ def test_dynamic_drop_additional_params(drop_params):
def test_dynamic_drop_additional_params_e2e():
with patch("requests.post", new=MagicMock()) as mock_response:
with patch(
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post", new=MagicMock()
) as mock_response:
try:
response = litellm.completion(
model="command-r",

View file

@ -695,79 +695,6 @@ async def test_anthropic_no_content_error():
pytest.fail(f"An unexpected error occurred - {str(e)}")
def test_completion_cohere_command_r_plus_function_call():
litellm.set_verbose = True
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
messages = [
{
"role": "user",
"content": "What's the weather like in Boston today in Fahrenheit?",
}
]
try:
# test without max tokens
response = completion(
model="command-r-plus",
messages=messages,
tools=tools,
tool_choice="auto",
)
# Add any assertions, here to check response args
print(response)
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
assert isinstance(
response.choices[0].message.tool_calls[0].function.arguments, str
)
messages.append(
response.choices[0].message.model_dump()
) # Add assistant tool invokes
tool_result = (
'{"location": "Boston", "temperature": "72", "unit": "fahrenheit"}'
)
# Add user submitted tool results in the OpenAI format
messages.append(
{
"tool_call_id": response.choices[0].message.tool_calls[0].id,
"role": "tool",
"name": response.choices[0].message.tool_calls[0].function.name,
"content": tool_result,
}
)
# In the second response, Cohere should deduce answer from tool results
second_response = completion(
model="command-r-plus",
messages=messages,
tools=tools,
tool_choice="auto",
force_single_step=True,
)
print(second_response)
except litellm.Timeout:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_parse_xml_params():
from litellm.llms.prompt_templates.factory import parse_xml_params
@ -2120,27 +2047,6 @@ def test_ollama_image():
# hf_test_error_logs()
# def test_completion_cohere(): # commenting out,for now as the cohere endpoint is being flaky
# try:
# litellm.CohereConfig(max_tokens=10, stop_sequences=["a"])
# response = completion(
# model="command-nightly", messages=messages, logger_fn=logger_fn
# )
# # Add any assertions here to check the response
# print(response)
# response_str = response["choices"][0]["message"]["content"]
# response_str_2 = response.choices[0].message.content
# if type(response_str) != str:
# pytest.fail(f"Error occurred: {e}")
# if type(response_str_2) != str:
# pytest.fail(f"Error occurred: {e}")
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# test_completion_cohere()
def test_completion_openai():
try:
litellm.set_verbose = True
@ -3550,9 +3456,6 @@ def test_completion_bedrock_titan_null_response():
# test_completion_bedrock_claude()
# test_completion_bedrock_cohere()
# def test_completion_bedrock_claude_stream():
# print("calling claude")
# litellm.set_verbose = False
@ -3722,78 +3625,6 @@ def test_completion_anyscale_api():
# test_completion_anyscale_api()
# @pytest.mark.skip(reason="flaky test, times out frequently")
@pytest.mark.flaky(retries=6, delay=1)
def test_completion_cohere():
try:
# litellm.set_verbose=True
messages = [
{"role": "system", "content": "You're a good bot"},
{"role": "assistant", "content": [{"text": "2", "type": "text"}]},
{"role": "assistant", "content": [{"text": "3", "type": "text"}]},
{
"role": "user",
"content": "Hey",
},
]
response = completion(
model="command-r",
messages=messages,
extra_headers={"Helicone-Property-Locale": "ko"},
)
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# FYI - cohere_chat looks quite unstable, even when testing locally
def test_chat_completion_cohere():
try:
litellm.set_verbose = True
messages = [
{"role": "system", "content": "You're a good bot"},
{
"role": "user",
"content": "Hey",
},
]
response = completion(
model="cohere_chat/command-r",
messages=messages,
max_tokens=10,
)
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_chat_completion_cohere_stream():
try:
litellm.set_verbose = False
messages = [
{"role": "system", "content": "You're a good bot"},
{
"role": "user",
"content": "Hey",
},
]
response = completion(
model="cohere_chat/command-r",
messages=messages,
max_tokens=10,
stream=True,
)
print(response)
for chunk in response:
print(chunk)
except litellm.APIConnectionError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_azure_cloudflare_api():
litellm.set_verbose = True
try: