(Refactor) Code Quality improvement - Use Common base handler for Cohere /generate API (#7122)

* use validate_environment in common utils

* use transform request / response for cohere

* remove unused file

* use cohere base_llm_http_handler

* working cohere generate api on llm http handler

* streaming cohere generate api

* fix get_model_response_iterator

* fix streaming handler

* fix get_model_response_iterator

* test_cohere_generate_api_completion

* fix linting error

* fix testing cohere raising error

* fix get_model_response_iterator type

* add testing cohere generate api
This commit is contained in:
Ishaan Jaff 2024-12-10 10:44:42 -08:00 committed by GitHub
parent 9c2316b7ec
commit 1b377d5229
9 changed files with 439 additions and 382 deletions

View file

@ -411,32 +411,6 @@ class CustomStreamWrapper:
except Exception: except Exception:
raise ValueError(f"Unable to parse response. Original response: {chunk}") raise ValueError(f"Unable to parse response. Original response: {chunk}")
def handle_cohere_chunk(self, chunk):
chunk = chunk.decode("utf-8")
data_json = json.loads(chunk)
try:
text = ""
is_finished = False
finish_reason = ""
index: Optional[int] = None
if "index" in data_json:
index = data_json.get("index")
if "text" in data_json:
text = data_json["text"]
elif "is_finished" in data_json:
is_finished = data_json["is_finished"]
finish_reason = data_json["finish_reason"]
else:
raise Exception(data_json)
return {
"index": index,
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
except Exception:
raise ValueError(f"Unable to parse response. Original response: {chunk}")
def handle_azure_chunk(self, chunk): def handle_azure_chunk(self, chunk):
is_finished = False is_finished = False
finish_reason = "" finish_reason = ""
@ -1157,11 +1131,6 @@ class CustomStreamWrapper:
) )
else: else:
completion_obj["content"] = str(chunk) completion_obj["content"] = str(chunk)
elif self.custom_llm_provider == "cohere":
response_obj = self.handle_cohere_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "petals": elif self.custom_llm_provider == "petals":
if len(self.completion_stream) == 0: if len(self.completion_stream) == 0:
if self.received_finish_reason is not None: if self.received_finish_reason is not None:
@ -1669,6 +1638,7 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "text-completion-codestral" or self.custom_llm_provider == "text-completion-codestral"
or self.custom_llm_provider == "azure_text" or self.custom_llm_provider == "azure_text"
or self.custom_llm_provider == "cohere_chat" or self.custom_llm_provider == "cohere_chat"
or self.custom_llm_provider == "cohere"
or self.custom_llm_provider == "anthropic" or self.custom_llm_provider == "anthropic"
or self.custom_llm_provider == "anthropic_text" or self.custom_llm_provider == "anthropic_text"
or self.custom_llm_provider == "huggingface" or self.custom_llm_provider == "huggingface"

View file

@ -8,13 +8,10 @@ import litellm
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.llms.prompt_templates.factory import cohere_messages_pt_v2 from litellm.llms.prompt_templates.factory import cohere_messages_pt_v2
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ( from litellm.types.utils import ModelResponse, Usage
ChatCompletionToolCallChunk,
ChatCompletionUsageBlock, from ..common_utils import ModelResponseIterator as CohereModelResponseIterator
GenericStreamingChunk, from ..common_utils import validate_environment as cohere_validate_environment
ModelResponse,
Usage,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
@ -120,28 +117,13 @@ class CohereChatConfig(BaseConfig):
optional_params: dict, optional_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
) -> dict: ) -> dict:
""" return cohere_validate_environment(
Return headers to use for cohere chat completion request headers=headers,
model=model,
Cohere API Ref: https://docs.cohere.com/reference/chat messages=messages,
Expected headers: optional_params=optional_params,
{ api_key=api_key,
"Request-Source": "unspecified:litellm",
"accept": "application/json",
"content-type": "application/json",
"Authorization": "bearer $CO_API_KEY"
}
"""
headers.update(
{
"Request-Source": "unspecified:litellm",
"accept": "application/json",
"content-type": "application/json",
}
) )
if api_key:
headers["Authorization"] = f"bearer {api_key}"
return headers
def get_supported_openai_params(self, model: str) -> List[str]: def get_supported_openai_params(self, model: str) -> List[str]:
return [ return [
@ -372,7 +354,7 @@ class CohereChatConfig(BaseConfig):
sync_stream: bool, sync_stream: bool,
json_mode: Optional[bool] = False, json_mode: Optional[bool] = False,
): ):
return ModelResponseIterator( return CohereModelResponseIterator(
streaming_response=streaming_response, streaming_response=streaming_response,
sync_stream=sync_stream, sync_stream=sync_stream,
json_mode=json_mode, json_mode=json_mode,
@ -387,103 +369,3 @@ class CohereChatConfig(BaseConfig):
self, messages: List[AllMessageValues] self, messages: List[AllMessageValues]
) -> List[AllMessageValues]: ) -> List[AllMessageValues]:
raise NotImplementedError raise NotImplementedError
class ModelResponseIterator:
def __init__(
self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
):
self.streaming_response = streaming_response
self.response_iterator = self.streaming_response
self.content_blocks: List = []
self.tool_index = -1
self.json_mode = json_mode
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try:
text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None
is_finished = False
finish_reason = ""
usage: Optional[ChatCompletionUsageBlock] = None
provider_specific_fields = None
index = int(chunk.get("index", 0))
if "text" in chunk:
text = chunk["text"]
elif "is_finished" in chunk and chunk["is_finished"] is True:
is_finished = chunk["is_finished"]
finish_reason = chunk["finish_reason"]
if "citations" in chunk:
provider_specific_fields = {"citations": chunk["citations"]}
returned_chunk = GenericStreamingChunk(
text=text,
tool_use=tool_use,
is_finished=is_finished,
finish_reason=finish_reason,
usage=usage,
index=index,
provider_specific_fields=provider_specific_fields,
)
return returned_chunk
except json.JSONDecodeError:
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
# Sync iterator
def __iter__(self):
return self
def __next__(self):
try:
chunk = self.response_iterator.__next__()
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
str_line = chunk
if isinstance(chunk, bytes): # Handle binary data
str_line = chunk.decode("utf-8") # Convert bytes to string
index = str_line.find("data:")
if index != -1:
str_line = str_line[index:]
data_json = json.loads(str_line)
return self.chunk_parser(chunk=data_json)
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
# Async iterator
def __aiter__(self):
self.async_response_iterator = self.streaming_response.__aiter__()
return self
async def __anext__(self):
try:
chunk = await self.async_response_iterator.__anext__()
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
str_line = chunk
if isinstance(chunk, bytes): # Handle binary data
str_line = chunk.decode("utf-8") # Convert bytes to string
index = str_line.find("data:")
if index != -1:
str_line = str_line[index:]
data_json = json.loads(str_line)
return self.chunk_parser(chunk=data_json)
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")

View file

@ -1,6 +1,13 @@
from typing import Optional import json
from typing import List, Optional
from litellm.llms.base_llm.transformation import BaseLLMException from litellm.llms.base_llm.transformation import BaseLLMException
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import (
ChatCompletionToolCallChunk,
ChatCompletionUsageBlock,
GenericStreamingChunk,
)
class CohereError(BaseLLMException): class CohereError(BaseLLMException):
@ -8,7 +15,25 @@ class CohereError(BaseLLMException):
super().__init__(status_code=status_code, message=message) super().__init__(status_code=status_code, message=message)
def validate_environment(*, api_key: Optional[str], headers: dict) -> dict: def validate_environment(
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
) -> dict:
"""
Return headers to use for cohere chat completion request
Cohere API Ref: https://docs.cohere.com/reference/chat
Expected headers:
{
"Request-Source": "unspecified:litellm",
"accept": "application/json",
"content-type": "application/json",
"Authorization": "bearer $CO_API_KEY"
}
"""
headers.update( headers.update(
{ {
"Request-Source": "unspecified:litellm", "Request-Source": "unspecified:litellm",
@ -17,5 +42,105 @@ def validate_environment(*, api_key: Optional[str], headers: dict) -> dict:
} }
) )
if api_key: if api_key:
headers["Authorization"] = f"Bearer {api_key}" headers["Authorization"] = f"bearer {api_key}"
return headers return headers
class ModelResponseIterator:
def __init__(
self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
):
self.streaming_response = streaming_response
self.response_iterator = self.streaming_response
self.content_blocks: List = []
self.tool_index = -1
self.json_mode = json_mode
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try:
text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None
is_finished = False
finish_reason = ""
usage: Optional[ChatCompletionUsageBlock] = None
provider_specific_fields = None
index = int(chunk.get("index", 0))
if "text" in chunk:
text = chunk["text"]
elif "is_finished" in chunk and chunk["is_finished"] is True:
is_finished = chunk["is_finished"]
finish_reason = chunk["finish_reason"]
if "citations" in chunk:
provider_specific_fields = {"citations": chunk["citations"]}
returned_chunk = GenericStreamingChunk(
text=text,
tool_use=tool_use,
is_finished=is_finished,
finish_reason=finish_reason,
usage=usage,
index=index,
provider_specific_fields=provider_specific_fields,
)
return returned_chunk
except json.JSONDecodeError:
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
# Sync iterator
def __iter__(self):
return self
def __next__(self):
try:
chunk = self.response_iterator.__next__()
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
str_line = chunk
if isinstance(chunk, bytes): # Handle binary data
str_line = chunk.decode("utf-8") # Convert bytes to string
index = str_line.find("data:")
if index != -1:
str_line = str_line[index:]
data_json = json.loads(str_line)
return self.chunk_parser(chunk=data_json)
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
# Async iterator
def __aiter__(self):
self.async_response_iterator = self.streaming_response.__aiter__()
return self
async def __anext__(self):
try:
chunk = await self.async_response_iterator.__anext__()
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
str_line = chunk
if isinstance(chunk, bytes): # Handle binary data
str_line = chunk.decode("utf-8") # Convert bytes to string
index = str_line.find("data:")
if index != -1:
str_line = str_line[index:]
data_json = json.loads(str_line)
return self.chunk_parser(chunk=data_json)
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")

View file

@ -1,155 +0,0 @@
##### Calls /generate endpoint #######
import json
import os
import time
import traceback
import types
from enum import Enum
from typing import Any, Callable, Optional, Union
import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.utils import Choices, Message, ModelResponse, Usage
from ..common_utils import CohereError
def construct_cohere_tool(tools=None):
if tools is None:
tools = []
return {"tools": tools}
def validate_environment(api_key, headers: dict):
headers.update(
{
"Request-Source": "unspecified:litellm",
"accept": "application/json",
"content-type": "application/json",
}
)
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers
def completion(
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
headers: dict,
optional_params: dict,
litellm_params=None,
logger_fn=None,
):
headers = validate_environment(api_key, headers=headers)
completion_url = api_base
model = model
prompt = " ".join(message["content"] for message in messages)
## Load Config
config = litellm.CohereConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
## Handle Tool Calling
if "tools" in optional_params:
_is_function_call = True
tool_calling_system_prompt = construct_cohere_tool(
tools=optional_params["tools"]
)
optional_params["tools"] = tool_calling_system_prompt
data = {
"model": model,
"prompt": prompt,
**optional_params,
}
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": completion_url,
},
)
## COMPLETION CALL
response = requests.post(
completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False,
)
## error handling for cohere calls
if response.status_code != 200:
raise CohereError(message=response.text, status_code=response.status_code)
if "stream" in optional_params and optional_params["stream"] is True:
return response.iter_lines()
else:
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
completion_response = response.json()
if "error" in completion_response:
raise CohereError(
message=completion_response["error"],
status_code=response.status_code,
)
else:
try:
choices_list = []
for idx, item in enumerate(completion_response["generations"]):
if len(item["text"]) > 0:
message_obj = Message(content=item["text"])
else:
message_obj = Message(content=None)
choice_obj = Choices(
finish_reason=item["finish_reason"],
index=idx + 1,
message=message_obj,
)
choices_list.append(choice_obj)
model_response.choices = choices_list # type: ignore
except Exception:
raise CohereError(
message=response.text, status_code=response.status_code
)
## CALCULATING USAGE
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response

View file

@ -1,13 +1,26 @@
import json
import time
import types import types
from typing import TYPE_CHECKING, Any, List, Optional, Union from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union
import httpx import httpx
import litellm
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.llms.prompt_templates.common_utils import convert_content_list_to_str
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse from litellm.types.utils import (
ChatCompletionToolCallChunk,
ChatCompletionUsageBlock,
Choices,
GenericStreamingChunk,
Message,
ModelResponse,
Usage,
)
from ..common_utils import CohereError from ..common_utils import CohereError
from ..common_utils import ModelResponseIterator as CohereModelResponseIterator
from ..common_utils import validate_environment as cohere_validate_environment from ..common_utils import validate_environment as cohere_validate_environment
if TYPE_CHECKING: if TYPE_CHECKING:
@ -98,7 +111,13 @@ class CohereTextConfig(BaseConfig):
optional_params: dict, optional_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
) -> dict: ) -> dict:
return cohere_validate_environment(api_key=api_key, headers=headers) return cohere_validate_environment(
headers=headers,
model=model,
messages=messages,
optional_params=optional_params,
api_key=api_key,
)
def _transform_messages( def _transform_messages(
self, self,
@ -161,7 +180,33 @@ class CohereTextConfig(BaseConfig):
litellm_params: dict, litellm_params: dict,
headers: dict, headers: dict,
) -> dict: ) -> dict:
raise NotImplementedError prompt = " ".join(
convert_content_list_to_str(message=message) for message in messages
)
## Load Config
config = litellm.CohereConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
## Handle Tool Calling
if "tools" in optional_params:
_is_function_call = True
tool_calling_system_prompt = self._construct_cohere_tool_for_completion_api(
tools=optional_params["tools"]
)
optional_params["tools"] = tool_calling_system_prompt
data = {
"model": model,
"prompt": prompt,
**optional_params,
}
return data
def transform_response( def transform_response(
self, self,
@ -176,4 +221,56 @@ class CohereTextConfig(BaseConfig):
api_key: Optional[str] = None, api_key: Optional[str] = None,
json_mode: Optional[bool] = None, json_mode: Optional[bool] = None,
) -> ModelResponse: ) -> ModelResponse:
raise NotImplementedError prompt = " ".join(
convert_content_list_to_str(message=message) for message in messages
)
completion_response = raw_response.json()
choices_list = []
for idx, item in enumerate(completion_response["generations"]):
if len(item["text"]) > 0:
message_obj = Message(content=item["text"])
else:
message_obj = Message(content=None)
choice_obj = Choices(
finish_reason=item["finish_reason"],
index=idx + 1,
message=message_obj,
)
choices_list.append(choice_obj)
model_response.choices = choices_list # type: ignore
## CALCULATING USAGE
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def _construct_cohere_tool_for_completion_api(
self,
tools: Optional[List] = None,
) -> dict:
if tools is None:
tools = []
return {"tools": tools}
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
sync_stream: bool,
json_mode: Optional[bool] = False,
):
return CohereModelResponseIterator(
streaming_response=streaming_response,
sync_stream=sync_stream,
json_mode=json_mode,
)

View file

@ -373,8 +373,12 @@ class BaseLLMHTTPHandler:
error_headers = getattr(error_response, "headers", None) error_headers = getattr(error_response, "headers", None)
if error_response and hasattr(error_response, "text"): if error_response and hasattr(error_response, "text"):
error_text = getattr(error_response, "text", error_text) error_text = getattr(error_response, "text", error_text)
raise provider_config.error_class( # type: ignore if error_headers:
message=error_text, error_headers = dict(error_headers)
else:
error_headers = {}
raise provider_config.get_error_class(
error_message=error_text,
status_code=status_code, status_code=status_code,
headers=error_headers, headers=error_headers,
) )

View file

@ -109,7 +109,6 @@ from .llms.azure_text import AzureTextCompletion
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
from .llms.bedrock.embed.embedding import BedrockEmbedding from .llms.bedrock.embed.embedding import BedrockEmbedding
from .llms.bedrock.image.image_handler import BedrockImageGeneration from .llms.bedrock.image.image_handler import BedrockImageGeneration
from .llms.cohere.completion import completion as cohere_completion # type: ignore
from .llms.cohere.embed import handler as cohere_embed from .llms.cohere.embed import handler as cohere_embed
from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
from .llms.custom_llm import CustomLLM, custom_chat_llm_router from .llms.custom_llm import CustomLLM, custom_chat_llm_router
@ -446,6 +445,7 @@ async def acompletion(
or custom_llm_provider == "groq" or custom_llm_provider == "groq"
or custom_llm_provider == "nvidia_nim" or custom_llm_provider == "nvidia_nim"
or custom_llm_provider == "cohere_chat" or custom_llm_provider == "cohere_chat"
or custom_llm_provider == "cohere"
or custom_llm_provider == "cerebras" or custom_llm_provider == "cerebras"
or custom_llm_provider == "sambanova" or custom_llm_provider == "sambanova"
or custom_llm_provider == "ai21_chat" or custom_llm_provider == "ai21_chat"
@ -1895,31 +1895,22 @@ def completion( # type: ignore # noqa: PLR0915
if extra_headers is not None: if extra_headers is not None:
headers.update(extra_headers) headers.update(extra_headers)
model_response = cohere_completion.completion( response = base_llm_http_handler.completion(
model=model, model=model,
stream=stream,
messages=messages, messages=messages,
acompletion=acompletion,
api_base=api_base, api_base=api_base,
model_response=model_response, model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, litellm_params=litellm_params,
logger_fn=logger_fn, custom_llm_provider="cohere",
encoding=encoding, timeout=timeout,
headers=headers, headers=headers,
encoding=encoding,
api_key=cohere_key, api_key=cohere_key,
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
) )
if "stream" in optional_params and optional_params["stream"] is True:
# don't try to access stream object,
response = CustomStreamWrapper(
model_response,
model,
custom_llm_provider="cohere",
logging_obj=logging,
)
return response
response = model_response
elif custom_llm_provider == "cohere_chat": elif custom_llm_provider == "cohere_chat":
cohere_key = ( cohere_key = (
api_key api_key

View file

@ -0,0 +1,184 @@
import os
import sys
import traceback
from dotenv import load_dotenv
load_dotenv()
import io
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import json
import pytest
import litellm
from litellm import completion
from litellm.llms.cohere.completion.transformation import CohereTextConfig
@pytest.mark.asyncio
async def test_cohere_generate_api_completion():
try:
litellm.set_verbose = False
messages = [
{"role": "system", "content": "You're a good bot"},
{
"role": "user",
"content": "Hey",
},
]
response = completion(
model="cohere/command-nightly",
messages=messages,
max_tokens=10,
)
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
async def test_cohere_generate_api_stream():
try:
litellm.set_verbose = True
messages = [
{"role": "system", "content": "You're a good bot"},
{
"role": "user",
"content": "Hey",
},
]
response = await litellm.acompletion(
model="cohere/command-nightly",
messages=messages,
max_tokens=10,
stream=True,
)
print("async cohere stream response", response)
async for chunk in response:
print(chunk)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_completion_cohere_stream_bad_key():
try:
api_key = "bad-key"
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "how does a court case get to the Supreme Court?",
},
]
completion(
model="command-nightly",
messages=messages,
stream=True,
max_tokens=50,
api_key=api_key,
)
except litellm.AuthenticationError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_cohere_transform_request():
try:
config = CohereTextConfig()
messages = [
{"role": "system", "content": "You're a helpful bot"},
{"role": "user", "content": "Hello"},
]
optional_params = {"max_tokens": 10, "temperature": 0.7}
headers = {}
transformed_request = config.transform_request(
model="command-nightly",
messages=messages,
optional_params=optional_params,
litellm_params={},
headers=headers,
)
print("transformed_request", json.dumps(transformed_request, indent=4))
assert transformed_request["model"] == "command-nightly"
assert transformed_request["prompt"] == "You're a helpful bot Hello"
assert transformed_request["max_tokens"] == 10
assert transformed_request["temperature"] == 0.7
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_cohere_transform_request_with_tools():
try:
config = CohereTextConfig()
messages = [{"role": "user", "content": "What's the weather?"}]
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather information",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
},
},
}
]
optional_params = {"tools": tools}
transformed_request = config.transform_request(
model="command-nightly",
messages=messages,
optional_params=optional_params,
litellm_params={},
headers={},
)
print("transformed_request", json.dumps(transformed_request, indent=4))
assert "tools" in transformed_request
assert transformed_request["tools"] == {"tools": tools}
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_cohere_map_openai_params():
try:
config = CohereTextConfig()
openai_params = {
"temperature": 0.7,
"max_tokens": 100,
"n": 2,
"top_p": 0.9,
"frequency_penalty": 0.5,
"presence_penalty": 0.5,
"stop": ["END"],
"stream": True,
}
mapped_params = config.map_openai_params(
non_default_params=openai_params,
optional_params={},
model="command-nightly",
drop_params=False,
)
assert mapped_params["temperature"] == 0.7
assert mapped_params["max_tokens"] == 100
assert mapped_params["num_generations"] == 2
assert mapped_params["p"] == 0.9
assert mapped_params["frequency_penalty"] == 0.5
assert mapped_params["presence_penalty"] == 0.5
assert mapped_params["stop_sequences"] == ["END"]
assert mapped_params["stream"] == True
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -436,47 +436,6 @@ def test_completion_azure_stream_content_filter_no_delta():
pytest.fail(f"An exception occurred - {str(e)}") pytest.fail(f"An exception occurred - {str(e)}")
def test_completion_cohere_stream_bad_key():
try:
litellm.cache = None
api_key = "bad-key"
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "how does a court case get to the Supreme Court?",
},
]
response = completion(
model="command-nightly",
messages=messages,
stream=True,
max_tokens=50,
api_key=api_key,
)
complete_response = ""
# Add any assertions here to check the response
has_finish_reason = False
for idx, chunk in enumerate(response):
chunk, finished = streaming_format_tests(idx, chunk)
has_finish_reason = finished
if finished:
break
complete_response += chunk
if has_finish_reason is False:
raise Exception("Finish reason not in final chunk")
if complete_response.strip() == "":
raise Exception("Empty response received")
print(f"completion_response: {complete_response}")
except AuthenticationError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_cohere_stream_bad_key()
@pytest.mark.flaky(retries=5, delay=1) @pytest.mark.flaky(retries=5, delay=1)
def test_completion_azure_stream(): def test_completion_azure_stream():
try: try: