mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
(Refactor) Code Quality improvement - Use Common base handler for Cohere /generate API (#7122)
* use validate_environment in common utils * use transform request / response for cohere * remove unused file * use cohere base_llm_http_handler * working cohere generate api on llm http handler * streaming cohere generate api * fix get_model_response_iterator * fix streaming handler * fix get_model_response_iterator * test_cohere_generate_api_completion * fix linting error * fix testing cohere raising error * fix get_model_response_iterator type * add testing cohere generate api
This commit is contained in:
parent
9c2316b7ec
commit
1b377d5229
9 changed files with 439 additions and 382 deletions
|
@ -411,32 +411,6 @@ class CustomStreamWrapper:
|
||||||
except Exception:
|
except Exception:
|
||||||
raise ValueError(f"Unable to parse response. Original response: {chunk}")
|
raise ValueError(f"Unable to parse response. Original response: {chunk}")
|
||||||
|
|
||||||
def handle_cohere_chunk(self, chunk):
|
|
||||||
chunk = chunk.decode("utf-8")
|
|
||||||
data_json = json.loads(chunk)
|
|
||||||
try:
|
|
||||||
text = ""
|
|
||||||
is_finished = False
|
|
||||||
finish_reason = ""
|
|
||||||
index: Optional[int] = None
|
|
||||||
if "index" in data_json:
|
|
||||||
index = data_json.get("index")
|
|
||||||
if "text" in data_json:
|
|
||||||
text = data_json["text"]
|
|
||||||
elif "is_finished" in data_json:
|
|
||||||
is_finished = data_json["is_finished"]
|
|
||||||
finish_reason = data_json["finish_reason"]
|
|
||||||
else:
|
|
||||||
raise Exception(data_json)
|
|
||||||
return {
|
|
||||||
"index": index,
|
|
||||||
"text": text,
|
|
||||||
"is_finished": is_finished,
|
|
||||||
"finish_reason": finish_reason,
|
|
||||||
}
|
|
||||||
except Exception:
|
|
||||||
raise ValueError(f"Unable to parse response. Original response: {chunk}")
|
|
||||||
|
|
||||||
def handle_azure_chunk(self, chunk):
|
def handle_azure_chunk(self, chunk):
|
||||||
is_finished = False
|
is_finished = False
|
||||||
finish_reason = ""
|
finish_reason = ""
|
||||||
|
@ -1157,11 +1131,6 @@ class CustomStreamWrapper:
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
completion_obj["content"] = str(chunk)
|
completion_obj["content"] = str(chunk)
|
||||||
elif self.custom_llm_provider == "cohere":
|
|
||||||
response_obj = self.handle_cohere_chunk(chunk)
|
|
||||||
completion_obj["content"] = response_obj["text"]
|
|
||||||
if response_obj["is_finished"]:
|
|
||||||
self.received_finish_reason = response_obj["finish_reason"]
|
|
||||||
elif self.custom_llm_provider == "petals":
|
elif self.custom_llm_provider == "petals":
|
||||||
if len(self.completion_stream) == 0:
|
if len(self.completion_stream) == 0:
|
||||||
if self.received_finish_reason is not None:
|
if self.received_finish_reason is not None:
|
||||||
|
@ -1669,6 +1638,7 @@ class CustomStreamWrapper:
|
||||||
or self.custom_llm_provider == "text-completion-codestral"
|
or self.custom_llm_provider == "text-completion-codestral"
|
||||||
or self.custom_llm_provider == "azure_text"
|
or self.custom_llm_provider == "azure_text"
|
||||||
or self.custom_llm_provider == "cohere_chat"
|
or self.custom_llm_provider == "cohere_chat"
|
||||||
|
or self.custom_llm_provider == "cohere"
|
||||||
or self.custom_llm_provider == "anthropic"
|
or self.custom_llm_provider == "anthropic"
|
||||||
or self.custom_llm_provider == "anthropic_text"
|
or self.custom_llm_provider == "anthropic_text"
|
||||||
or self.custom_llm_provider == "huggingface"
|
or self.custom_llm_provider == "huggingface"
|
||||||
|
|
|
@ -8,13 +8,10 @@ import litellm
|
||||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||||
from litellm.llms.prompt_templates.factory import cohere_messages_pt_v2
|
from litellm.llms.prompt_templates.factory import cohere_messages_pt_v2
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import ModelResponse, Usage
|
||||||
ChatCompletionToolCallChunk,
|
|
||||||
ChatCompletionUsageBlock,
|
from ..common_utils import ModelResponseIterator as CohereModelResponseIterator
|
||||||
GenericStreamingChunk,
|
from ..common_utils import validate_environment as cohere_validate_environment
|
||||||
ModelResponse,
|
|
||||||
Usage,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||||
|
@ -120,28 +117,13 @@ class CohereChatConfig(BaseConfig):
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
return cohere_validate_environment(
|
||||||
Return headers to use for cohere chat completion request
|
headers=headers,
|
||||||
|
model=model,
|
||||||
Cohere API Ref: https://docs.cohere.com/reference/chat
|
messages=messages,
|
||||||
Expected headers:
|
optional_params=optional_params,
|
||||||
{
|
api_key=api_key,
|
||||||
"Request-Source": "unspecified:litellm",
|
|
||||||
"accept": "application/json",
|
|
||||||
"content-type": "application/json",
|
|
||||||
"Authorization": "bearer $CO_API_KEY"
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
headers.update(
|
|
||||||
{
|
|
||||||
"Request-Source": "unspecified:litellm",
|
|
||||||
"accept": "application/json",
|
|
||||||
"content-type": "application/json",
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
if api_key:
|
|
||||||
headers["Authorization"] = f"bearer {api_key}"
|
|
||||||
return headers
|
|
||||||
|
|
||||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||||
return [
|
return [
|
||||||
|
@ -372,7 +354,7 @@ class CohereChatConfig(BaseConfig):
|
||||||
sync_stream: bool,
|
sync_stream: bool,
|
||||||
json_mode: Optional[bool] = False,
|
json_mode: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
return ModelResponseIterator(
|
return CohereModelResponseIterator(
|
||||||
streaming_response=streaming_response,
|
streaming_response=streaming_response,
|
||||||
sync_stream=sync_stream,
|
sync_stream=sync_stream,
|
||||||
json_mode=json_mode,
|
json_mode=json_mode,
|
||||||
|
@ -387,103 +369,3 @@ class CohereChatConfig(BaseConfig):
|
||||||
self, messages: List[AllMessageValues]
|
self, messages: List[AllMessageValues]
|
||||||
) -> List[AllMessageValues]:
|
) -> List[AllMessageValues]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class ModelResponseIterator:
|
|
||||||
def __init__(
|
|
||||||
self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
|
|
||||||
):
|
|
||||||
self.streaming_response = streaming_response
|
|
||||||
self.response_iterator = self.streaming_response
|
|
||||||
self.content_blocks: List = []
|
|
||||||
self.tool_index = -1
|
|
||||||
self.json_mode = json_mode
|
|
||||||
|
|
||||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
|
||||||
try:
|
|
||||||
text = ""
|
|
||||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
|
||||||
is_finished = False
|
|
||||||
finish_reason = ""
|
|
||||||
usage: Optional[ChatCompletionUsageBlock] = None
|
|
||||||
provider_specific_fields = None
|
|
||||||
|
|
||||||
index = int(chunk.get("index", 0))
|
|
||||||
|
|
||||||
if "text" in chunk:
|
|
||||||
text = chunk["text"]
|
|
||||||
elif "is_finished" in chunk and chunk["is_finished"] is True:
|
|
||||||
is_finished = chunk["is_finished"]
|
|
||||||
finish_reason = chunk["finish_reason"]
|
|
||||||
|
|
||||||
if "citations" in chunk:
|
|
||||||
provider_specific_fields = {"citations": chunk["citations"]}
|
|
||||||
|
|
||||||
returned_chunk = GenericStreamingChunk(
|
|
||||||
text=text,
|
|
||||||
tool_use=tool_use,
|
|
||||||
is_finished=is_finished,
|
|
||||||
finish_reason=finish_reason,
|
|
||||||
usage=usage,
|
|
||||||
index=index,
|
|
||||||
provider_specific_fields=provider_specific_fields,
|
|
||||||
)
|
|
||||||
|
|
||||||
return returned_chunk
|
|
||||||
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
|
|
||||||
|
|
||||||
# Sync iterator
|
|
||||||
def __iter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __next__(self):
|
|
||||||
try:
|
|
||||||
chunk = self.response_iterator.__next__()
|
|
||||||
except StopIteration:
|
|
||||||
raise StopIteration
|
|
||||||
except ValueError as e:
|
|
||||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
str_line = chunk
|
|
||||||
if isinstance(chunk, bytes): # Handle binary data
|
|
||||||
str_line = chunk.decode("utf-8") # Convert bytes to string
|
|
||||||
index = str_line.find("data:")
|
|
||||||
if index != -1:
|
|
||||||
str_line = str_line[index:]
|
|
||||||
data_json = json.loads(str_line)
|
|
||||||
return self.chunk_parser(chunk=data_json)
|
|
||||||
except StopIteration:
|
|
||||||
raise StopIteration
|
|
||||||
except ValueError as e:
|
|
||||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
|
||||||
|
|
||||||
# Async iterator
|
|
||||||
def __aiter__(self):
|
|
||||||
self.async_response_iterator = self.streaming_response.__aiter__()
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __anext__(self):
|
|
||||||
try:
|
|
||||||
chunk = await self.async_response_iterator.__anext__()
|
|
||||||
except StopAsyncIteration:
|
|
||||||
raise StopAsyncIteration
|
|
||||||
except ValueError as e:
|
|
||||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
str_line = chunk
|
|
||||||
if isinstance(chunk, bytes): # Handle binary data
|
|
||||||
str_line = chunk.decode("utf-8") # Convert bytes to string
|
|
||||||
index = str_line.find("data:")
|
|
||||||
if index != -1:
|
|
||||||
str_line = str_line[index:]
|
|
||||||
|
|
||||||
data_json = json.loads(str_line)
|
|
||||||
return self.chunk_parser(chunk=data_json)
|
|
||||||
except StopAsyncIteration:
|
|
||||||
raise StopAsyncIteration
|
|
||||||
except ValueError as e:
|
|
||||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
|
||||||
|
|
|
@ -1,6 +1,13 @@
|
||||||
from typing import Optional
|
import json
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
from litellm.types.utils import (
|
||||||
|
ChatCompletionToolCallChunk,
|
||||||
|
ChatCompletionUsageBlock,
|
||||||
|
GenericStreamingChunk,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CohereError(BaseLLMException):
|
class CohereError(BaseLLMException):
|
||||||
|
@ -8,7 +15,25 @@ class CohereError(BaseLLMException):
|
||||||
super().__init__(status_code=status_code, message=message)
|
super().__init__(status_code=status_code, message=message)
|
||||||
|
|
||||||
|
|
||||||
def validate_environment(*, api_key: Optional[str], headers: dict) -> dict:
|
def validate_environment(
|
||||||
|
headers: dict,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Return headers to use for cohere chat completion request
|
||||||
|
|
||||||
|
Cohere API Ref: https://docs.cohere.com/reference/chat
|
||||||
|
Expected headers:
|
||||||
|
{
|
||||||
|
"Request-Source": "unspecified:litellm",
|
||||||
|
"accept": "application/json",
|
||||||
|
"content-type": "application/json",
|
||||||
|
"Authorization": "bearer $CO_API_KEY"
|
||||||
|
}
|
||||||
|
"""
|
||||||
headers.update(
|
headers.update(
|
||||||
{
|
{
|
||||||
"Request-Source": "unspecified:litellm",
|
"Request-Source": "unspecified:litellm",
|
||||||
|
@ -17,5 +42,105 @@ def validate_environment(*, api_key: Optional[str], headers: dict) -> dict:
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if api_key:
|
if api_key:
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
headers["Authorization"] = f"bearer {api_key}"
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
|
|
||||||
|
class ModelResponseIterator:
|
||||||
|
def __init__(
|
||||||
|
self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
|
||||||
|
):
|
||||||
|
self.streaming_response = streaming_response
|
||||||
|
self.response_iterator = self.streaming_response
|
||||||
|
self.content_blocks: List = []
|
||||||
|
self.tool_index = -1
|
||||||
|
self.json_mode = json_mode
|
||||||
|
|
||||||
|
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||||
|
try:
|
||||||
|
text = ""
|
||||||
|
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||||
|
is_finished = False
|
||||||
|
finish_reason = ""
|
||||||
|
usage: Optional[ChatCompletionUsageBlock] = None
|
||||||
|
provider_specific_fields = None
|
||||||
|
|
||||||
|
index = int(chunk.get("index", 0))
|
||||||
|
|
||||||
|
if "text" in chunk:
|
||||||
|
text = chunk["text"]
|
||||||
|
elif "is_finished" in chunk and chunk["is_finished"] is True:
|
||||||
|
is_finished = chunk["is_finished"]
|
||||||
|
finish_reason = chunk["finish_reason"]
|
||||||
|
|
||||||
|
if "citations" in chunk:
|
||||||
|
provider_specific_fields = {"citations": chunk["citations"]}
|
||||||
|
|
||||||
|
returned_chunk = GenericStreamingChunk(
|
||||||
|
text=text,
|
||||||
|
tool_use=tool_use,
|
||||||
|
is_finished=is_finished,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
usage=usage,
|
||||||
|
index=index,
|
||||||
|
provider_specific_fields=provider_specific_fields,
|
||||||
|
)
|
||||||
|
|
||||||
|
return returned_chunk
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
|
||||||
|
|
||||||
|
# Sync iterator
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
try:
|
||||||
|
chunk = self.response_iterator.__next__()
|
||||||
|
except StopIteration:
|
||||||
|
raise StopIteration
|
||||||
|
except ValueError as e:
|
||||||
|
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
str_line = chunk
|
||||||
|
if isinstance(chunk, bytes): # Handle binary data
|
||||||
|
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||||
|
index = str_line.find("data:")
|
||||||
|
if index != -1:
|
||||||
|
str_line = str_line[index:]
|
||||||
|
data_json = json.loads(str_line)
|
||||||
|
return self.chunk_parser(chunk=data_json)
|
||||||
|
except StopIteration:
|
||||||
|
raise StopIteration
|
||||||
|
except ValueError as e:
|
||||||
|
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||||
|
|
||||||
|
# Async iterator
|
||||||
|
def __aiter__(self):
|
||||||
|
self.async_response_iterator = self.streaming_response.__aiter__()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
try:
|
||||||
|
chunk = await self.async_response_iterator.__anext__()
|
||||||
|
except StopAsyncIteration:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
except ValueError as e:
|
||||||
|
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
str_line = chunk
|
||||||
|
if isinstance(chunk, bytes): # Handle binary data
|
||||||
|
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||||
|
index = str_line.find("data:")
|
||||||
|
if index != -1:
|
||||||
|
str_line = str_line[index:]
|
||||||
|
|
||||||
|
data_json = json.loads(str_line)
|
||||||
|
return self.chunk_parser(chunk=data_json)
|
||||||
|
except StopAsyncIteration:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
except ValueError as e:
|
||||||
|
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||||
|
|
|
@ -1,155 +0,0 @@
|
||||||
##### Calls /generate endpoint #######
|
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
import traceback
|
|
||||||
import types
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, Callable, Optional, Union
|
|
||||||
|
|
||||||
import httpx # type: ignore
|
|
||||||
import requests # type: ignore
|
|
||||||
|
|
||||||
import litellm
|
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
|
||||||
from litellm.utils import Choices, Message, ModelResponse, Usage
|
|
||||||
|
|
||||||
from ..common_utils import CohereError
|
|
||||||
|
|
||||||
|
|
||||||
def construct_cohere_tool(tools=None):
|
|
||||||
if tools is None:
|
|
||||||
tools = []
|
|
||||||
return {"tools": tools}
|
|
||||||
|
|
||||||
|
|
||||||
def validate_environment(api_key, headers: dict):
|
|
||||||
headers.update(
|
|
||||||
{
|
|
||||||
"Request-Source": "unspecified:litellm",
|
|
||||||
"accept": "application/json",
|
|
||||||
"content-type": "application/json",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if api_key:
|
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
|
||||||
return headers
|
|
||||||
|
|
||||||
|
|
||||||
def completion(
|
|
||||||
model: str,
|
|
||||||
messages: list,
|
|
||||||
api_base: str,
|
|
||||||
model_response: ModelResponse,
|
|
||||||
print_verbose: Callable,
|
|
||||||
encoding,
|
|
||||||
api_key,
|
|
||||||
logging_obj,
|
|
||||||
headers: dict,
|
|
||||||
optional_params: dict,
|
|
||||||
litellm_params=None,
|
|
||||||
logger_fn=None,
|
|
||||||
):
|
|
||||||
headers = validate_environment(api_key, headers=headers)
|
|
||||||
completion_url = api_base
|
|
||||||
model = model
|
|
||||||
prompt = " ".join(message["content"] for message in messages)
|
|
||||||
|
|
||||||
## Load Config
|
|
||||||
config = litellm.CohereConfig.get_config()
|
|
||||||
for k, v in config.items():
|
|
||||||
if (
|
|
||||||
k not in optional_params
|
|
||||||
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
|
|
||||||
optional_params[k] = v
|
|
||||||
|
|
||||||
## Handle Tool Calling
|
|
||||||
if "tools" in optional_params:
|
|
||||||
_is_function_call = True
|
|
||||||
tool_calling_system_prompt = construct_cohere_tool(
|
|
||||||
tools=optional_params["tools"]
|
|
||||||
)
|
|
||||||
optional_params["tools"] = tool_calling_system_prompt
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"model": model,
|
|
||||||
"prompt": prompt,
|
|
||||||
**optional_params,
|
|
||||||
}
|
|
||||||
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.pre_call(
|
|
||||||
input=prompt,
|
|
||||||
api_key=api_key,
|
|
||||||
additional_args={
|
|
||||||
"complete_input_dict": data,
|
|
||||||
"headers": headers,
|
|
||||||
"api_base": completion_url,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
## COMPLETION CALL
|
|
||||||
response = requests.post(
|
|
||||||
completion_url,
|
|
||||||
headers=headers,
|
|
||||||
data=json.dumps(data),
|
|
||||||
stream=optional_params["stream"] if "stream" in optional_params else False,
|
|
||||||
)
|
|
||||||
## error handling for cohere calls
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise CohereError(message=response.text, status_code=response.status_code)
|
|
||||||
|
|
||||||
if "stream" in optional_params and optional_params["stream"] is True:
|
|
||||||
return response.iter_lines()
|
|
||||||
else:
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.post_call(
|
|
||||||
input=prompt,
|
|
||||||
api_key=api_key,
|
|
||||||
original_response=response.text,
|
|
||||||
additional_args={"complete_input_dict": data},
|
|
||||||
)
|
|
||||||
print_verbose(f"raw model_response: {response.text}")
|
|
||||||
## RESPONSE OBJECT
|
|
||||||
completion_response = response.json()
|
|
||||||
if "error" in completion_response:
|
|
||||||
raise CohereError(
|
|
||||||
message=completion_response["error"],
|
|
||||||
status_code=response.status_code,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
choices_list = []
|
|
||||||
for idx, item in enumerate(completion_response["generations"]):
|
|
||||||
if len(item["text"]) > 0:
|
|
||||||
message_obj = Message(content=item["text"])
|
|
||||||
else:
|
|
||||||
message_obj = Message(content=None)
|
|
||||||
choice_obj = Choices(
|
|
||||||
finish_reason=item["finish_reason"],
|
|
||||||
index=idx + 1,
|
|
||||||
message=message_obj,
|
|
||||||
)
|
|
||||||
choices_list.append(choice_obj)
|
|
||||||
model_response.choices = choices_list # type: ignore
|
|
||||||
except Exception:
|
|
||||||
raise CohereError(
|
|
||||||
message=response.text, status_code=response.status_code
|
|
||||||
)
|
|
||||||
|
|
||||||
## CALCULATING USAGE
|
|
||||||
prompt_tokens = len(encoding.encode(prompt))
|
|
||||||
completion_tokens = len(
|
|
||||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
|
||||||
)
|
|
||||||
|
|
||||||
model_response.created = int(time.time())
|
|
||||||
model_response.model = model
|
|
||||||
usage = Usage(
|
|
||||||
prompt_tokens=prompt_tokens,
|
|
||||||
completion_tokens=completion_tokens,
|
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
|
||||||
)
|
|
||||||
setattr(model_response, "usage", usage)
|
|
||||||
return model_response
|
|
|
@ -1,13 +1,26 @@
|
||||||
|
import json
|
||||||
|
import time
|
||||||
import types
|
import types
|
||||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
import litellm
|
||||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||||
|
from litellm.llms.prompt_templates.common_utils import convert_content_list_to_str
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
from litellm.types.utils import ModelResponse
|
from litellm.types.utils import (
|
||||||
|
ChatCompletionToolCallChunk,
|
||||||
|
ChatCompletionUsageBlock,
|
||||||
|
Choices,
|
||||||
|
GenericStreamingChunk,
|
||||||
|
Message,
|
||||||
|
ModelResponse,
|
||||||
|
Usage,
|
||||||
|
)
|
||||||
|
|
||||||
from ..common_utils import CohereError
|
from ..common_utils import CohereError
|
||||||
|
from ..common_utils import ModelResponseIterator as CohereModelResponseIterator
|
||||||
from ..common_utils import validate_environment as cohere_validate_environment
|
from ..common_utils import validate_environment as cohere_validate_environment
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -98,7 +111,13 @@ class CohereTextConfig(BaseConfig):
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
return cohere_validate_environment(api_key=api_key, headers=headers)
|
return cohere_validate_environment(
|
||||||
|
headers=headers,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
|
|
||||||
def _transform_messages(
|
def _transform_messages(
|
||||||
self,
|
self,
|
||||||
|
@ -161,7 +180,33 @@ class CohereTextConfig(BaseConfig):
|
||||||
litellm_params: dict,
|
litellm_params: dict,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
raise NotImplementedError
|
prompt = " ".join(
|
||||||
|
convert_content_list_to_str(message=message) for message in messages
|
||||||
|
)
|
||||||
|
|
||||||
|
## Load Config
|
||||||
|
config = litellm.CohereConfig.get_config()
|
||||||
|
for k, v in config.items():
|
||||||
|
if (
|
||||||
|
k not in optional_params
|
||||||
|
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
|
optional_params[k] = v
|
||||||
|
|
||||||
|
## Handle Tool Calling
|
||||||
|
if "tools" in optional_params:
|
||||||
|
_is_function_call = True
|
||||||
|
tool_calling_system_prompt = self._construct_cohere_tool_for_completion_api(
|
||||||
|
tools=optional_params["tools"]
|
||||||
|
)
|
||||||
|
optional_params["tools"] = tool_calling_system_prompt
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": model,
|
||||||
|
"prompt": prompt,
|
||||||
|
**optional_params,
|
||||||
|
}
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
def transform_response(
|
def transform_response(
|
||||||
self,
|
self,
|
||||||
|
@ -176,4 +221,56 @@ class CohereTextConfig(BaseConfig):
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
json_mode: Optional[bool] = None,
|
json_mode: Optional[bool] = None,
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
raise NotImplementedError
|
prompt = " ".join(
|
||||||
|
convert_content_list_to_str(message=message) for message in messages
|
||||||
|
)
|
||||||
|
completion_response = raw_response.json()
|
||||||
|
choices_list = []
|
||||||
|
for idx, item in enumerate(completion_response["generations"]):
|
||||||
|
if len(item["text"]) > 0:
|
||||||
|
message_obj = Message(content=item["text"])
|
||||||
|
else:
|
||||||
|
message_obj = Message(content=None)
|
||||||
|
choice_obj = Choices(
|
||||||
|
finish_reason=item["finish_reason"],
|
||||||
|
index=idx + 1,
|
||||||
|
message=message_obj,
|
||||||
|
)
|
||||||
|
choices_list.append(choice_obj)
|
||||||
|
model_response.choices = choices_list # type: ignore
|
||||||
|
|
||||||
|
## CALCULATING USAGE
|
||||||
|
prompt_tokens = len(encoding.encode(prompt))
|
||||||
|
completion_tokens = len(
|
||||||
|
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||||
|
)
|
||||||
|
|
||||||
|
model_response.created = int(time.time())
|
||||||
|
model_response.model = model
|
||||||
|
usage = Usage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
)
|
||||||
|
setattr(model_response, "usage", usage)
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
def _construct_cohere_tool_for_completion_api(
|
||||||
|
self,
|
||||||
|
tools: Optional[List] = None,
|
||||||
|
) -> dict:
|
||||||
|
if tools is None:
|
||||||
|
tools = []
|
||||||
|
return {"tools": tools}
|
||||||
|
|
||||||
|
def get_model_response_iterator(
|
||||||
|
self,
|
||||||
|
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||||
|
sync_stream: bool,
|
||||||
|
json_mode: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
return CohereModelResponseIterator(
|
||||||
|
streaming_response=streaming_response,
|
||||||
|
sync_stream=sync_stream,
|
||||||
|
json_mode=json_mode,
|
||||||
|
)
|
||||||
|
|
|
@ -373,8 +373,12 @@ class BaseLLMHTTPHandler:
|
||||||
error_headers = getattr(error_response, "headers", None)
|
error_headers = getattr(error_response, "headers", None)
|
||||||
if error_response and hasattr(error_response, "text"):
|
if error_response and hasattr(error_response, "text"):
|
||||||
error_text = getattr(error_response, "text", error_text)
|
error_text = getattr(error_response, "text", error_text)
|
||||||
raise provider_config.error_class( # type: ignore
|
if error_headers:
|
||||||
message=error_text,
|
error_headers = dict(error_headers)
|
||||||
|
else:
|
||||||
|
error_headers = {}
|
||||||
|
raise provider_config.get_error_class(
|
||||||
|
error_message=error_text,
|
||||||
status_code=status_code,
|
status_code=status_code,
|
||||||
headers=error_headers,
|
headers=error_headers,
|
||||||
)
|
)
|
||||||
|
|
|
@ -109,7 +109,6 @@ from .llms.azure_text import AzureTextCompletion
|
||||||
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
|
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
|
||||||
from .llms.bedrock.embed.embedding import BedrockEmbedding
|
from .llms.bedrock.embed.embedding import BedrockEmbedding
|
||||||
from .llms.bedrock.image.image_handler import BedrockImageGeneration
|
from .llms.bedrock.image.image_handler import BedrockImageGeneration
|
||||||
from .llms.cohere.completion import completion as cohere_completion # type: ignore
|
|
||||||
from .llms.cohere.embed import handler as cohere_embed
|
from .llms.cohere.embed import handler as cohere_embed
|
||||||
from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
||||||
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
|
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
|
||||||
|
@ -446,6 +445,7 @@ async def acompletion(
|
||||||
or custom_llm_provider == "groq"
|
or custom_llm_provider == "groq"
|
||||||
or custom_llm_provider == "nvidia_nim"
|
or custom_llm_provider == "nvidia_nim"
|
||||||
or custom_llm_provider == "cohere_chat"
|
or custom_llm_provider == "cohere_chat"
|
||||||
|
or custom_llm_provider == "cohere"
|
||||||
or custom_llm_provider == "cerebras"
|
or custom_llm_provider == "cerebras"
|
||||||
or custom_llm_provider == "sambanova"
|
or custom_llm_provider == "sambanova"
|
||||||
or custom_llm_provider == "ai21_chat"
|
or custom_llm_provider == "ai21_chat"
|
||||||
|
@ -1895,31 +1895,22 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
if extra_headers is not None:
|
if extra_headers is not None:
|
||||||
headers.update(extra_headers)
|
headers.update(extra_headers)
|
||||||
|
|
||||||
model_response = cohere_completion.completion(
|
response = base_llm_http_handler.completion(
|
||||||
model=model,
|
model=model,
|
||||||
|
stream=stream,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
acompletion=acompletion,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
print_verbose=print_verbose,
|
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
logger_fn=logger_fn,
|
custom_llm_provider="cohere",
|
||||||
encoding=encoding,
|
timeout=timeout,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
encoding=encoding,
|
||||||
api_key=cohere_key,
|
api_key=cohere_key,
|
||||||
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||||
)
|
)
|
||||||
|
|
||||||
if "stream" in optional_params and optional_params["stream"] is True:
|
|
||||||
# don't try to access stream object,
|
|
||||||
response = CustomStreamWrapper(
|
|
||||||
model_response,
|
|
||||||
model,
|
|
||||||
custom_llm_provider="cohere",
|
|
||||||
logging_obj=logging,
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
response = model_response
|
|
||||||
elif custom_llm_provider == "cohere_chat":
|
elif custom_llm_provider == "cohere_chat":
|
||||||
cohere_key = (
|
cohere_key = (
|
||||||
api_key
|
api_key
|
||||||
|
|
184
tests/llm_translation/test_cohere_generate_api.py
Normal file
184
tests/llm_translation/test_cohere_generate_api.py
Normal file
|
@ -0,0 +1,184 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import completion
|
||||||
|
from litellm.llms.cohere.completion.transformation import CohereTextConfig
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cohere_generate_api_completion():
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = False
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You're a good bot"},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hey",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
response = completion(
|
||||||
|
model="cohere/command-nightly",
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cohere_generate_api_stream():
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You're a good bot"},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hey",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="cohere/command-nightly",
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
print("async cohere stream response", response)
|
||||||
|
async for chunk in response:
|
||||||
|
print(chunk)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion_cohere_stream_bad_key():
|
||||||
|
try:
|
||||||
|
api_key = "bad-key"
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "how does a court case get to the Supreme Court?",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
completion(
|
||||||
|
model="command-nightly",
|
||||||
|
messages=messages,
|
||||||
|
stream=True,
|
||||||
|
max_tokens=50,
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
except litellm.AuthenticationError as e:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_transform_request():
|
||||||
|
try:
|
||||||
|
config = CohereTextConfig()
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You're a helpful bot"},
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
]
|
||||||
|
optional_params = {"max_tokens": 10, "temperature": 0.7}
|
||||||
|
headers = {}
|
||||||
|
|
||||||
|
transformed_request = config.transform_request(
|
||||||
|
model="command-nightly",
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params={},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("transformed_request", json.dumps(transformed_request, indent=4))
|
||||||
|
|
||||||
|
assert transformed_request["model"] == "command-nightly"
|
||||||
|
assert transformed_request["prompt"] == "You're a helpful bot Hello"
|
||||||
|
assert transformed_request["max_tokens"] == 10
|
||||||
|
assert transformed_request["temperature"] == 0.7
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_transform_request_with_tools():
|
||||||
|
try:
|
||||||
|
config = CohereTextConfig()
|
||||||
|
messages = [{"role": "user", "content": "What's the weather?"}]
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get weather information",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"location": {"type": "string"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
optional_params = {"tools": tools}
|
||||||
|
|
||||||
|
transformed_request = config.transform_request(
|
||||||
|
model="command-nightly",
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params={},
|
||||||
|
headers={},
|
||||||
|
)
|
||||||
|
|
||||||
|
print("transformed_request", json.dumps(transformed_request, indent=4))
|
||||||
|
assert "tools" in transformed_request
|
||||||
|
assert transformed_request["tools"] == {"tools": tools}
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_map_openai_params():
|
||||||
|
try:
|
||||||
|
config = CohereTextConfig()
|
||||||
|
openai_params = {
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_tokens": 100,
|
||||||
|
"n": 2,
|
||||||
|
"top_p": 0.9,
|
||||||
|
"frequency_penalty": 0.5,
|
||||||
|
"presence_penalty": 0.5,
|
||||||
|
"stop": ["END"],
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
mapped_params = config.map_openai_params(
|
||||||
|
non_default_params=openai_params,
|
||||||
|
optional_params={},
|
||||||
|
model="command-nightly",
|
||||||
|
drop_params=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mapped_params["temperature"] == 0.7
|
||||||
|
assert mapped_params["max_tokens"] == 100
|
||||||
|
assert mapped_params["num_generations"] == 2
|
||||||
|
assert mapped_params["p"] == 0.9
|
||||||
|
assert mapped_params["frequency_penalty"] == 0.5
|
||||||
|
assert mapped_params["presence_penalty"] == 0.5
|
||||||
|
assert mapped_params["stop_sequences"] == ["END"]
|
||||||
|
assert mapped_params["stream"] == True
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
|
@ -436,47 +436,6 @@ def test_completion_azure_stream_content_filter_no_delta():
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
def test_completion_cohere_stream_bad_key():
|
|
||||||
try:
|
|
||||||
litellm.cache = None
|
|
||||||
api_key = "bad-key"
|
|
||||||
messages = [
|
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "how does a court case get to the Supreme Court?",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
response = completion(
|
|
||||||
model="command-nightly",
|
|
||||||
messages=messages,
|
|
||||||
stream=True,
|
|
||||||
max_tokens=50,
|
|
||||||
api_key=api_key,
|
|
||||||
)
|
|
||||||
complete_response = ""
|
|
||||||
# Add any assertions here to check the response
|
|
||||||
has_finish_reason = False
|
|
||||||
for idx, chunk in enumerate(response):
|
|
||||||
chunk, finished = streaming_format_tests(idx, chunk)
|
|
||||||
has_finish_reason = finished
|
|
||||||
if finished:
|
|
||||||
break
|
|
||||||
complete_response += chunk
|
|
||||||
if has_finish_reason is False:
|
|
||||||
raise Exception("Finish reason not in final chunk")
|
|
||||||
if complete_response.strip() == "":
|
|
||||||
raise Exception("Empty response received")
|
|
||||||
print(f"completion_response: {complete_response}")
|
|
||||||
except AuthenticationError as e:
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
pytest.fail(f"Error occurred: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
# test_completion_cohere_stream_bad_key()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.flaky(retries=5, delay=1)
|
@pytest.mark.flaky(retries=5, delay=1)
|
||||||
def test_completion_azure_stream():
|
def test_completion_azure_stream():
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue