diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index f29aad38f..b4394d345 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -36,7 +36,7 @@ def test_completion_custom_provider_model_name(): pytest.fail(f"Error occurred: {e}") -test_completion_custom_provider_model_name() +# test_completion_custom_provider_model_name() def test_completion_claude(): @@ -221,34 +221,32 @@ def test_get_hf_task_for_model(): # ################### Hugging Face TGI models ######################## # # TGI model # # this is a TGI model https://huggingface.co/glaiveai/glaive-coder-7b -# def hf_test_completion_tgi(): -# litellm.set_verbose=True -# try: -# response = litellm.completion( -# model="huggingface/mistralai/Mistral-7B-Instruct-v0.1", -# messages=[{ "content": "Hello, how are you?","role": "user"}], -# api_base="https://3kk3h56912qga4-80.proxy.runpod.net", -# ) -# # Add any assertions here to check the response -# print(response) -# except Exception as e: -# pytest.fail(f"Error occurred: {e}") +def hf_test_completion_tgi(): + # litellm.set_verbose=True + try: + response = completion( + model = 'huggingface/HuggingFaceH4/zephyr-7b-beta', + messages = [{ "content": "Hello, how are you?","role": "user"}], + ) + # Add any assertions here to check the response + print(response) + except Exception as e: + pytest.fail(f"Error occurred: {e}") # hf_test_completion_tgi() -# def hf_test_completion_tgi_stream(): -# try: -# response = litellm.completion( -# model="huggingface/glaiveai/glaive-coder-7b", -# messages=[{ "content": "Hello, how are you?","role": "user"}], -# api_base="https://wjiegasee9bmqke2.us-east-1.aws.endpoints.huggingface.cloud", -# stream=True -# ) -# # Add any assertions here to check the response -# print(response) -# for chunk in response: -# print(chunk) -# except Exception as e: -# pytest.fail(f"Error occurred: {e}") +def hf_test_completion_tgi_stream(): + try: + response = completion( + model = 'huggingface/HuggingFaceH4/zephyr-7b-beta', + messages = [{ "content": "Hello, how are you?","role": "user"}], + stream=True + ) + # Add any assertions here to check the response + print(response) + for chunk in response: + print(chunk["choices"][0]["delta"]["content"]) + except Exception as e: + pytest.fail(f"Error occurred: {e}") # hf_test_completion_tgi_stream() # ################### Hugging Face Conversational models ######################## diff --git a/litellm/utils.py b/litellm/utils.py index ddc9b0825..8291c20e9 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -3686,6 +3686,8 @@ class CustomStreamWrapper: self.completion_stream = completion_stream self.sent_first_chunk = False self.sent_last_chunk = False + self.special_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "", ""] + self.holding_chunk = "" if self.logging_obj: # Log the type of the received item self.logging_obj.post_call(str(type(completion_stream))) @@ -3699,6 +3701,29 @@ class CustomStreamWrapper: def logging(self, text): if self.logging_obj: self.logging_obj.post_call(text) + + def check_special_tokens(self, chunk: str): + hold = False + if self.sent_first_chunk is True: + return hold, chunk + + curr_chunk = self.holding_chunk + chunk + curr_chunk = curr_chunk.strip() + + for token in self.special_tokens: + if len(curr_chunk) < len(token) and curr_chunk in token: + hold = True + elif len(curr_chunk) >= len(token): + if token in curr_chunk: + self.holding_chunk = curr_chunk.replace(token, "") + hold = True + else: + pass + + if hold is False: # reset + self.holding_chunk = "" + return hold, curr_chunk + def handle_anthropic_chunk(self, chunk): str_line = chunk.decode("utf-8") # Convert bytes to string @@ -4100,13 +4125,16 @@ class CustomStreamWrapper: model_response.model = self.model if len(completion_obj["content"]) > 0: # cannot set content of an OpenAI Object to be an empty string - if self.sent_first_chunk == False: - completion_obj["role"] = "assistant" - self.sent_first_chunk = True - model_response.choices[0].delta = Delta(**completion_obj) - # LOGGING - threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start() - return model_response + hold, model_response_str = self.check_special_tokens(completion_obj["content"]) + if hold is False: + completion_obj["content"] = model_response_str + if self.sent_first_chunk == False: + completion_obj["role"] = "assistant" + self.sent_first_chunk = True + model_response.choices[0].delta = Delta(**completion_obj) + # LOGGING + threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start() + return model_response elif model_response.choices[0].finish_reason: model_response.choices[0].finish_reason = map_finish_reason(model_response.choices[0].finish_reason) # ensure consistent output to openai # LOGGING