feat(cohere/chat.py): return citations in model response

Closes https://github.com/BerriAI/litellm/issues/6814
This commit is contained in:
Krrish Dholakia 2024-11-30 13:59:57 -08:00
parent bd59f18809
commit 2fbc71a62c
6 changed files with 310 additions and 254 deletions

View file

@ -437,29 +437,6 @@ class CustomStreamWrapper:
except Exception: except Exception:
raise ValueError(f"Unable to parse response. Original response: {chunk}") raise ValueError(f"Unable to parse response. Original response: {chunk}")
def handle_cohere_chat_chunk(self, chunk):
chunk = chunk.decode("utf-8")
data_json = json.loads(chunk)
print_verbose(f"chunk: {chunk}")
try:
text = ""
is_finished = False
finish_reason = ""
if "text" in data_json:
text = data_json["text"]
elif "is_finished" in data_json and data_json["is_finished"] is True:
is_finished = data_json["is_finished"]
finish_reason = data_json["finish_reason"]
else:
return
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
except Exception:
raise ValueError(f"Unable to parse response. Original response: {chunk}")
def handle_azure_chunk(self, chunk): def handle_azure_chunk(self, chunk):
is_finished = False is_finished = False
finish_reason = "" finish_reason = ""
@ -949,7 +926,12 @@ class CustomStreamWrapper:
"function_call" in completion_obj "function_call" in completion_obj
and completion_obj["function_call"] is not None and completion_obj["function_call"] is not None
) )
or (
"provider_specific_fields" in response_obj
and response_obj["provider_specific_fields"] is not None
)
): # cannot set content of an OpenAI Object to be an empty string ): # cannot set content of an OpenAI Object to be an empty string
self.safety_checker() self.safety_checker()
hold, model_response_str = self.check_special_tokens( hold, model_response_str = self.check_special_tokens(
chunk=completion_obj["content"], chunk=completion_obj["content"],
@ -1058,6 +1040,7 @@ class CustomStreamWrapper:
and model_response.choices[0].delta.audio is not None and model_response.choices[0].delta.audio is not None
): ):
return model_response return model_response
else: else:
if hasattr(model_response, "usage"): if hasattr(model_response, "usage"):
self.chunks.append(model_response) self.chunks.append(model_response)
@ -1066,6 +1049,7 @@ class CustomStreamWrapper:
def chunk_creator(self, chunk): # type: ignore # noqa: PLR0915 def chunk_creator(self, chunk): # type: ignore # noqa: PLR0915
model_response = self.model_response_creator() model_response = self.model_response_creator()
response_obj: dict = {} response_obj: dict = {}
try: try:
# return this for all models # return this for all models
completion_obj = {"content": ""} completion_obj = {"content": ""}
@ -1256,14 +1240,6 @@ class CustomStreamWrapper:
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"] self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "cohere_chat":
response_obj = self.handle_cohere_chat_chunk(chunk)
if response_obj is None:
return
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "petals": elif self.custom_llm_provider == "petals":
if len(self.completion_stream) == 0: if len(self.completion_stream) == 0:
if self.received_finish_reason is not None: if self.received_finish_reason is not None:

View file

@ -4,13 +4,20 @@ import time
import traceback import traceback
import types import types
from enum import Enum from enum import Enum
from typing import Callable, Optional from typing import Any, Callable, List, Optional, Tuple, Union
import httpx # type: ignore import httpx # type: ignore
import requests # type: ignore import requests # type: ignore
import litellm import litellm
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.cohere import ToolResultObject from litellm.types.llms.cohere import ToolResultObject
from litellm.types.utils import (
ChatCompletionToolCallChunk,
ChatCompletionUsageBlock,
GenericStreamingChunk,
)
from litellm.utils import Choices, Message, ModelResponse, Usage from litellm.utils import Choices, Message, ModelResponse, Usage
from ..prompt_templates.factory import cohere_message_pt, cohere_messages_pt_v2 from ..prompt_templates.factory import cohere_message_pt, cohere_messages_pt_v2
@ -198,6 +205,106 @@ def construct_cohere_tool(tools=None):
return cohere_tools return cohere_tools
async def make_call(
client: Optional[AsyncHTTPHandler],
api_base: str,
headers: dict,
data: str,
model: str,
messages: list,
logging_obj,
timeout: Optional[Union[float, httpx.Timeout]],
json_mode: bool,
) -> Tuple[Any, httpx.Headers]:
if client is None:
client = litellm.module_level_aclient
try:
response = await client.post(
api_base, headers=headers, data=data, stream=True, timeout=timeout
)
except httpx.HTTPStatusError as e:
error_headers = getattr(e, "headers", None)
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise CohereError(
status_code=e.response.status_code,
message=await e.response.aread(),
)
except Exception as e:
for exception in litellm.LITELLM_EXCEPTION_TYPES:
if isinstance(e, exception):
raise e
raise CohereError(status_code=500, message=str(e))
completion_stream = ModelResponseIterator(
streaming_response=response.aiter_lines(),
sync_stream=False,
json_mode=json_mode,
)
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=completion_stream, # Pass the completion stream for logging
additional_args={"complete_input_dict": data},
)
return completion_stream, response.headers
def make_sync_call(
client: Optional[HTTPHandler],
api_base: str,
headers: dict,
data: str,
model: str,
messages: list,
logging_obj,
timeout: Optional[Union[float, httpx.Timeout]],
) -> Tuple[Any, httpx.Headers]:
if client is None:
client = litellm.module_level_client # re-use a module level client
try:
response = client.post(
api_base, headers=headers, data=data, stream=True, timeout=timeout
)
except httpx.HTTPStatusError as e:
raise CohereError(
status_code=e.response.status_code,
message=e.response.read(),
)
except Exception as e:
for exception in litellm.LITELLM_EXCEPTION_TYPES:
if isinstance(e, exception):
raise e
raise CohereError(status_code=500, message=str(e))
if response.status_code != 200:
raise CohereError(
status_code=response.status_code,
message=response.read(),
)
completion_stream = ModelResponseIterator(
streaming_response=response.iter_lines(), sync_stream=True
)
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response="first stream response received",
additional_args={"complete_input_dict": data},
)
return completion_stream, response.headers
def completion( def completion(
model: str, model: str,
messages: list, messages: list,
@ -211,6 +318,8 @@ def completion(
logging_obj, logging_obj,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
client=None,
timeout=None,
): ):
headers = validate_environment(api_key, headers=headers) headers = validate_environment(api_key, headers=headers)
completion_url = api_base completion_url = api_base
@ -269,7 +378,23 @@ def completion(
raise CohereError(message=response.text, status_code=response.status_code) raise CohereError(message=response.text, status_code=response.status_code)
if "stream" in optional_params and optional_params["stream"] is True: if "stream" in optional_params and optional_params["stream"] is True:
return response.iter_lines() completion_stream, headers = make_sync_call(
client=client,
api_base=api_base,
headers=headers, # type: ignore
data=json.dumps(data),
model=model,
messages=messages,
logging_obj=logging_obj,
timeout=timeout,
)
return CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="cohere_chat",
logging_obj=logging_obj,
_response_headers=headers,
)
else: else:
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -286,6 +411,10 @@ def completion(
except Exception: except Exception:
raise CohereError(message=response.text, status_code=response.status_code) raise CohereError(message=response.text, status_code=response.status_code)
## ADD CITATIONS
if "citations" in completion_response:
setattr(model_response, "citations", completion_response["citations"])
## Tool calling response ## Tool calling response
cohere_tools_response = completion_response.get("tool_calls", None) cohere_tools_response = completion_response.get("tool_calls", None)
if cohere_tools_response is not None and cohere_tools_response != []: if cohere_tools_response is not None and cohere_tools_response != []:
@ -325,3 +454,103 @@ def completion(
) )
setattr(model_response, "usage", usage) setattr(model_response, "usage", usage)
return model_response return model_response
class ModelResponseIterator:
def __init__(
self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
):
self.streaming_response = streaming_response
self.response_iterator = self.streaming_response
self.content_blocks: List = []
self.tool_index = -1
self.json_mode = json_mode
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try:
text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None
is_finished = False
finish_reason = ""
usage: Optional[ChatCompletionUsageBlock] = None
provider_specific_fields = None
index = int(chunk.get("index", 0))
if "text" in chunk:
text = chunk["text"]
elif "is_finished" in chunk and chunk["is_finished"] is True:
is_finished = chunk["is_finished"]
finish_reason = chunk["finish_reason"]
if "citations" in chunk:
provider_specific_fields = {"citations": chunk["citations"]}
returned_chunk = GenericStreamingChunk(
text=text,
tool_use=tool_use,
is_finished=is_finished,
finish_reason=finish_reason,
usage=usage,
index=index,
provider_specific_fields=provider_specific_fields,
)
return returned_chunk
except json.JSONDecodeError:
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
# Sync iterator
def __iter__(self):
return self
def __next__(self):
try:
chunk = self.response_iterator.__next__()
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
str_line = chunk
if isinstance(chunk, bytes): # Handle binary data
str_line = chunk.decode("utf-8") # Convert bytes to string
index = str_line.find("data:")
if index != -1:
str_line = str_line[index:]
data_json = json.loads(str_line)
return self.chunk_parser(chunk=data_json)
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
# Async iterator
def __aiter__(self):
self.async_response_iterator = self.streaming_response.__aiter__()
return self
async def __anext__(self):
try:
chunk = await self.async_response_iterator.__anext__()
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
str_line = chunk
if isinstance(chunk, bytes): # Handle binary data
str_line = chunk.decode("utf-8") # Convert bytes to string
index = str_line.find("data:")
if index != -1:
str_line = str_line[index:]
data_json = json.loads(str_line)
return self.chunk_parser(chunk=data_json)
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")

View file

@ -1970,15 +1970,16 @@ def completion( # type: ignore # noqa: PLR0915
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
) )
if "stream" in optional_params and optional_params["stream"] is True: # if "stream" in optional_params and optional_params["stream"] is True:
# don't try to access stream object, # # don't try to access stream object,
response = CustomStreamWrapper( # response = CustomStreamWrapper(
model_response, # model_response,
model, # model,
custom_llm_provider="cohere_chat", # custom_llm_provider="cohere_chat",
logging_obj=logging, # logging_obj=logging,
) # _response_headers=headers,
return response # )
# return response
response = model_response response = model_response
elif custom_llm_provider == "maritalk": elif custom_llm_provider == "maritalk":
maritalk_key = ( maritalk_key = (

View file

@ -0,0 +1,59 @@
import os
import sys
import traceback
from dotenv import load_dotenv
load_dotenv()
import io
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import json
import pytest
import litellm
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
litellm.num_retries = 3
@pytest.mark.parametrize("stream", [True, False])
@pytest.mark.asyncio
async def test_chat_completion_cohere_citations(stream):
try:
litellm.set_verbose = True
messages = [
{
"role": "user",
"content": "Which penguins are the tallest?",
},
]
response = await litellm.acompletion(
model="cohere_chat/command-r",
messages=messages,
documents=[
{"title": "Tall penguins", "text": "Emperor penguins are the tallest."},
{
"title": "Penguin habitats",
"text": "Emperor penguins only live in Antarctica.",
},
],
stream=stream,
)
if stream:
citations_chunk = False
async for chunk in response:
print("received chunk", chunk)
if "citations" in chunk:
citations_chunk = True
break
assert citations_chunk
else:
assert response.citations is not None
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -1,210 +0,0 @@
import os
import sys
import traceback
from dotenv import load_dotenv
load_dotenv()
import io
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import json
import pytest
import litellm
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
litellm.num_retries = 3
# FYI - cohere_chat looks quite unstable, even when testing locally
def test_chat_completion_cohere():
try:
litellm.set_verbose = True
messages = [
{
"role": "user",
"content": "Hey",
},
]
response = completion(
model="cohere_chat/command-r",
messages=messages,
max_tokens=10,
)
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_chat_completion_cohere_tool_calling():
try:
litellm.set_verbose = True
messages = [
{
"role": "user",
"content": "What is the weather like in Boston?",
},
]
response = completion(
model="cohere_chat/command-r",
messages=messages,
tools=[
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
},
}
],
)
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# def get_current_weather(location, unit="fahrenheit"):
# """Get the current weather in a given location"""
# if "tokyo" in location.lower():
# return json.dumps({"location": "Tokyo", "temperature": "10", "unit": unit})
# elif "san francisco" in location.lower():
# return json.dumps({"location": "San Francisco", "temperature": "72", "unit": unit})
# elif "paris" in location.lower():
# return json.dumps({"location": "Paris", "temperature": "22", "unit": unit})
# else:
# return json.dumps({"location": location, "temperature": "unknown"})
# def test_chat_completion_cohere_tool_with_result_calling():
# # end to end cohere command-r with tool calling
# # Step 1 - Send available tools
# # Step 2 - Execute results
# # Step 3 - Send results to command-r
# try:
# litellm.set_verbose = True
# import json
# # Step 1 - Send available tools
# tools = [
# {
# "type": "function",
# "function": {
# "name": "get_current_weather",
# "description": "Get the current weather in a given location",
# "parameters": {
# "type": "object",
# "properties": {
# "location": {
# "type": "string",
# "description": "The city and state, e.g. San Francisco, CA",
# },
# "unit": {
# "type": "string",
# "enum": ["celsius", "fahrenheit"],
# },
# },
# "required": ["location"],
# },
# },
# }
# ]
# messages = [
# {
# "role": "user",
# "content": "What is the weather like in Boston?",
# },
# ]
# response = completion(
# model="cohere_chat/command-r",
# messages=messages,
# tools=tools,
# )
# print("Response with tools to call", response)
# print(response)
# # step 2 - Execute results
# tool_calls = response.tool_calls
# available_functions = {
# "get_current_weather": get_current_weather,
# } # only one function in this example, but you can have multiple
# for tool_call in tool_calls:
# function_name = tool_call.function.name
# function_to_call = available_functions[function_name]
# function_args = json.loads(tool_call.function.arguments)
# function_response = function_to_call(
# location=function_args.get("location"),
# unit=function_args.get("unit"),
# )
# messages.append(
# {
# "tool_call_id": tool_call.id,
# "role": "tool",
# "name": function_name,
# "content": function_response,
# }
# ) # extend conversation with function response
# print("messages with tool call results", messages)
# messages = [
# {
# "role": "user",
# "content": "What is the weather like in Boston?",
# },
# {
# "tool_call_id": "tool_1",
# "role": "tool",
# "name": "get_current_weather",
# "content": {"location": "San Francisco, CA", "unit": "fahrenheit", "temperature": "72"},
# },
# ]
# respone = completion(
# model="cohere_chat/command-r",
# messages=messages,
# tools=[
# {
# "type": "function",
# "function": {
# "name": "get_current_weather",
# "description": "Get the current weather in a given location",
# "parameters": {
# "type": "object",
# "properties": {
# "location": {
# "type": "string",
# "description": "The city and state, e.g. San Francisco, CA",
# },
# "unit": {
# "type": "string",
# "enum": ["celsius", "fahrenheit"],
# },
# },
# "required": ["location"],
# },
# },
# }
# ],
# )
# print(respone)
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -46,11 +46,12 @@ def get_current_weather(location, unit="fahrenheit"):
"model", "model",
[ [
"gpt-3.5-turbo-1106", "gpt-3.5-turbo-1106",
# "mistral/mistral-large-latest", "mistral/mistral-large-latest",
"claude-3-haiku-20240307", "claude-3-haiku-20240307",
"gemini/gemini-1.5-pro", "gemini/gemini-1.5-pro",
"anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0",
# "groq/llama3-8b-8192", "groq/llama3-8b-8192",
"cohere_chat/command-r",
], ],
) )
@pytest.mark.flaky(retries=3, delay=1) @pytest.mark.flaky(retries=3, delay=1)