mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
(Refactor) Code Quality improvement - use Common base handler for Cohere (#7117)
* fix use new format for Cohere config * fix base llm http handler * Litellm code qa common config (#7116) * feat(base_llm): initial commit for common base config class Addresses code qa critique https://github.com/andrewyng/aisuite/issues/113#issuecomment-2512369132 * feat(base_llm/): add transform request/response abstract methods to base config class --------- Co-authored-by: Krrish Dholakia <krrishdholakia@gmail.com> * use base transform helpers * use base_llm_http_handler for cohere * working cohere using base llm handler * add async cohere chat completion support on base handler * fix completion code * working sync cohere stream * add async support cohere_chat * fix types get_model_response_iterator * async / sync tests cohere * feat cohere using base llm class * fix linting errors * fix _abc error * add cohere params to transformation * remove old cohere file * fix type error * fix merge conflicts * fix cohere merge conflicts * fix linting error * fix litellm.llms.custom_httpx.http_handler.HTTPHandler.post * fix passing cohere specific params --------- Co-authored-by: Krrish Dholakia <krrishdholakia@gmail.com>
This commit is contained in:
parent
5bbf906c83
commit
ff7c95694d
14 changed files with 933 additions and 720 deletions
|
@ -1752,6 +1752,7 @@ class CustomStreamWrapper:
|
|||
or self.custom_llm_provider == "text-completion-openai"
|
||||
or self.custom_llm_provider == "text-completion-codestral"
|
||||
or self.custom_llm_provider == "azure_text"
|
||||
or self.custom_llm_provider == "cohere_chat"
|
||||
or self.custom_llm_provider == "anthropic"
|
||||
or self.custom_llm_provider == "anthropic_text"
|
||||
or self.custom_llm_provider == "huggingface"
|
||||
|
|
|
@ -15,11 +15,11 @@ from litellm.types.utils import ModelResponse
|
|||
from ..common_utils import OpenAIError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LoggingClass = LiteLLMLoggingObj
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LoggingClass = Any
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class OpenAIGPTConfig(BaseConfig):
|
||||
|
@ -189,12 +189,12 @@ class OpenAIGPTConfig(BaseConfig):
|
|||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LoggingClass,
|
||||
api_key: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
encoding: Any,
|
||||
encoding: str,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
"""
|
||||
|
@ -216,10 +216,10 @@ class OpenAIGPTConfig(BaseConfig):
|
|||
|
||||
def validate_environment(
|
||||
self,
|
||||
api_key: str,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -564,11 +564,11 @@ class AnthropicConfig(BaseConfig):
|
|||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LoggingClass,
|
||||
api_key: str,
|
||||
request_data: Dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: Dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
_hidden_params: Dict = {}
|
||||
|
@ -721,7 +721,7 @@ class AnthropicConfig(BaseConfig):
|
|||
return messages
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Dict
|
||||
self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return AnthropicError(
|
||||
status_code=status_code,
|
||||
|
@ -731,11 +731,11 @@ class AnthropicConfig(BaseConfig):
|
|||
|
||||
def validate_environment(
|
||||
self,
|
||||
api_key: str,
|
||||
headers: Dict,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
) -> Dict:
|
||||
if api_key is None:
|
||||
raise litellm.AuthenticationError(
|
||||
|
|
|
@ -4,7 +4,16 @@ Common base config for all LLM providers
|
|||
|
||||
import types
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -12,11 +21,11 @@ from litellm.types.llms.openai import AllMessageValues
|
|||
from litellm.types.utils import ModelResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LoggingClass = LiteLLMLoggingObj
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LoggingClass = Any
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class BaseLLMException(Exception):
|
||||
|
@ -78,11 +87,11 @@ class BaseConfig(ABC):
|
|||
@abstractmethod
|
||||
def validate_environment(
|
||||
self,
|
||||
api_key: str,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
pass
|
||||
|
||||
|
@ -109,21 +118,26 @@ class BaseConfig(ABC):
|
|||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LoggingClass,
|
||||
api_key: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
encoding: Any,
|
||||
encoding: str,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_error_class(
|
||||
self,
|
||||
error_message: str,
|
||||
status_code: int,
|
||||
headers: dict,
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
pass
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str]],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
) -> Any:
|
||||
pass
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import types
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -104,11 +104,11 @@ class ClarifaiConfig(BaseConfig):
|
|||
|
||||
def validate_environment(
|
||||
self,
|
||||
api_key: str,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
|
@ -125,7 +125,7 @@ class ClarifaiConfig(BaseConfig):
|
|||
raise NotImplementedError
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: dict
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return ClarifaiError(message=error_message, status_code=status_code)
|
||||
|
||||
|
@ -135,11 +135,11 @@ class ClarifaiConfig(BaseConfig):
|
|||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LoggingClass,
|
||||
api_key: str,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
encoding: str,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> litellm.ModelResponse:
|
||||
logging_obj.post_call(
|
||||
|
|
|
@ -1,453 +0,0 @@
|
|||
import json
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
import types
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import httpx # type: ignore
|
||||
import requests # type: ignore
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.types.llms.cohere import ToolResultObject
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionUsageBlock,
|
||||
GenericStreamingChunk,
|
||||
)
|
||||
from litellm.utils import Choices, Message, ModelResponse, Usage
|
||||
|
||||
from ...prompt_templates.factory import cohere_message_pt, cohere_messages_pt_v2
|
||||
from .transformation import CohereChatConfig, CohereError
|
||||
|
||||
|
||||
def translate_openai_tool_to_cohere(openai_tool):
|
||||
# cohere tools look like this
|
||||
"""
|
||||
{
|
||||
"name": "query_daily_sales_report",
|
||||
"description": "Connects to a database to retrieve overall sales volumes and sales information for a given day.",
|
||||
"parameter_definitions": {
|
||||
"day": {
|
||||
"description": "Retrieves sales data for this day, formatted as YYYY-MM-DD.",
|
||||
"type": "str",
|
||||
"required": True
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
# OpenAI tools look like this
|
||||
"""
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
"""
|
||||
cohere_tool = {
|
||||
"name": openai_tool["function"]["name"],
|
||||
"description": openai_tool["function"]["description"],
|
||||
"parameter_definitions": {},
|
||||
}
|
||||
|
||||
for param_name, param_def in openai_tool["function"]["parameters"][
|
||||
"properties"
|
||||
].items():
|
||||
required_params = (
|
||||
openai_tool.get("function", {}).get("parameters", {}).get("required", [])
|
||||
)
|
||||
cohere_param_def = {
|
||||
"description": param_def.get("description", ""),
|
||||
"type": param_def.get("type", ""),
|
||||
"required": param_name in required_params,
|
||||
}
|
||||
cohere_tool["parameter_definitions"][param_name] = cohere_param_def
|
||||
|
||||
return cohere_tool
|
||||
|
||||
|
||||
def construct_cohere_tool(tools=None):
|
||||
if tools is None:
|
||||
tools = []
|
||||
cohere_tools = []
|
||||
for tool in tools:
|
||||
cohere_tool = translate_openai_tool_to_cohere(tool)
|
||||
cohere_tools.append(cohere_tool)
|
||||
return cohere_tools
|
||||
|
||||
|
||||
async def make_call(
|
||||
client: Optional[AsyncHTTPHandler],
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: str,
|
||||
model: str,
|
||||
messages: list,
|
||||
logging_obj,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
json_mode: bool,
|
||||
) -> Tuple[Any, httpx.Headers]:
|
||||
if client is None:
|
||||
client = litellm.module_level_aclient
|
||||
|
||||
try:
|
||||
response = await client.post(
|
||||
api_base, headers=headers, data=data, stream=True, timeout=timeout
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_headers is None and error_response:
|
||||
error_headers = getattr(error_response, "headers", None)
|
||||
raise CohereError(
|
||||
status_code=e.response.status_code,
|
||||
message=await e.response.aread(),
|
||||
)
|
||||
except Exception as e:
|
||||
for exception in litellm.LITELLM_EXCEPTION_TYPES:
|
||||
if isinstance(e, exception):
|
||||
raise e
|
||||
raise CohereError(status_code=500, message=str(e))
|
||||
|
||||
completion_stream = ModelResponseIterator(
|
||||
streaming_response=response.aiter_lines(),
|
||||
sync_stream=False,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response=completion_stream, # Pass the completion stream for logging
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return completion_stream, response.headers
|
||||
|
||||
|
||||
def make_sync_call(
|
||||
client: Optional[HTTPHandler],
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: str,
|
||||
model: str,
|
||||
messages: list,
|
||||
logging_obj,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
) -> Tuple[Any, httpx.Headers]:
|
||||
if client is None:
|
||||
client = litellm.module_level_client # re-use a module level client
|
||||
|
||||
try:
|
||||
response = client.post(
|
||||
api_base, headers=headers, data=data, stream=True, timeout=timeout
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise CohereError(
|
||||
status_code=e.response.status_code,
|
||||
message=e.response.read(),
|
||||
)
|
||||
except Exception as e:
|
||||
for exception in litellm.LITELLM_EXCEPTION_TYPES:
|
||||
if isinstance(e, exception):
|
||||
raise e
|
||||
raise CohereError(status_code=500, message=str(e))
|
||||
|
||||
if response.status_code != 200:
|
||||
|
||||
raise CohereError(
|
||||
status_code=response.status_code,
|
||||
message=response.read(),
|
||||
)
|
||||
|
||||
completion_stream = ModelResponseIterator(
|
||||
streaming_response=response.iter_lines(), sync_stream=True
|
||||
)
|
||||
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response="first stream response received",
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return completion_stream, response.headers
|
||||
|
||||
|
||||
def completion( # noqa: PLR0915
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
client=None,
|
||||
timeout=None,
|
||||
):
|
||||
headers = litellm.CohereChatConfig().validate_environment(
|
||||
api_key=api_key,
|
||||
headers=headers,
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
completion_url = api_base
|
||||
model = model
|
||||
most_recent_message, chat_history = cohere_messages_pt_v2(
|
||||
messages=messages, model=model, llm_provider="cohere_chat"
|
||||
)
|
||||
|
||||
## Load Config
|
||||
config = litellm.CohereConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
## Handle Tool Calling
|
||||
if "tools" in optional_params:
|
||||
_is_function_call = True
|
||||
cohere_tools = construct_cohere_tool(tools=optional_params["tools"])
|
||||
optional_params["tools"] = cohere_tools
|
||||
if isinstance(most_recent_message, dict):
|
||||
optional_params["tool_results"] = [most_recent_message]
|
||||
elif isinstance(most_recent_message, str):
|
||||
optional_params["message"] = most_recent_message
|
||||
|
||||
## check if chat history message is 'user' and 'tool_results' is given -> force_single_step=True, else cohere api fails
|
||||
if len(chat_history) > 0 and chat_history[-1]["role"] == "USER":
|
||||
optional_params["force_single_step"] = True
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"chat_history": chat_history,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=most_recent_message,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"headers": headers,
|
||||
"api_base": completion_url,
|
||||
},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
response = requests.post(
|
||||
completion_url,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
stream=optional_params["stream"] if "stream" in optional_params else False,
|
||||
)
|
||||
## error handling for cohere calls
|
||||
if response.status_code != 200:
|
||||
raise CohereError(message=response.text, status_code=response.status_code)
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
completion_stream, cohere_headers = make_sync_call(
|
||||
client=client,
|
||||
api_base=api_base,
|
||||
headers=headers, # type: ignore
|
||||
data=json.dumps(data),
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
timeout=timeout,
|
||||
)
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="cohere_chat",
|
||||
logging_obj=logging_obj,
|
||||
_response_headers=dict(cohere_headers),
|
||||
)
|
||||
else:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=most_recent_message,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
## RESPONSE OBJECT
|
||||
completion_response = response.json()
|
||||
try:
|
||||
model_response.choices[0].message.content = completion_response["text"] # type: ignore
|
||||
except Exception:
|
||||
raise CohereError(message=response.text, status_code=response.status_code)
|
||||
|
||||
## ADD CITATIONS
|
||||
if "citations" in completion_response:
|
||||
setattr(model_response, "citations", completion_response["citations"])
|
||||
|
||||
## Tool calling response
|
||||
cohere_tools_response = completion_response.get("tool_calls", None)
|
||||
if cohere_tools_response is not None and cohere_tools_response != []:
|
||||
# convert cohere_tools_response to OpenAI response format
|
||||
tool_calls = []
|
||||
for tool in cohere_tools_response:
|
||||
function_name = tool.get("name", "")
|
||||
generation_id = tool.get("generation_id", "")
|
||||
parameters = tool.get("parameters", {})
|
||||
tool_call = {
|
||||
"id": f"call_{generation_id}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(parameters),
|
||||
},
|
||||
}
|
||||
tool_calls.append(tool_call)
|
||||
_message = litellm.Message(
|
||||
tool_calls=tool_calls,
|
||||
content=None,
|
||||
)
|
||||
model_response.choices[0].message = _message # type: ignore
|
||||
|
||||
## CALCULATING USAGE - use cohere `billed_units` for returning usage
|
||||
billed_units = completion_response.get("meta", {}).get("billed_units", {})
|
||||
|
||||
prompt_tokens = billed_units.get("input_tokens", 0)
|
||||
completion_tokens = billed_units.get("output_tokens", 0)
|
||||
|
||||
model_response.created = int(time.time())
|
||||
model_response.model = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
|
||||
class ModelResponseIterator:
|
||||
def __init__(
|
||||
self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
|
||||
):
|
||||
self.streaming_response = streaming_response
|
||||
self.response_iterator = self.streaming_response
|
||||
self.content_blocks: List = []
|
||||
self.tool_index = -1
|
||||
self.json_mode = json_mode
|
||||
|
||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||
try:
|
||||
text = ""
|
||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
usage: Optional[ChatCompletionUsageBlock] = None
|
||||
provider_specific_fields = None
|
||||
|
||||
index = int(chunk.get("index", 0))
|
||||
|
||||
if "text" in chunk:
|
||||
text = chunk["text"]
|
||||
elif "is_finished" in chunk and chunk["is_finished"] is True:
|
||||
is_finished = chunk["is_finished"]
|
||||
finish_reason = chunk["finish_reason"]
|
||||
|
||||
if "citations" in chunk:
|
||||
provider_specific_fields = {"citations": chunk["citations"]}
|
||||
|
||||
returned_chunk = GenericStreamingChunk(
|
||||
text=text,
|
||||
tool_use=tool_use,
|
||||
is_finished=is_finished,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
index=index,
|
||||
provider_specific_fields=provider_specific_fields,
|
||||
)
|
||||
|
||||
return returned_chunk
|
||||
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
|
||||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
chunk = self.response_iterator.__next__()
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
str_line = chunk
|
||||
if isinstance(chunk, bytes): # Handle binary data
|
||||
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||
index = str_line.find("data:")
|
||||
if index != -1:
|
||||
str_line = str_line[index:]
|
||||
data_json = json.loads(str_line)
|
||||
return self.chunk_parser(chunk=data_json)
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||
|
||||
# Async iterator
|
||||
def __aiter__(self):
|
||||
self.async_response_iterator = self.streaming_response.__aiter__()
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
chunk = await self.async_response_iterator.__anext__()
|
||||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
str_line = chunk
|
||||
if isinstance(chunk, bytes): # Handle binary data
|
||||
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||
index = str_line.find("data:")
|
||||
if index != -1:
|
||||
str_line = str_line[index:]
|
||||
|
||||
data_json = json.loads(str_line)
|
||||
return self.chunk_parser(chunk=data_json)
|
||||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
|
@ -1,21 +1,45 @@
|
|||
import types
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
import json
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.prompt_templates.factory import cohere_messages_pt_v2
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
from ..common_utils import CohereError
|
||||
from ..common_utils import validate_environment as cohere_validate_environment
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionUsageBlock,
|
||||
GenericStreamingChunk,
|
||||
ModelResponse,
|
||||
Usage,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LoggingObj = LiteLLMLoggingObj
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LoggingObj = Any
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class CohereError(BaseLLMException):
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
message: str,
|
||||
headers: Optional[httpx.Headers] = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(method="POST", url="https://api.cohere.ai/v1/chat")
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
status_code=status_code,
|
||||
message=message,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
||||
class CohereChatConfig(BaseConfig):
|
||||
|
@ -88,19 +112,36 @@ class CohereChatConfig(BaseConfig):
|
|||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return super().get_config()
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Return headers to use for cohere chat completion request
|
||||
|
||||
def _transform_messages(
|
||||
self, messages: List[AllMessageValues]
|
||||
) -> List[AllMessageValues]:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: dict
|
||||
) -> BaseLLMException:
|
||||
return CohereError(status_code=status_code, message=error_message)
|
||||
Cohere API Ref: https://docs.cohere.com/reference/chat
|
||||
Expected headers:
|
||||
{
|
||||
"Request-Source": "unspecified:litellm",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"Authorization": "bearer $CO_API_KEY"
|
||||
}
|
||||
"""
|
||||
headers.update(
|
||||
{
|
||||
"Request-Source": "unspecified:litellm",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
)
|
||||
if api_key:
|
||||
headers["Authorization"] = f"bearer {api_key}"
|
||||
return headers
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return [
|
||||
|
@ -156,29 +197,293 @@ class CohereChatConfig(BaseConfig):
|
|||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
## Load Config
|
||||
for k, v in litellm.CohereChatConfig.get_config().items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
most_recent_message, chat_history = cohere_messages_pt_v2(
|
||||
messages=messages, model=model, llm_provider="cohere_chat"
|
||||
)
|
||||
|
||||
## Handle Tool Calling
|
||||
if "tools" in optional_params:
|
||||
_is_function_call = True
|
||||
cohere_tools = self._construct_cohere_tool(tools=optional_params["tools"])
|
||||
optional_params["tools"] = cohere_tools
|
||||
if isinstance(most_recent_message, dict):
|
||||
optional_params["tool_results"] = [most_recent_message]
|
||||
elif isinstance(most_recent_message, str):
|
||||
optional_params["message"] = most_recent_message
|
||||
|
||||
## check if chat history message is 'user' and 'tool_results' is given -> force_single_step=True, else cohere api fails
|
||||
if len(chat_history) > 0 and chat_history[-1]["role"] == "USER":
|
||||
optional_params["force_single_step"] = True
|
||||
|
||||
return optional_params
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LoggingObj,
|
||||
api_key: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
encoding: Any,
|
||||
encoding: str,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
model_response.choices[0].message.content = raw_response_json["text"] # type: ignore
|
||||
except Exception:
|
||||
raise CohereError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
|
||||
## ADD CITATIONS
|
||||
if "citations" in raw_response_json:
|
||||
setattr(model_response, "citations", raw_response_json["citations"])
|
||||
|
||||
## Tool calling response
|
||||
cohere_tools_response = raw_response_json.get("tool_calls", None)
|
||||
if cohere_tools_response is not None and cohere_tools_response != []:
|
||||
# convert cohere_tools_response to OpenAI response format
|
||||
tool_calls = []
|
||||
for tool in cohere_tools_response:
|
||||
function_name = tool.get("name", "")
|
||||
generation_id = tool.get("generation_id", "")
|
||||
parameters = tool.get("parameters", {})
|
||||
tool_call = {
|
||||
"id": f"call_{generation_id}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(parameters),
|
||||
},
|
||||
}
|
||||
tool_calls.append(tool_call)
|
||||
_message = litellm.Message(
|
||||
tool_calls=tool_calls,
|
||||
content=None,
|
||||
)
|
||||
model_response.choices[0].message = _message # type: ignore
|
||||
|
||||
## CALCULATING USAGE - use cohere `billed_units` for returning usage
|
||||
billed_units = raw_response_json.get("meta", {}).get("billed_units", {})
|
||||
|
||||
prompt_tokens = billed_units.get("input_tokens", 0)
|
||||
completion_tokens = billed_units.get("output_tokens", 0)
|
||||
|
||||
model_response.created = int(time.time())
|
||||
model_response.model = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
def _construct_cohere_tool(
|
||||
self,
|
||||
tools: Optional[list] = None,
|
||||
):
|
||||
if tools is None:
|
||||
tools = []
|
||||
cohere_tools = []
|
||||
for tool in tools:
|
||||
cohere_tool = self._translate_openai_tool_to_cohere(tool)
|
||||
cohere_tools.append(cohere_tool)
|
||||
return cohere_tools
|
||||
|
||||
def _translate_openai_tool_to_cohere(
|
||||
self,
|
||||
openai_tool: dict,
|
||||
):
|
||||
# cohere tools look like this
|
||||
"""
|
||||
{
|
||||
"name": "query_daily_sales_report",
|
||||
"description": "Connects to a database to retrieve overall sales volumes and sales information for a given day.",
|
||||
"parameter_definitions": {
|
||||
"day": {
|
||||
"description": "Retrieves sales data for this day, formatted as YYYY-MM-DD.",
|
||||
"type": "str",
|
||||
"required": True
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
# OpenAI tools look like this
|
||||
"""
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
"""
|
||||
cohere_tool = {
|
||||
"name": openai_tool["function"]["name"],
|
||||
"description": openai_tool["function"]["description"],
|
||||
"parameter_definitions": {},
|
||||
}
|
||||
|
||||
for param_name, param_def in openai_tool["function"]["parameters"][
|
||||
"properties"
|
||||
].items():
|
||||
required_params = (
|
||||
openai_tool.get("function", {})
|
||||
.get("parameters", {})
|
||||
.get("required", [])
|
||||
)
|
||||
cohere_param_def = {
|
||||
"description": param_def.get("description", ""),
|
||||
"type": param_def.get("type", ""),
|
||||
"required": param_name in required_params,
|
||||
}
|
||||
cohere_tool["parameter_definitions"][param_name] = cohere_param_def
|
||||
|
||||
return cohere_tool
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str]],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
):
|
||||
return ModelResponseIterator(
|
||||
streaming_response=streaming_response,
|
||||
sync_stream=sync_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return CohereError(status_code=status_code, message=error_message)
|
||||
|
||||
def _transform_messages(
|
||||
self, messages: List[AllMessageValues]
|
||||
) -> List[AllMessageValues]:
|
||||
raise NotImplementedError
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
api_key: str,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
) -> dict:
|
||||
return cohere_validate_environment(api_key=api_key, headers=headers)
|
||||
|
||||
class ModelResponseIterator:
|
||||
def __init__(
|
||||
self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
|
||||
):
|
||||
self.streaming_response = streaming_response
|
||||
self.response_iterator = self.streaming_response
|
||||
self.content_blocks: List = []
|
||||
self.tool_index = -1
|
||||
self.json_mode = json_mode
|
||||
|
||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||
try:
|
||||
text = ""
|
||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
usage: Optional[ChatCompletionUsageBlock] = None
|
||||
provider_specific_fields = None
|
||||
|
||||
index = int(chunk.get("index", 0))
|
||||
|
||||
if "text" in chunk:
|
||||
text = chunk["text"]
|
||||
elif "is_finished" in chunk and chunk["is_finished"] is True:
|
||||
is_finished = chunk["is_finished"]
|
||||
finish_reason = chunk["finish_reason"]
|
||||
|
||||
if "citations" in chunk:
|
||||
provider_specific_fields = {"citations": chunk["citations"]}
|
||||
|
||||
returned_chunk = GenericStreamingChunk(
|
||||
text=text,
|
||||
tool_use=tool_use,
|
||||
is_finished=is_finished,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
index=index,
|
||||
provider_specific_fields=provider_specific_fields,
|
||||
)
|
||||
|
||||
return returned_chunk
|
||||
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
|
||||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
chunk = self.response_iterator.__next__()
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
str_line = chunk
|
||||
if isinstance(chunk, bytes): # Handle binary data
|
||||
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||
index = str_line.find("data:")
|
||||
if index != -1:
|
||||
str_line = str_line[index:]
|
||||
data_json = json.loads(str_line)
|
||||
return self.chunk_parser(chunk=data_json)
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||
|
||||
# Async iterator
|
||||
def __aiter__(self):
|
||||
self.async_response_iterator = self.streaming_response.__aiter__()
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
chunk = await self.async_response_iterator.__anext__()
|
||||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
str_line = chunk
|
||||
if isinstance(chunk, bytes): # Handle binary data
|
||||
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||
index = str_line.find("data:")
|
||||
if index != -1:
|
||||
str_line = str_line[index:]
|
||||
|
||||
data_json = json.loads(str_line)
|
||||
return self.chunk_parser(chunk=data_json)
|
||||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from typing import Optional
|
||||
|
||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||
|
||||
|
||||
|
@ -6,7 +8,7 @@ class CohereError(BaseLLMException):
|
|||
super().__init__(status_code=status_code, message=message)
|
||||
|
||||
|
||||
def validate_environment(*, api_key: str, headers: dict) -> dict:
|
||||
def validate_environment(*, api_key: Optional[str], headers: dict) -> dict:
|
||||
headers.update(
|
||||
{
|
||||
"Request-Source": "unspecified:litellm",
|
||||
|
|
|
@ -1,13 +1,9 @@
|
|||
import types
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.transformation import (
|
||||
BaseConfig,
|
||||
BaseLLMException,
|
||||
LoggingClass,
|
||||
)
|
||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
|
@ -15,11 +11,11 @@ from ..common_utils import CohereError
|
|||
from ..common_utils import validate_environment as cohere_validate_environment
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LoggingObj = LiteLLMLoggingObj
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LoggingObj = Any
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class CohereTextConfig(BaseConfig):
|
||||
|
@ -96,11 +92,11 @@ class CohereTextConfig(BaseConfig):
|
|||
|
||||
def validate_environment(
|
||||
self,
|
||||
api_key: str,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
return cohere_validate_environment(api_key=api_key, headers=headers)
|
||||
|
||||
|
@ -111,7 +107,7 @@ class CohereTextConfig(BaseConfig):
|
|||
raise NotImplementedError
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: dict
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return CohereError(status_code=status_code, message=error_message)
|
||||
|
||||
|
@ -172,12 +168,12 @@ class CohereTextConfig(BaseConfig):
|
|||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LoggingObj,
|
||||
api_key: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
encoding: Any,
|
||||
encoding: str,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
raise NotImplementedError
|
||||
|
|
355
litellm/llms/custom_httpx/llm_http_handler.py
Normal file
355
litellm/llms/custom_httpx/llm_http_handler.py
Normal file
|
@ -0,0 +1,355 @@
|
|||
import copy
|
||||
import json
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import httpx # type: ignore
|
||||
import requests # type: ignore
|
||||
from openai.types.chat.chat_completion_chunk import Choice as OpenAIStreamingChoice
|
||||
|
||||
import litellm
|
||||
import litellm.litellm_core_utils
|
||||
import litellm.types
|
||||
import litellm.types.utils
|
||||
from litellm import verbose_logger
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class BaseLLMHTTPHandler:
|
||||
async def async_completion(
|
||||
self,
|
||||
custom_llm_provider: str,
|
||||
provider_config: BaseConfig,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
model: str,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
messages: list,
|
||||
optional_params: dict,
|
||||
encoding: str,
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders(custom_llm_provider)
|
||||
)
|
||||
try:
|
||||
response = await async_httpx_client.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
timeout=timeout,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._handle_error(e=e, provider_config=provider_config)
|
||||
return provider_config.transform_response(
|
||||
model=model,
|
||||
raw_response=response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
request_data=data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_llm_provider: str,
|
||||
model_response: ModelResponse,
|
||||
encoding,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
optional_params: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
litellm_params: dict,
|
||||
acompletion: bool,
|
||||
stream: Optional[bool] = False,
|
||||
api_key: Optional[str] = None,
|
||||
headers={},
|
||||
):
|
||||
provider_config = ProviderConfigManager.get_provider_chat_config(
|
||||
model=model, provider=litellm.LlmProviders(custom_llm_provider)
|
||||
)
|
||||
# get config from model, custom llm provider
|
||||
headers = provider_config.validate_environment(
|
||||
api_key=api_key,
|
||||
headers=headers,
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
data = provider_config.transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
if acompletion is True:
|
||||
if stream is True:
|
||||
data["stream"] = stream
|
||||
return self.acompletion_stream_function(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
provider_config=provider_config,
|
||||
timeout=timeout,
|
||||
logging_obj=logging_obj,
|
||||
data=data,
|
||||
)
|
||||
|
||||
else:
|
||||
return self.async_completion(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
provider_config=provider_config,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=data,
|
||||
timeout=timeout,
|
||||
model=model,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
if stream is True:
|
||||
data["stream"] = stream
|
||||
completion_stream, headers = self.make_sync_call(
|
||||
provider_config=provider_config,
|
||||
api_base=api_base,
|
||||
headers=headers, # type: ignore
|
||||
data=json.dumps(data),
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
timeout=timeout,
|
||||
)
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
sync_httpx_client = _get_httpx_client()
|
||||
|
||||
try:
|
||||
response = sync_httpx_client.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
timeout=timeout,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._handle_error(
|
||||
e=e,
|
||||
provider_config=provider_config,
|
||||
)
|
||||
|
||||
return provider_config.transform_response(
|
||||
model=model,
|
||||
raw_response=response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
request_data=data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
def make_sync_call(
|
||||
self,
|
||||
provider_config: BaseConfig,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: str,
|
||||
model: str,
|
||||
messages: list,
|
||||
logging_obj,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
) -> Tuple[Any, httpx.Headers]:
|
||||
sync_httpx_client = _get_httpx_client()
|
||||
try:
|
||||
response = sync_httpx_client.post(
|
||||
api_base, headers=headers, data=data, stream=True, timeout=timeout
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise self._handle_error(
|
||||
e=e,
|
||||
provider_config=provider_config,
|
||||
)
|
||||
except Exception as e:
|
||||
for exception in litellm.LITELLM_EXCEPTION_TYPES:
|
||||
if isinstance(e, exception):
|
||||
raise e
|
||||
raise self._handle_error(
|
||||
e=e,
|
||||
provider_config=provider_config,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise BaseLLMException(
|
||||
status_code=response.status_code,
|
||||
message=str(response.read()),
|
||||
)
|
||||
completion_stream = provider_config.get_model_response_iterator(
|
||||
streaming_response=response.iter_lines(), sync_stream=True
|
||||
)
|
||||
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response="first stream response received",
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return completion_stream, response.headers
|
||||
|
||||
async def acompletion_stream_function(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_llm_provider: str,
|
||||
headers: dict,
|
||||
provider_config: BaseConfig,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
data: dict,
|
||||
):
|
||||
data["stream"] = True
|
||||
completion_stream, _response_headers = await self.make_async_call(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
provider_config=provider_config,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
timeout=timeout,
|
||||
)
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streamwrapper
|
||||
|
||||
async def make_async_call(
|
||||
self,
|
||||
custom_llm_provider: str,
|
||||
provider_config: BaseConfig,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: str,
|
||||
messages: list,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
) -> Tuple[Any, httpx.Headers]:
|
||||
async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders(custom_llm_provider)
|
||||
)
|
||||
try:
|
||||
response = await async_httpx_client.post(
|
||||
api_base, headers=headers, data=data, stream=True, timeout=timeout
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise self._handle_error(
|
||||
e=e,
|
||||
provider_config=provider_config,
|
||||
)
|
||||
except Exception as e:
|
||||
for exception in litellm.LITELLM_EXCEPTION_TYPES:
|
||||
if isinstance(e, exception):
|
||||
raise e
|
||||
raise self._handle_error(
|
||||
e=e,
|
||||
provider_config=provider_config,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise BaseLLMException(
|
||||
status_code=response.status_code,
|
||||
message=str(response.read()),
|
||||
)
|
||||
|
||||
completion_stream = provider_config.get_model_response_iterator(
|
||||
streaming_response=response.aiter_lines(), sync_stream=False
|
||||
)
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response="first stream response received",
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return completion_stream, response.headers
|
||||
|
||||
def _handle_error(self, e: Exception, provider_config: BaseConfig):
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_text = getattr(e, "text", str(e))
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_headers is None and error_response:
|
||||
error_headers = getattr(error_response, "headers", None)
|
||||
if error_response and hasattr(error_response, "text"):
|
||||
error_text = getattr(error_response, "text", error_text)
|
||||
raise provider_config.error_class( # type: ignore
|
||||
message=error_text,
|
||||
status_code=status_code,
|
||||
headers=error_headers,
|
||||
)
|
||||
|
||||
def embedding(self):
|
||||
pass
|
|
@ -111,9 +111,9 @@ from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
|
|||
from .llms.bedrock.embed.embedding import BedrockEmbedding
|
||||
from .llms.bedrock.image.image_handler import BedrockImageGeneration
|
||||
from .llms.clarifai.chat import handler
|
||||
from .llms.cohere.chat import handler as cohere_chat
|
||||
from .llms.cohere.completion import completion as cohere_completion # type: ignore
|
||||
from .llms.cohere.embed import handler as cohere_embed
|
||||
from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
||||
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
|
||||
from .llms.databricks.chat.handler import DatabricksChatCompletion
|
||||
from .llms.databricks.embed.handler import DatabricksEmbeddingHandler
|
||||
|
@ -233,6 +233,7 @@ sagemaker_llm = SagemakerLLM()
|
|||
watsonx_chat_completion = WatsonXChatHandler()
|
||||
openai_like_embedding = OpenAILikeEmbeddingHandler()
|
||||
databricks_embedding = DatabricksEmbeddingHandler()
|
||||
base_llm_http_handler = BaseLLMHTTPHandler()
|
||||
####### COMPLETION ENDPOINTS ################
|
||||
|
||||
|
||||
|
@ -446,6 +447,7 @@ async def acompletion(
|
|||
or custom_llm_provider == "perplexity"
|
||||
or custom_llm_provider == "groq"
|
||||
or custom_llm_provider == "nvidia_nim"
|
||||
or custom_llm_provider == "cohere_chat"
|
||||
or custom_llm_provider == "cerebras"
|
||||
or custom_llm_provider == "sambanova"
|
||||
or custom_llm_provider == "ai21_chat"
|
||||
|
@ -1941,15 +1943,15 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
cohere_key = (
|
||||
api_key
|
||||
or litellm.cohere_key
|
||||
or get_secret("COHERE_API_KEY")
|
||||
or get_secret("CO_API_KEY")
|
||||
or get_secret_str("COHERE_API_KEY")
|
||||
or get_secret_str("CO_API_KEY")
|
||||
or litellm.api_key
|
||||
)
|
||||
|
||||
api_base = (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
or get_secret("COHERE_API_BASE")
|
||||
or get_secret_str("COHERE_API_BASE")
|
||||
or "https://api.cohere.ai/v1/chat"
|
||||
)
|
||||
|
||||
|
@ -1960,32 +1962,22 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
if extra_headers is not None:
|
||||
headers.update(extra_headers)
|
||||
|
||||
model_response = cohere_chat.completion(
|
||||
response = base_llm_http_handler.completion(
|
||||
model=model,
|
||||
stream=stream,
|
||||
messages=messages,
|
||||
acompletion=acompletion,
|
||||
api_base=api_base,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
custom_llm_provider="cohere_chat",
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
api_key=cohere_key,
|
||||
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||
)
|
||||
|
||||
# if "stream" in optional_params and optional_params["stream"] is True:
|
||||
# # don't try to access stream object,
|
||||
# response = CustomStreamWrapper(
|
||||
# model_response,
|
||||
# model,
|
||||
# custom_llm_provider="cohere_chat",
|
||||
# logging_obj=logging,
|
||||
# _response_headers=headers,
|
||||
# )
|
||||
# return response
|
||||
response = model_response
|
||||
elif custom_llm_provider == "maritalk":
|
||||
maritalk_key = (
|
||||
api_key
|
||||
|
|
|
@ -57,3 +57,167 @@ async def test_chat_completion_cohere_citations(stream):
|
|||
assert response.citations is not None
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_completion_cohere_command_r_plus_function_call():
|
||||
litellm.set_verbose = True
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Boston today in Fahrenheit?",
|
||||
}
|
||||
]
|
||||
try:
|
||||
# test without max tokens
|
||||
response = completion(
|
||||
model="command-r-plus",
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
)
|
||||
# Add any assertions, here to check response args
|
||||
print(response)
|
||||
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
|
||||
assert isinstance(
|
||||
response.choices[0].message.tool_calls[0].function.arguments, str
|
||||
)
|
||||
|
||||
messages.append(
|
||||
response.choices[0].message.model_dump()
|
||||
) # Add assistant tool invokes
|
||||
tool_result = (
|
||||
'{"location": "Boston", "temperature": "72", "unit": "fahrenheit"}'
|
||||
)
|
||||
# Add user submitted tool results in the OpenAI format
|
||||
messages.append(
|
||||
{
|
||||
"tool_call_id": response.choices[0].message.tool_calls[0].id,
|
||||
"role": "tool",
|
||||
"name": response.choices[0].message.tool_calls[0].function.name,
|
||||
"content": tool_result,
|
||||
}
|
||||
)
|
||||
# In the second response, Cohere should deduce answer from tool results
|
||||
second_response = completion(
|
||||
model="command-r-plus",
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
force_single_step=True,
|
||||
)
|
||||
print(second_response)
|
||||
except litellm.Timeout:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# @pytest.mark.skip(reason="flaky test, times out frequently")
|
||||
@pytest.mark.flaky(retries=6, delay=1)
|
||||
def test_completion_cohere():
|
||||
try:
|
||||
# litellm.set_verbose=True
|
||||
messages = [
|
||||
{"role": "system", "content": "You're a good bot"},
|
||||
{"role": "assistant", "content": [{"text": "2", "type": "text"}]},
|
||||
{"role": "assistant", "content": [{"text": "3", "type": "text"}]},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hey",
|
||||
},
|
||||
]
|
||||
response = completion(
|
||||
model="command-r",
|
||||
messages=messages,
|
||||
)
|
||||
print(response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# FYI - cohere_chat looks quite unstable, even when testing locally
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
async def test_chat_completion_cohere(sync_mode):
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
messages = [
|
||||
{"role": "system", "content": "You're a good bot"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hey",
|
||||
},
|
||||
]
|
||||
if sync_mode is False:
|
||||
response = await litellm.acompletion(
|
||||
model="cohere_chat/command-r",
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
)
|
||||
else:
|
||||
response = completion(
|
||||
model="cohere_chat/command-r",
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
)
|
||||
print(response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("sync_mode", [False])
|
||||
async def test_chat_completion_cohere_stream(sync_mode):
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
messages = [
|
||||
{"role": "system", "content": "You're a good bot"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hey",
|
||||
},
|
||||
]
|
||||
if sync_mode is False:
|
||||
response = await litellm.acompletion(
|
||||
model="cohere_chat/command-r",
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
stream=True,
|
||||
)
|
||||
print("async cohere stream response", response)
|
||||
async for chunk in response:
|
||||
print(chunk)
|
||||
else:
|
||||
response = completion(
|
||||
model="cohere_chat/command-r",
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
stream=True,
|
||||
)
|
||||
print(response)
|
||||
for chunk in response:
|
||||
print(chunk)
|
||||
except litellm.APIConnectionError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
|
|
@ -411,7 +411,9 @@ def test_dynamic_drop_params(drop_params):
|
|||
|
||||
|
||||
def test_dynamic_drop_params_e2e():
|
||||
with patch("requests.post", new=MagicMock()) as mock_response:
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post", new=MagicMock()
|
||||
) as mock_response:
|
||||
try:
|
||||
response = litellm.completion(
|
||||
model="command-r",
|
||||
|
@ -457,7 +459,9 @@ def test_dynamic_drop_params_parallel_tool_calls():
|
|||
"""
|
||||
https://github.com/BerriAI/litellm/issues/4584
|
||||
"""
|
||||
with patch("requests.post", new=MagicMock()) as mock_response:
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post", new=MagicMock()
|
||||
) as mock_response:
|
||||
try:
|
||||
response = litellm.completion(
|
||||
model="command-r",
|
||||
|
@ -498,7 +502,9 @@ def test_dynamic_drop_additional_params(drop_params):
|
|||
|
||||
|
||||
def test_dynamic_drop_additional_params_e2e():
|
||||
with patch("requests.post", new=MagicMock()) as mock_response:
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post", new=MagicMock()
|
||||
) as mock_response:
|
||||
try:
|
||||
response = litellm.completion(
|
||||
model="command-r",
|
||||
|
|
|
@ -695,79 +695,6 @@ async def test_anthropic_no_content_error():
|
|||
pytest.fail(f"An unexpected error occurred - {str(e)}")
|
||||
|
||||
|
||||
def test_completion_cohere_command_r_plus_function_call():
|
||||
litellm.set_verbose = True
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Boston today in Fahrenheit?",
|
||||
}
|
||||
]
|
||||
try:
|
||||
# test without max tokens
|
||||
response = completion(
|
||||
model="command-r-plus",
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
)
|
||||
# Add any assertions, here to check response args
|
||||
print(response)
|
||||
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
|
||||
assert isinstance(
|
||||
response.choices[0].message.tool_calls[0].function.arguments, str
|
||||
)
|
||||
|
||||
messages.append(
|
||||
response.choices[0].message.model_dump()
|
||||
) # Add assistant tool invokes
|
||||
tool_result = (
|
||||
'{"location": "Boston", "temperature": "72", "unit": "fahrenheit"}'
|
||||
)
|
||||
# Add user submitted tool results in the OpenAI format
|
||||
messages.append(
|
||||
{
|
||||
"tool_call_id": response.choices[0].message.tool_calls[0].id,
|
||||
"role": "tool",
|
||||
"name": response.choices[0].message.tool_calls[0].function.name,
|
||||
"content": tool_result,
|
||||
}
|
||||
)
|
||||
# In the second response, Cohere should deduce answer from tool results
|
||||
second_response = completion(
|
||||
model="command-r-plus",
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
force_single_step=True,
|
||||
)
|
||||
print(second_response)
|
||||
except litellm.Timeout:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_parse_xml_params():
|
||||
from litellm.llms.prompt_templates.factory import parse_xml_params
|
||||
|
||||
|
@ -2120,27 +2047,6 @@ def test_ollama_image():
|
|||
# hf_test_error_logs()
|
||||
|
||||
|
||||
# def test_completion_cohere(): # commenting out,for now as the cohere endpoint is being flaky
|
||||
# try:
|
||||
# litellm.CohereConfig(max_tokens=10, stop_sequences=["a"])
|
||||
# response = completion(
|
||||
# model="command-nightly", messages=messages, logger_fn=logger_fn
|
||||
# )
|
||||
# # Add any assertions here to check the response
|
||||
# print(response)
|
||||
# response_str = response["choices"][0]["message"]["content"]
|
||||
# response_str_2 = response.choices[0].message.content
|
||||
# if type(response_str) != str:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
# if type(response_str_2) != str:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_completion_cohere()
|
||||
|
||||
|
||||
def test_completion_openai():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
|
@ -3550,9 +3456,6 @@ def test_completion_bedrock_titan_null_response():
|
|||
# test_completion_bedrock_claude()
|
||||
|
||||
|
||||
# test_completion_bedrock_cohere()
|
||||
|
||||
|
||||
# def test_completion_bedrock_claude_stream():
|
||||
# print("calling claude")
|
||||
# litellm.set_verbose = False
|
||||
|
@ -3722,78 +3625,6 @@ def test_completion_anyscale_api():
|
|||
|
||||
|
||||
# test_completion_anyscale_api()
|
||||
|
||||
|
||||
# @pytest.mark.skip(reason="flaky test, times out frequently")
|
||||
@pytest.mark.flaky(retries=6, delay=1)
|
||||
def test_completion_cohere():
|
||||
try:
|
||||
# litellm.set_verbose=True
|
||||
messages = [
|
||||
{"role": "system", "content": "You're a good bot"},
|
||||
{"role": "assistant", "content": [{"text": "2", "type": "text"}]},
|
||||
{"role": "assistant", "content": [{"text": "3", "type": "text"}]},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hey",
|
||||
},
|
||||
]
|
||||
response = completion(
|
||||
model="command-r",
|
||||
messages=messages,
|
||||
extra_headers={"Helicone-Property-Locale": "ko"},
|
||||
)
|
||||
print(response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# FYI - cohere_chat looks quite unstable, even when testing locally
|
||||
def test_chat_completion_cohere():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
messages = [
|
||||
{"role": "system", "content": "You're a good bot"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hey",
|
||||
},
|
||||
]
|
||||
response = completion(
|
||||
model="cohere_chat/command-r",
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
)
|
||||
print(response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_chat_completion_cohere_stream():
|
||||
try:
|
||||
litellm.set_verbose = False
|
||||
messages = [
|
||||
{"role": "system", "content": "You're a good bot"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hey",
|
||||
},
|
||||
]
|
||||
response = completion(
|
||||
model="cohere_chat/command-r",
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
stream=True,
|
||||
)
|
||||
print(response)
|
||||
for chunk in response:
|
||||
print(chunk)
|
||||
except litellm.APIConnectionError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_azure_cloudflare_api():
|
||||
litellm.set_verbose = True
|
||||
try:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue