From e5bb65669d926d3d20a13ea34b8341823c0843d5 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 12 Mar 2024 10:45:42 -0700 Subject: [PATCH] (feat) exception mapping for cohere_chat --- litellm/tests/test_completion.py | 23 ++++++++++++++++++++ litellm/utils.py | 36 +++++++++++++++++++++++++++++++- 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 0bb26ad68..b298cec4a 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -1980,6 +1980,29 @@ def test_chat_completion_cohere(): pytest.fail(f"Error occurred: {e}") +def test_chat_completion_cohere_stream(): + try: + litellm.set_verbose = False + messages = [ + {"role": "system", "content": "You're a good bot"}, + { + "role": "user", + "content": "Hey", + }, + ] + response = completion( + model="cohere_chat/command-r", + messages=messages, + max_tokens=10, + stream=True, + ) + print(response) + for chunk in response: + print(chunk) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + def test_azure_cloudflare_api(): litellm.set_verbose = True try: diff --git a/litellm/utils.py b/litellm/utils.py index 3b6169770..5caea73b0 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -7411,7 +7411,9 @@ def exception_type( model=model, response=original_exception.response, ) - elif custom_llm_provider == "cohere": # Cohere + elif ( + custom_llm_provider == "cohere" or custom_llm_provider == "cohere_chat" + ): # Cohere if ( "invalid api token" in error_str or "No API key provided." in error_str @@ -8544,6 +8546,29 @@ class CustomStreamWrapper: except: raise ValueError(f"Unable to parse response. Original response: {chunk}") + def handle_cohere_chat_chunk(self, chunk): + chunk = chunk.decode("utf-8") + data_json = json.loads(chunk) + print_verbose(f"chunk: {chunk}") + try: + text = "" + is_finished = False + finish_reason = "" + if "text" in data_json: + text = data_json["text"] + elif "is_finished" in data_json and data_json["is_finished"] == True: + is_finished = data_json["is_finished"] + finish_reason = data_json["finish_reason"] + else: + return + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } + except: + raise ValueError(f"Unable to parse response. Original response: {chunk}") + def handle_azure_chunk(self, chunk): is_finished = False finish_reason = "" @@ -9052,6 +9077,15 @@ class CustomStreamWrapper: model_response.choices[0].finish_reason = response_obj[ "finish_reason" ] + elif self.custom_llm_provider == "cohere_chat": + response_obj = self.handle_cohere_chat_chunk(chunk) + if response_obj is None: + return + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj[ + "finish_reason" + ] elif self.custom_llm_provider == "bedrock": if self.sent_last_chunk: raise StopIteration