From 2fbc71a62c3236f99bc59e4ccbb17d16b63d1e26 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 30 Nov 2024 13:59:57 -0800 Subject: [PATCH] feat(cohere/chat.py): return citations in model response Closes https://github.com/BerriAI/litellm/issues/6814 --- .../litellm_core_utils/streaming_handler.py | 38 +-- litellm/llms/cohere/chat.py | 233 +++++++++++++++++- litellm/main.py | 19 +- tests/llm_translation/test_cohere.py | 59 +++++ tests/local_testing/test_cohere_completion.py | 210 ---------------- tests/local_testing/test_function_calling.py | 5 +- 6 files changed, 310 insertions(+), 254 deletions(-) create mode 100644 tests/llm_translation/test_cohere.py delete mode 100644 tests/local_testing/test_cohere_completion.py diff --git a/litellm/litellm_core_utils/streaming_handler.py b/litellm/litellm_core_utils/streaming_handler.py index 483121c38..c1cc20a7d 100644 --- a/litellm/litellm_core_utils/streaming_handler.py +++ b/litellm/litellm_core_utils/streaming_handler.py @@ -437,29 +437,6 @@ class CustomStreamWrapper: except Exception: raise ValueError(f"Unable to parse response. Original response: {chunk}") - def handle_cohere_chat_chunk(self, chunk): - chunk = chunk.decode("utf-8") - data_json = json.loads(chunk) - print_verbose(f"chunk: {chunk}") - try: - text = "" - is_finished = False - finish_reason = "" - if "text" in data_json: - text = data_json["text"] - elif "is_finished" in data_json and data_json["is_finished"] is True: - is_finished = data_json["is_finished"] - finish_reason = data_json["finish_reason"] - else: - return - return { - "text": text, - "is_finished": is_finished, - "finish_reason": finish_reason, - } - except Exception: - raise ValueError(f"Unable to parse response. Original response: {chunk}") - def handle_azure_chunk(self, chunk): is_finished = False finish_reason = "" @@ -949,7 +926,12 @@ class CustomStreamWrapper: "function_call" in completion_obj and completion_obj["function_call"] is not None ) + or ( + "provider_specific_fields" in response_obj + and response_obj["provider_specific_fields"] is not None + ) ): # cannot set content of an OpenAI Object to be an empty string + self.safety_checker() hold, model_response_str = self.check_special_tokens( chunk=completion_obj["content"], @@ -1058,6 +1040,7 @@ class CustomStreamWrapper: and model_response.choices[0].delta.audio is not None ): return model_response + else: if hasattr(model_response, "usage"): self.chunks.append(model_response) @@ -1066,6 +1049,7 @@ class CustomStreamWrapper: def chunk_creator(self, chunk): # type: ignore # noqa: PLR0915 model_response = self.model_response_creator() response_obj: dict = {} + try: # return this for all models completion_obj = {"content": ""} @@ -1256,14 +1240,6 @@ class CustomStreamWrapper: completion_obj["content"] = response_obj["text"] if response_obj["is_finished"]: self.received_finish_reason = response_obj["finish_reason"] - elif self.custom_llm_provider == "cohere_chat": - response_obj = self.handle_cohere_chat_chunk(chunk) - if response_obj is None: - return - completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - self.received_finish_reason = response_obj["finish_reason"] - elif self.custom_llm_provider == "petals": if len(self.completion_stream) == 0: if self.received_finish_reason is not None: diff --git a/litellm/llms/cohere/chat.py b/litellm/llms/cohere/chat.py index e0a92b6c8..d65c4f895 100644 --- a/litellm/llms/cohere/chat.py +++ b/litellm/llms/cohere/chat.py @@ -4,13 +4,20 @@ import time import traceback import types from enum import Enum -from typing import Callable, Optional +from typing import Any, Callable, List, Optional, Tuple, Union import httpx # type: ignore import requests # type: ignore import litellm +from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.types.llms.cohere import ToolResultObject +from litellm.types.utils import ( + ChatCompletionToolCallChunk, + ChatCompletionUsageBlock, + GenericStreamingChunk, +) from litellm.utils import Choices, Message, ModelResponse, Usage from ..prompt_templates.factory import cohere_message_pt, cohere_messages_pt_v2 @@ -198,6 +205,106 @@ def construct_cohere_tool(tools=None): return cohere_tools +async def make_call( + client: Optional[AsyncHTTPHandler], + api_base: str, + headers: dict, + data: str, + model: str, + messages: list, + logging_obj, + timeout: Optional[Union[float, httpx.Timeout]], + json_mode: bool, +) -> Tuple[Any, httpx.Headers]: + if client is None: + client = litellm.module_level_aclient + + try: + response = await client.post( + api_base, headers=headers, data=data, stream=True, timeout=timeout + ) + except httpx.HTTPStatusError as e: + error_headers = getattr(e, "headers", None) + error_response = getattr(e, "response", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) + raise CohereError( + status_code=e.response.status_code, + message=await e.response.aread(), + ) + except Exception as e: + for exception in litellm.LITELLM_EXCEPTION_TYPES: + if isinstance(e, exception): + raise e + raise CohereError(status_code=500, message=str(e)) + + completion_stream = ModelResponseIterator( + streaming_response=response.aiter_lines(), + sync_stream=False, + json_mode=json_mode, + ) + + # LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response=completion_stream, # Pass the completion stream for logging + additional_args={"complete_input_dict": data}, + ) + + return completion_stream, response.headers + + +def make_sync_call( + client: Optional[HTTPHandler], + api_base: str, + headers: dict, + data: str, + model: str, + messages: list, + logging_obj, + timeout: Optional[Union[float, httpx.Timeout]], +) -> Tuple[Any, httpx.Headers]: + if client is None: + client = litellm.module_level_client # re-use a module level client + + try: + response = client.post( + api_base, headers=headers, data=data, stream=True, timeout=timeout + ) + except httpx.HTTPStatusError as e: + raise CohereError( + status_code=e.response.status_code, + message=e.response.read(), + ) + except Exception as e: + for exception in litellm.LITELLM_EXCEPTION_TYPES: + if isinstance(e, exception): + raise e + raise CohereError(status_code=500, message=str(e)) + + if response.status_code != 200: + + raise CohereError( + status_code=response.status_code, + message=response.read(), + ) + + completion_stream = ModelResponseIterator( + streaming_response=response.iter_lines(), sync_stream=True + ) + + # LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response="first stream response received", + additional_args={"complete_input_dict": data}, + ) + + return completion_stream, response.headers + + def completion( model: str, messages: list, @@ -211,6 +318,8 @@ def completion( logging_obj, litellm_params=None, logger_fn=None, + client=None, + timeout=None, ): headers = validate_environment(api_key, headers=headers) completion_url = api_base @@ -269,7 +378,23 @@ def completion( raise CohereError(message=response.text, status_code=response.status_code) if "stream" in optional_params and optional_params["stream"] is True: - return response.iter_lines() + completion_stream, headers = make_sync_call( + client=client, + api_base=api_base, + headers=headers, # type: ignore + data=json.dumps(data), + model=model, + messages=messages, + logging_obj=logging_obj, + timeout=timeout, + ) + return CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="cohere_chat", + logging_obj=logging_obj, + _response_headers=headers, + ) else: ## LOGGING logging_obj.post_call( @@ -286,6 +411,10 @@ def completion( except Exception: raise CohereError(message=response.text, status_code=response.status_code) + ## ADD CITATIONS + if "citations" in completion_response: + setattr(model_response, "citations", completion_response["citations"]) + ## Tool calling response cohere_tools_response = completion_response.get("tool_calls", None) if cohere_tools_response is not None and cohere_tools_response != []: @@ -325,3 +454,103 @@ def completion( ) setattr(model_response, "usage", usage) return model_response + + +class ModelResponseIterator: + def __init__( + self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False + ): + self.streaming_response = streaming_response + self.response_iterator = self.streaming_response + self.content_blocks: List = [] + self.tool_index = -1 + self.json_mode = json_mode + + def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: + try: + text = "" + tool_use: Optional[ChatCompletionToolCallChunk] = None + is_finished = False + finish_reason = "" + usage: Optional[ChatCompletionUsageBlock] = None + provider_specific_fields = None + + index = int(chunk.get("index", 0)) + + if "text" in chunk: + text = chunk["text"] + elif "is_finished" in chunk and chunk["is_finished"] is True: + is_finished = chunk["is_finished"] + finish_reason = chunk["finish_reason"] + + if "citations" in chunk: + provider_specific_fields = {"citations": chunk["citations"]} + + returned_chunk = GenericStreamingChunk( + text=text, + tool_use=tool_use, + is_finished=is_finished, + finish_reason=finish_reason, + usage=usage, + index=index, + provider_specific_fields=provider_specific_fields, + ) + + return returned_chunk + + except json.JSONDecodeError: + raise ValueError(f"Failed to decode JSON from chunk: {chunk}") + + # Sync iterator + def __iter__(self): + return self + + def __next__(self): + try: + chunk = self.response_iterator.__next__() + except StopIteration: + raise StopIteration + except ValueError as e: + raise RuntimeError(f"Error receiving chunk from stream: {e}") + + try: + str_line = chunk + if isinstance(chunk, bytes): # Handle binary data + str_line = chunk.decode("utf-8") # Convert bytes to string + index = str_line.find("data:") + if index != -1: + str_line = str_line[index:] + data_json = json.loads(str_line) + return self.chunk_parser(chunk=data_json) + except StopIteration: + raise StopIteration + except ValueError as e: + raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") + + # Async iterator + def __aiter__(self): + self.async_response_iterator = self.streaming_response.__aiter__() + return self + + async def __anext__(self): + try: + chunk = await self.async_response_iterator.__anext__() + except StopAsyncIteration: + raise StopAsyncIteration + except ValueError as e: + raise RuntimeError(f"Error receiving chunk from stream: {e}") + + try: + str_line = chunk + if isinstance(chunk, bytes): # Handle binary data + str_line = chunk.decode("utf-8") # Convert bytes to string + index = str_line.find("data:") + if index != -1: + str_line = str_line[index:] + + data_json = json.loads(str_line) + return self.chunk_parser(chunk=data_json) + except StopAsyncIteration: + raise StopAsyncIteration + except ValueError as e: + raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") diff --git a/litellm/main.py b/litellm/main.py index 5095ce518..6da7bb604 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1970,15 +1970,16 @@ def completion( # type: ignore # noqa: PLR0915 logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements ) - if "stream" in optional_params and optional_params["stream"] is True: - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="cohere_chat", - logging_obj=logging, - ) - return response + # if "stream" in optional_params and optional_params["stream"] is True: + # # don't try to access stream object, + # response = CustomStreamWrapper( + # model_response, + # model, + # custom_llm_provider="cohere_chat", + # logging_obj=logging, + # _response_headers=headers, + # ) + # return response response = model_response elif custom_llm_provider == "maritalk": maritalk_key = ( diff --git a/tests/llm_translation/test_cohere.py b/tests/llm_translation/test_cohere.py new file mode 100644 index 000000000..b8cf8d4d1 --- /dev/null +++ b/tests/llm_translation/test_cohere.py @@ -0,0 +1,59 @@ +import os +import sys +import traceback + +from dotenv import load_dotenv + +load_dotenv() +import io +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import json + +import pytest + +import litellm +from litellm import RateLimitError, Timeout, completion, completion_cost, embedding + +litellm.num_retries = 3 + + +@pytest.mark.parametrize("stream", [True, False]) +@pytest.mark.asyncio +async def test_chat_completion_cohere_citations(stream): + try: + litellm.set_verbose = True + messages = [ + { + "role": "user", + "content": "Which penguins are the tallest?", + }, + ] + response = await litellm.acompletion( + model="cohere_chat/command-r", + messages=messages, + documents=[ + {"title": "Tall penguins", "text": "Emperor penguins are the tallest."}, + { + "title": "Penguin habitats", + "text": "Emperor penguins only live in Antarctica.", + }, + ], + stream=stream, + ) + + if stream: + citations_chunk = False + async for chunk in response: + print("received chunk", chunk) + if "citations" in chunk: + citations_chunk = True + break + assert citations_chunk + else: + assert response.citations is not None + except Exception as e: + pytest.fail(f"Error occurred: {e}") diff --git a/tests/local_testing/test_cohere_completion.py b/tests/local_testing/test_cohere_completion.py deleted file mode 100644 index e90818fee..000000000 --- a/tests/local_testing/test_cohere_completion.py +++ /dev/null @@ -1,210 +0,0 @@ -import os -import sys -import traceback - -from dotenv import load_dotenv - -load_dotenv() -import io -import os - -sys.path.insert( - 0, os.path.abspath("../..") -) # Adds the parent directory to the system path -import json - -import pytest - -import litellm -from litellm import RateLimitError, Timeout, completion, completion_cost, embedding - -litellm.num_retries = 3 - - -# FYI - cohere_chat looks quite unstable, even when testing locally -def test_chat_completion_cohere(): - try: - litellm.set_verbose = True - messages = [ - { - "role": "user", - "content": "Hey", - }, - ] - response = completion( - model="cohere_chat/command-r", - messages=messages, - max_tokens=10, - ) - print(response) - except Exception as e: - pytest.fail(f"Error occurred: {e}") - - -def test_chat_completion_cohere_tool_calling(): - try: - litellm.set_verbose = True - messages = [ - { - "role": "user", - "content": "What is the weather like in Boston?", - }, - ] - response = completion( - model="cohere_chat/command-r", - messages=messages, - tools=[ - { - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA", - }, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"], - }, - }, - "required": ["location"], - }, - }, - } - ], - ) - print(response) - except Exception as e: - pytest.fail(f"Error occurred: {e}") - - # def get_current_weather(location, unit="fahrenheit"): - # """Get the current weather in a given location""" - # if "tokyo" in location.lower(): - # return json.dumps({"location": "Tokyo", "temperature": "10", "unit": unit}) - # elif "san francisco" in location.lower(): - # return json.dumps({"location": "San Francisco", "temperature": "72", "unit": unit}) - # elif "paris" in location.lower(): - # return json.dumps({"location": "Paris", "temperature": "22", "unit": unit}) - # else: - # return json.dumps({"location": location, "temperature": "unknown"}) - - # def test_chat_completion_cohere_tool_with_result_calling(): - # # end to end cohere command-r with tool calling - # # Step 1 - Send available tools - # # Step 2 - Execute results - # # Step 3 - Send results to command-r - # try: - # litellm.set_verbose = True - # import json - - # # Step 1 - Send available tools - # tools = [ - # { - # "type": "function", - # "function": { - # "name": "get_current_weather", - # "description": "Get the current weather in a given location", - # "parameters": { - # "type": "object", - # "properties": { - # "location": { - # "type": "string", - # "description": "The city and state, e.g. San Francisco, CA", - # }, - # "unit": { - # "type": "string", - # "enum": ["celsius", "fahrenheit"], - # }, - # }, - # "required": ["location"], - # }, - # }, - # } - # ] - - # messages = [ - # { - # "role": "user", - # "content": "What is the weather like in Boston?", - # }, - # ] - # response = completion( - # model="cohere_chat/command-r", - # messages=messages, - # tools=tools, - # ) - # print("Response with tools to call", response) - # print(response) - - # # step 2 - Execute results - # tool_calls = response.tool_calls - - # available_functions = { - # "get_current_weather": get_current_weather, - # } # only one function in this example, but you can have multiple - - # for tool_call in tool_calls: - # function_name = tool_call.function.name - # function_to_call = available_functions[function_name] - # function_args = json.loads(tool_call.function.arguments) - # function_response = function_to_call( - # location=function_args.get("location"), - # unit=function_args.get("unit"), - # ) - # messages.append( - # { - # "tool_call_id": tool_call.id, - # "role": "tool", - # "name": function_name, - # "content": function_response, - # } - # ) # extend conversation with function response - - # print("messages with tool call results", messages) - - # messages = [ - # { - # "role": "user", - # "content": "What is the weather like in Boston?", - # }, - # { - # "tool_call_id": "tool_1", - # "role": "tool", - # "name": "get_current_weather", - # "content": {"location": "San Francisco, CA", "unit": "fahrenheit", "temperature": "72"}, - # }, - # ] - # respone = completion( - # model="cohere_chat/command-r", - # messages=messages, - # tools=[ - # { - # "type": "function", - # "function": { - # "name": "get_current_weather", - # "description": "Get the current weather in a given location", - # "parameters": { - # "type": "object", - # "properties": { - # "location": { - # "type": "string", - # "description": "The city and state, e.g. San Francisco, CA", - # }, - # "unit": { - # "type": "string", - # "enum": ["celsius", "fahrenheit"], - # }, - # }, - # "required": ["location"], - # }, - # }, - # } - # ], - # ) - # print(respone) - except Exception as e: - pytest.fail(f"Error occurred: {e}") diff --git a/tests/local_testing/test_function_calling.py b/tests/local_testing/test_function_calling.py index 6e1bd13a1..52a1fc697 100644 --- a/tests/local_testing/test_function_calling.py +++ b/tests/local_testing/test_function_calling.py @@ -46,11 +46,12 @@ def get_current_weather(location, unit="fahrenheit"): "model", [ "gpt-3.5-turbo-1106", - # "mistral/mistral-large-latest", + "mistral/mistral-large-latest", "claude-3-haiku-20240307", "gemini/gemini-1.5-pro", "anthropic.claude-3-sonnet-20240229-v1:0", - # "groq/llama3-8b-8192", + "groq/llama3-8b-8192", + "cohere_chat/command-r", ], ) @pytest.mark.flaky(retries=3, delay=1)