From 94dc3f66f338a49e6b3bf4f8c3d2c1093ddd1399 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 23 Nov 2023 17:47:39 -0800 Subject: [PATCH] fix(utils.py): remove eos token for zephyr models --- litellm/tests/test_completion.py | 15 --------------- litellm/tests/test_streaming.py | 25 ++++++++++++++++++++++++- litellm/utils.py | 11 +++++++++-- 3 files changed, 33 insertions(+), 18 deletions(-) diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index d2300fd386..e905fb464a 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -286,21 +286,6 @@ def hf_test_completion_tgi(): pytest.fail(f"Error occurred: {e}") # hf_test_completion_tgi() -def hf_test_completion_tgi_stream(): - try: - response = completion( - model = 'huggingface/HuggingFaceH4/zephyr-7b-beta', - messages = [{ "content": "Hello, how are you?","role": "user"}], - stream=True - ) - # Add any assertions here to check the response - print(response) - for chunk in response: - print(chunk["choices"][0]["delta"]["content"]) - except Exception as e: - pytest.fail(f"Error occurred: {e}") -# hf_test_completion_tgi_stream() - # ################### Hugging Face Conversational models ######################## # def hf_test_completion_conv(): # try: diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index c76599c002..1a635460a2 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -631,6 +631,29 @@ def ai21_completion_call_bad_key(): # ai21_completion_call_bad_key() +def hf_test_completion_tgi_stream(): + try: + response = completion( + model = 'huggingface/HuggingFaceH4/zephyr-7b-beta', + messages = [{ "content": "Hello, how are you?","role": "user"}], + stream=True + ) + # Add any assertions here to check the response + print(f"response: {response}") + complete_response = "" + start_time = time.time() + for idx, chunk in enumerate(response): + chunk, finished = streaming_format_tests(idx, chunk) + complete_response += chunk + if finished: + break + if complete_response.strip() == "": + raise Exception("Empty response received") + print(f"completion_response: {complete_response}") + except Exception as e: + pytest.fail(f"Error occurred: {e}") +hf_test_completion_tgi_stream() + # def test_completion_aleph_alpha(): # try: # response = completion( @@ -706,7 +729,7 @@ def test_openai_chat_completion_call(): print(f"error occurred: {traceback.format_exc()}") pass -test_openai_chat_completion_call() +# test_openai_chat_completion_call() def test_openai_chat_completion_complete_response_call(): try: diff --git a/litellm/utils.py b/litellm/utils.py index 2a2dc9d204..f7462ccab2 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4538,8 +4538,14 @@ class CustomStreamWrapper: if self.logging_obj: self.logging_obj.post_call(text) - def check_special_tokens(self, chunk: str): + def check_special_tokens(self, chunk: str, finish_reason: Optional[str]): hold = False + if finish_reason: + for token in self.special_tokens: + if token in chunk: + chunk = chunk.replace(token, "") + return hold, chunk + if self.sent_first_chunk is True: return hold, chunk @@ -4996,8 +5002,9 @@ class CustomStreamWrapper: model_response.model = self.model print_verbose(f"model_response: {model_response}; completion_obj: {completion_obj}") print_verbose(f"model_response finish reason 3: {model_response.choices[0].finish_reason}") + if len(completion_obj["content"]) > 0: # cannot set content of an OpenAI Object to be an empty string - hold, model_response_str = self.check_special_tokens(completion_obj["content"]) + hold, model_response_str = self.check_special_tokens(chunk=completion_obj["content"], finish_reason=model_response.choices[0].finish_reason) if hold is False: completion_obj["content"] = model_response_str if self.sent_first_chunk == False: