mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
(Refactor) Code Quality improvement - Use Common base handler for Cohere /generate API (#7122)
* use validate_environment in common utils * use transform request / response for cohere * remove unused file * use cohere base_llm_http_handler * working cohere generate api on llm http handler * streaming cohere generate api * fix get_model_response_iterator * fix streaming handler * fix get_model_response_iterator * test_cohere_generate_api_completion * fix linting error * fix testing cohere raising error * fix get_model_response_iterator type * add testing cohere generate api
This commit is contained in:
parent
bd39e1ab5d
commit
5e016fe66a
9 changed files with 439 additions and 382 deletions
|
@ -411,32 +411,6 @@ class CustomStreamWrapper:
|
|||
except Exception:
|
||||
raise ValueError(f"Unable to parse response. Original response: {chunk}")
|
||||
|
||||
def handle_cohere_chunk(self, chunk):
|
||||
chunk = chunk.decode("utf-8")
|
||||
data_json = json.loads(chunk)
|
||||
try:
|
||||
text = ""
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
index: Optional[int] = None
|
||||
if "index" in data_json:
|
||||
index = data_json.get("index")
|
||||
if "text" in data_json:
|
||||
text = data_json["text"]
|
||||
elif "is_finished" in data_json:
|
||||
is_finished = data_json["is_finished"]
|
||||
finish_reason = data_json["finish_reason"]
|
||||
else:
|
||||
raise Exception(data_json)
|
||||
return {
|
||||
"index": index,
|
||||
"text": text,
|
||||
"is_finished": is_finished,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
except Exception:
|
||||
raise ValueError(f"Unable to parse response. Original response: {chunk}")
|
||||
|
||||
def handle_azure_chunk(self, chunk):
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
|
@ -1157,11 +1131,6 @@ class CustomStreamWrapper:
|
|||
)
|
||||
else:
|
||||
completion_obj["content"] = str(chunk)
|
||||
elif self.custom_llm_provider == "cohere":
|
||||
response_obj = self.handle_cohere_chunk(chunk)
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
elif self.custom_llm_provider == "petals":
|
||||
if len(self.completion_stream) == 0:
|
||||
if self.received_finish_reason is not None:
|
||||
|
@ -1669,6 +1638,7 @@ class CustomStreamWrapper:
|
|||
or self.custom_llm_provider == "text-completion-codestral"
|
||||
or self.custom_llm_provider == "azure_text"
|
||||
or self.custom_llm_provider == "cohere_chat"
|
||||
or self.custom_llm_provider == "cohere"
|
||||
or self.custom_llm_provider == "anthropic"
|
||||
or self.custom_llm_provider == "anthropic_text"
|
||||
or self.custom_llm_provider == "huggingface"
|
||||
|
|
|
@ -8,13 +8,10 @@ import litellm
|
|||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.prompt_templates.factory import cohere_messages_pt_v2
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionUsageBlock,
|
||||
GenericStreamingChunk,
|
||||
ModelResponse,
|
||||
Usage,
|
||||
)
|
||||
from litellm.types.utils import ModelResponse, Usage
|
||||
|
||||
from ..common_utils import ModelResponseIterator as CohereModelResponseIterator
|
||||
from ..common_utils import validate_environment as cohere_validate_environment
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
@ -120,28 +117,13 @@ class CohereChatConfig(BaseConfig):
|
|||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Return headers to use for cohere chat completion request
|
||||
|
||||
Cohere API Ref: https://docs.cohere.com/reference/chat
|
||||
Expected headers:
|
||||
{
|
||||
"Request-Source": "unspecified:litellm",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"Authorization": "bearer $CO_API_KEY"
|
||||
}
|
||||
"""
|
||||
headers.update(
|
||||
{
|
||||
"Request-Source": "unspecified:litellm",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
return cohere_validate_environment(
|
||||
headers=headers,
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
api_key=api_key,
|
||||
)
|
||||
if api_key:
|
||||
headers["Authorization"] = f"bearer {api_key}"
|
||||
return headers
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return [
|
||||
|
@ -372,7 +354,7 @@ class CohereChatConfig(BaseConfig):
|
|||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
):
|
||||
return ModelResponseIterator(
|
||||
return CohereModelResponseIterator(
|
||||
streaming_response=streaming_response,
|
||||
sync_stream=sync_stream,
|
||||
json_mode=json_mode,
|
||||
|
@ -387,103 +369,3 @@ class CohereChatConfig(BaseConfig):
|
|||
self, messages: List[AllMessageValues]
|
||||
) -> List[AllMessageValues]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ModelResponseIterator:
|
||||
def __init__(
|
||||
self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
|
||||
):
|
||||
self.streaming_response = streaming_response
|
||||
self.response_iterator = self.streaming_response
|
||||
self.content_blocks: List = []
|
||||
self.tool_index = -1
|
||||
self.json_mode = json_mode
|
||||
|
||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||
try:
|
||||
text = ""
|
||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
usage: Optional[ChatCompletionUsageBlock] = None
|
||||
provider_specific_fields = None
|
||||
|
||||
index = int(chunk.get("index", 0))
|
||||
|
||||
if "text" in chunk:
|
||||
text = chunk["text"]
|
||||
elif "is_finished" in chunk and chunk["is_finished"] is True:
|
||||
is_finished = chunk["is_finished"]
|
||||
finish_reason = chunk["finish_reason"]
|
||||
|
||||
if "citations" in chunk:
|
||||
provider_specific_fields = {"citations": chunk["citations"]}
|
||||
|
||||
returned_chunk = GenericStreamingChunk(
|
||||
text=text,
|
||||
tool_use=tool_use,
|
||||
is_finished=is_finished,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
index=index,
|
||||
provider_specific_fields=provider_specific_fields,
|
||||
)
|
||||
|
||||
return returned_chunk
|
||||
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
|
||||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
chunk = self.response_iterator.__next__()
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
str_line = chunk
|
||||
if isinstance(chunk, bytes): # Handle binary data
|
||||
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||
index = str_line.find("data:")
|
||||
if index != -1:
|
||||
str_line = str_line[index:]
|
||||
data_json = json.loads(str_line)
|
||||
return self.chunk_parser(chunk=data_json)
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||
|
||||
# Async iterator
|
||||
def __aiter__(self):
|
||||
self.async_response_iterator = self.streaming_response.__aiter__()
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
chunk = await self.async_response_iterator.__anext__()
|
||||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
str_line = chunk
|
||||
if isinstance(chunk, bytes): # Handle binary data
|
||||
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||
index = str_line.find("data:")
|
||||
if index != -1:
|
||||
str_line = str_line[index:]
|
||||
|
||||
data_json = json.loads(str_line)
|
||||
return self.chunk_parser(chunk=data_json)
|
||||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||
|
|
|
@ -1,6 +1,13 @@
|
|||
from typing import Optional
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionUsageBlock,
|
||||
GenericStreamingChunk,
|
||||
)
|
||||
|
||||
|
||||
class CohereError(BaseLLMException):
|
||||
|
@ -8,7 +15,25 @@ class CohereError(BaseLLMException):
|
|||
super().__init__(status_code=status_code, message=message)
|
||||
|
||||
|
||||
def validate_environment(*, api_key: Optional[str], headers: dict) -> dict:
|
||||
def validate_environment(
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Return headers to use for cohere chat completion request
|
||||
|
||||
Cohere API Ref: https://docs.cohere.com/reference/chat
|
||||
Expected headers:
|
||||
{
|
||||
"Request-Source": "unspecified:litellm",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"Authorization": "bearer $CO_API_KEY"
|
||||
}
|
||||
"""
|
||||
headers.update(
|
||||
{
|
||||
"Request-Source": "unspecified:litellm",
|
||||
|
@ -17,5 +42,105 @@ def validate_environment(*, api_key: Optional[str], headers: dict) -> dict:
|
|||
}
|
||||
)
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
headers["Authorization"] = f"bearer {api_key}"
|
||||
return headers
|
||||
|
||||
|
||||
class ModelResponseIterator:
|
||||
def __init__(
|
||||
self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
|
||||
):
|
||||
self.streaming_response = streaming_response
|
||||
self.response_iterator = self.streaming_response
|
||||
self.content_blocks: List = []
|
||||
self.tool_index = -1
|
||||
self.json_mode = json_mode
|
||||
|
||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||
try:
|
||||
text = ""
|
||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
usage: Optional[ChatCompletionUsageBlock] = None
|
||||
provider_specific_fields = None
|
||||
|
||||
index = int(chunk.get("index", 0))
|
||||
|
||||
if "text" in chunk:
|
||||
text = chunk["text"]
|
||||
elif "is_finished" in chunk and chunk["is_finished"] is True:
|
||||
is_finished = chunk["is_finished"]
|
||||
finish_reason = chunk["finish_reason"]
|
||||
|
||||
if "citations" in chunk:
|
||||
provider_specific_fields = {"citations": chunk["citations"]}
|
||||
|
||||
returned_chunk = GenericStreamingChunk(
|
||||
text=text,
|
||||
tool_use=tool_use,
|
||||
is_finished=is_finished,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
index=index,
|
||||
provider_specific_fields=provider_specific_fields,
|
||||
)
|
||||
|
||||
return returned_chunk
|
||||
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
|
||||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
chunk = self.response_iterator.__next__()
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
str_line = chunk
|
||||
if isinstance(chunk, bytes): # Handle binary data
|
||||
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||
index = str_line.find("data:")
|
||||
if index != -1:
|
||||
str_line = str_line[index:]
|
||||
data_json = json.loads(str_line)
|
||||
return self.chunk_parser(chunk=data_json)
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||
|
||||
# Async iterator
|
||||
def __aiter__(self):
|
||||
self.async_response_iterator = self.streaming_response.__aiter__()
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
chunk = await self.async_response_iterator.__anext__()
|
||||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
str_line = chunk
|
||||
if isinstance(chunk, bytes): # Handle binary data
|
||||
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||
index = str_line.find("data:")
|
||||
if index != -1:
|
||||
str_line = str_line[index:]
|
||||
|
||||
data_json = json.loads(str_line)
|
||||
return self.chunk_parser(chunk=data_json)
|
||||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||
|
|
|
@ -1,155 +0,0 @@
|
|||
##### Calls /generate endpoint #######
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
import types
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import httpx # type: ignore
|
||||
import requests # type: ignore
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.utils import Choices, Message, ModelResponse, Usage
|
||||
|
||||
from ..common_utils import CohereError
|
||||
|
||||
|
||||
def construct_cohere_tool(tools=None):
|
||||
if tools is None:
|
||||
tools = []
|
||||
return {"tools": tools}
|
||||
|
||||
|
||||
def validate_environment(api_key, headers: dict):
|
||||
headers.update(
|
||||
{
|
||||
"Request-Source": "unspecified:litellm",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
)
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
headers: dict,
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
):
|
||||
headers = validate_environment(api_key, headers=headers)
|
||||
completion_url = api_base
|
||||
model = model
|
||||
prompt = " ".join(message["content"] for message in messages)
|
||||
|
||||
## Load Config
|
||||
config = litellm.CohereConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
## Handle Tool Calling
|
||||
if "tools" in optional_params:
|
||||
_is_function_call = True
|
||||
tool_calling_system_prompt = construct_cohere_tool(
|
||||
tools=optional_params["tools"]
|
||||
)
|
||||
optional_params["tools"] = tool_calling_system_prompt
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"headers": headers,
|
||||
"api_base": completion_url,
|
||||
},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
response = requests.post(
|
||||
completion_url,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
stream=optional_params["stream"] if "stream" in optional_params else False,
|
||||
)
|
||||
## error handling for cohere calls
|
||||
if response.status_code != 200:
|
||||
raise CohereError(message=response.text, status_code=response.status_code)
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
return response.iter_lines()
|
||||
else:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
## RESPONSE OBJECT
|
||||
completion_response = response.json()
|
||||
if "error" in completion_response:
|
||||
raise CohereError(
|
||||
message=completion_response["error"],
|
||||
status_code=response.status_code,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
choices_list = []
|
||||
for idx, item in enumerate(completion_response["generations"]):
|
||||
if len(item["text"]) > 0:
|
||||
message_obj = Message(content=item["text"])
|
||||
else:
|
||||
message_obj = Message(content=None)
|
||||
choice_obj = Choices(
|
||||
finish_reason=item["finish_reason"],
|
||||
index=idx + 1,
|
||||
message=message_obj,
|
||||
)
|
||||
choices_list.append(choice_obj)
|
||||
model_response.choices = choices_list # type: ignore
|
||||
except Exception:
|
||||
raise CohereError(
|
||||
message=response.text, status_code=response.status_code
|
||||
)
|
||||
|
||||
## CALCULATING USAGE
|
||||
prompt_tokens = len(encoding.encode(prompt))
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||
)
|
||||
|
||||
model_response.created = int(time.time())
|
||||
model_response.model = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
|
@ -1,13 +1,26 @@
|
|||
import json
|
||||
import time
|
||||
import types
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.prompt_templates.common_utils import convert_content_list_to_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionUsageBlock,
|
||||
Choices,
|
||||
GenericStreamingChunk,
|
||||
Message,
|
||||
ModelResponse,
|
||||
Usage,
|
||||
)
|
||||
|
||||
from ..common_utils import CohereError
|
||||
from ..common_utils import ModelResponseIterator as CohereModelResponseIterator
|
||||
from ..common_utils import validate_environment as cohere_validate_environment
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -98,7 +111,13 @@ class CohereTextConfig(BaseConfig):
|
|||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
return cohere_validate_environment(api_key=api_key, headers=headers)
|
||||
return cohere_validate_environment(
|
||||
headers=headers,
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
def _transform_messages(
|
||||
self,
|
||||
|
@ -161,7 +180,33 @@ class CohereTextConfig(BaseConfig):
|
|||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
raise NotImplementedError
|
||||
prompt = " ".join(
|
||||
convert_content_list_to_str(message=message) for message in messages
|
||||
)
|
||||
|
||||
## Load Config
|
||||
config = litellm.CohereConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
## Handle Tool Calling
|
||||
if "tools" in optional_params:
|
||||
_is_function_call = True
|
||||
tool_calling_system_prompt = self._construct_cohere_tool_for_completion_api(
|
||||
tools=optional_params["tools"]
|
||||
)
|
||||
optional_params["tools"] = tool_calling_system_prompt
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
return data
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
|
@ -176,4 +221,56 @@ class CohereTextConfig(BaseConfig):
|
|||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
raise NotImplementedError
|
||||
prompt = " ".join(
|
||||
convert_content_list_to_str(message=message) for message in messages
|
||||
)
|
||||
completion_response = raw_response.json()
|
||||
choices_list = []
|
||||
for idx, item in enumerate(completion_response["generations"]):
|
||||
if len(item["text"]) > 0:
|
||||
message_obj = Message(content=item["text"])
|
||||
else:
|
||||
message_obj = Message(content=None)
|
||||
choice_obj = Choices(
|
||||
finish_reason=item["finish_reason"],
|
||||
index=idx + 1,
|
||||
message=message_obj,
|
||||
)
|
||||
choices_list.append(choice_obj)
|
||||
model_response.choices = choices_list # type: ignore
|
||||
|
||||
## CALCULATING USAGE
|
||||
prompt_tokens = len(encoding.encode(prompt))
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||
)
|
||||
|
||||
model_response.created = int(time.time())
|
||||
model_response.model = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
def _construct_cohere_tool_for_completion_api(
|
||||
self,
|
||||
tools: Optional[List] = None,
|
||||
) -> dict:
|
||||
if tools is None:
|
||||
tools = []
|
||||
return {"tools": tools}
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
):
|
||||
return CohereModelResponseIterator(
|
||||
streaming_response=streaming_response,
|
||||
sync_stream=sync_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
|
|
@ -373,8 +373,12 @@ class BaseLLMHTTPHandler:
|
|||
error_headers = getattr(error_response, "headers", None)
|
||||
if error_response and hasattr(error_response, "text"):
|
||||
error_text = getattr(error_response, "text", error_text)
|
||||
raise provider_config.error_class( # type: ignore
|
||||
message=error_text,
|
||||
if error_headers:
|
||||
error_headers = dict(error_headers)
|
||||
else:
|
||||
error_headers = {}
|
||||
raise provider_config.get_error_class(
|
||||
error_message=error_text,
|
||||
status_code=status_code,
|
||||
headers=error_headers,
|
||||
)
|
||||
|
|
|
@ -109,7 +109,6 @@ from .llms.azure_text import AzureTextCompletion
|
|||
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
|
||||
from .llms.bedrock.embed.embedding import BedrockEmbedding
|
||||
from .llms.bedrock.image.image_handler import BedrockImageGeneration
|
||||
from .llms.cohere.completion import completion as cohere_completion # type: ignore
|
||||
from .llms.cohere.embed import handler as cohere_embed
|
||||
from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
||||
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
|
||||
|
@ -446,6 +445,7 @@ async def acompletion(
|
|||
or custom_llm_provider == "groq"
|
||||
or custom_llm_provider == "nvidia_nim"
|
||||
or custom_llm_provider == "cohere_chat"
|
||||
or custom_llm_provider == "cohere"
|
||||
or custom_llm_provider == "cerebras"
|
||||
or custom_llm_provider == "sambanova"
|
||||
or custom_llm_provider == "ai21_chat"
|
||||
|
@ -1895,31 +1895,22 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
if extra_headers is not None:
|
||||
headers.update(extra_headers)
|
||||
|
||||
model_response = cohere_completion.completion(
|
||||
response = base_llm_http_handler.completion(
|
||||
model=model,
|
||||
stream=stream,
|
||||
messages=messages,
|
||||
acompletion=acompletion,
|
||||
api_base=api_base,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
custom_llm_provider="cohere",
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
encoding=encoding,
|
||||
api_key=cohere_key,
|
||||
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||
)
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(
|
||||
model_response,
|
||||
model,
|
||||
custom_llm_provider="cohere",
|
||||
logging_obj=logging,
|
||||
)
|
||||
return response
|
||||
response = model_response
|
||||
elif custom_llm_provider == "cohere_chat":
|
||||
cohere_key = (
|
||||
api_key
|
||||
|
|
184
tests/llm_translation/test_cohere_generate_api.py
Normal file
184
tests/llm_translation/test_cohere_generate_api.py
Normal file
|
@ -0,0 +1,184 @@
|
|||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
import io
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm import completion
|
||||
from litellm.llms.cohere.completion.transformation import CohereTextConfig
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cohere_generate_api_completion():
|
||||
try:
|
||||
litellm.set_verbose = False
|
||||
messages = [
|
||||
{"role": "system", "content": "You're a good bot"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hey",
|
||||
},
|
||||
]
|
||||
response = completion(
|
||||
model="cohere/command-nightly",
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
)
|
||||
print(response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cohere_generate_api_stream():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
messages = [
|
||||
{"role": "system", "content": "You're a good bot"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hey",
|
||||
},
|
||||
]
|
||||
response = await litellm.acompletion(
|
||||
model="cohere/command-nightly",
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
stream=True,
|
||||
)
|
||||
print("async cohere stream response", response)
|
||||
async for chunk in response:
|
||||
print(chunk)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_completion_cohere_stream_bad_key():
|
||||
try:
|
||||
api_key = "bad-key"
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "how does a court case get to the Supreme Court?",
|
||||
},
|
||||
]
|
||||
completion(
|
||||
model="command-nightly",
|
||||
messages=messages,
|
||||
stream=True,
|
||||
max_tokens=50,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
except litellm.AuthenticationError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_cohere_transform_request():
|
||||
try:
|
||||
config = CohereTextConfig()
|
||||
messages = [
|
||||
{"role": "system", "content": "You're a helpful bot"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
optional_params = {"max_tokens": 10, "temperature": 0.7}
|
||||
headers = {}
|
||||
|
||||
transformed_request = config.transform_request(
|
||||
model="command-nightly",
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params={},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
print("transformed_request", json.dumps(transformed_request, indent=4))
|
||||
|
||||
assert transformed_request["model"] == "command-nightly"
|
||||
assert transformed_request["prompt"] == "You're a helpful bot Hello"
|
||||
assert transformed_request["max_tokens"] == 10
|
||||
assert transformed_request["temperature"] == 0.7
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_cohere_transform_request_with_tools():
|
||||
try:
|
||||
config = CohereTextConfig()
|
||||
messages = [{"role": "user", "content": "What's the weather?"}]
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get weather information",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"location": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
optional_params = {"tools": tools}
|
||||
|
||||
transformed_request = config.transform_request(
|
||||
model="command-nightly",
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params={},
|
||||
headers={},
|
||||
)
|
||||
|
||||
print("transformed_request", json.dumps(transformed_request, indent=4))
|
||||
assert "tools" in transformed_request
|
||||
assert transformed_request["tools"] == {"tools": tools}
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_cohere_map_openai_params():
|
||||
try:
|
||||
config = CohereTextConfig()
|
||||
openai_params = {
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 100,
|
||||
"n": 2,
|
||||
"top_p": 0.9,
|
||||
"frequency_penalty": 0.5,
|
||||
"presence_penalty": 0.5,
|
||||
"stop": ["END"],
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
mapped_params = config.map_openai_params(
|
||||
non_default_params=openai_params,
|
||||
optional_params={},
|
||||
model="command-nightly",
|
||||
drop_params=False,
|
||||
)
|
||||
|
||||
assert mapped_params["temperature"] == 0.7
|
||||
assert mapped_params["max_tokens"] == 100
|
||||
assert mapped_params["num_generations"] == 2
|
||||
assert mapped_params["p"] == 0.9
|
||||
assert mapped_params["frequency_penalty"] == 0.5
|
||||
assert mapped_params["presence_penalty"] == 0.5
|
||||
assert mapped_params["stop_sequences"] == ["END"]
|
||||
assert mapped_params["stream"] == True
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
|
@ -436,47 +436,6 @@ def test_completion_azure_stream_content_filter_no_delta():
|
|||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
|
||||
|
||||
def test_completion_cohere_stream_bad_key():
|
||||
try:
|
||||
litellm.cache = None
|
||||
api_key = "bad-key"
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "how does a court case get to the Supreme Court?",
|
||||
},
|
||||
]
|
||||
response = completion(
|
||||
model="command-nightly",
|
||||
messages=messages,
|
||||
stream=True,
|
||||
max_tokens=50,
|
||||
api_key=api_key,
|
||||
)
|
||||
complete_response = ""
|
||||
# Add any assertions here to check the response
|
||||
has_finish_reason = False
|
||||
for idx, chunk in enumerate(response):
|
||||
chunk, finished = streaming_format_tests(idx, chunk)
|
||||
has_finish_reason = finished
|
||||
if finished:
|
||||
break
|
||||
complete_response += chunk
|
||||
if has_finish_reason is False:
|
||||
raise Exception("Finish reason not in final chunk")
|
||||
if complete_response.strip() == "":
|
||||
raise Exception("Empty response received")
|
||||
print(f"completion_response: {complete_response}")
|
||||
except AuthenticationError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_completion_cohere_stream_bad_key()
|
||||
|
||||
|
||||
@pytest.mark.flaky(retries=5, delay=1)
|
||||
def test_completion_azure_stream():
|
||||
try:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue