fix(utils.py): remove special characters from streaming output

This commit is contained in:
Krrish Dholakia 2023-11-06 12:21:31 -08:00
parent a2f2fd3841
commit 6e7e409615
2 changed files with 60 additions and 34 deletions

View file

@ -36,7 +36,7 @@ def test_completion_custom_provider_model_name():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_completion_custom_provider_model_name() # test_completion_custom_provider_model_name()
def test_completion_claude(): def test_completion_claude():
@ -221,34 +221,32 @@ def test_get_hf_task_for_model():
# ################### Hugging Face TGI models ######################## # ################### Hugging Face TGI models ########################
# # TGI model # # TGI model
# # this is a TGI model https://huggingface.co/glaiveai/glaive-coder-7b # # this is a TGI model https://huggingface.co/glaiveai/glaive-coder-7b
# def hf_test_completion_tgi(): def hf_test_completion_tgi():
# litellm.set_verbose=True # litellm.set_verbose=True
# try: try:
# response = litellm.completion( response = completion(
# model="huggingface/mistralai/Mistral-7B-Instruct-v0.1", model = 'huggingface/HuggingFaceH4/zephyr-7b-beta',
# messages=[{ "content": "Hello, how are you?","role": "user"}], messages = [{ "content": "Hello, how are you?","role": "user"}],
# api_base="https://3kk3h56912qga4-80.proxy.runpod.net", )
# ) # Add any assertions here to check the response
# # Add any assertions here to check the response print(response)
# print(response) except Exception as e:
# except Exception as e: pytest.fail(f"Error occurred: {e}")
# pytest.fail(f"Error occurred: {e}")
# hf_test_completion_tgi() # hf_test_completion_tgi()
# def hf_test_completion_tgi_stream(): def hf_test_completion_tgi_stream():
# try: try:
# response = litellm.completion( response = completion(
# model="huggingface/glaiveai/glaive-coder-7b", model = 'huggingface/HuggingFaceH4/zephyr-7b-beta',
# messages=[{ "content": "Hello, how are you?","role": "user"}], messages = [{ "content": "Hello, how are you?","role": "user"}],
# api_base="https://wjiegasee9bmqke2.us-east-1.aws.endpoints.huggingface.cloud", stream=True
# stream=True )
# ) # Add any assertions here to check the response
# # Add any assertions here to check the response print(response)
# print(response) for chunk in response:
# for chunk in response: print(chunk["choices"][0]["delta"]["content"])
# print(chunk) except Exception as e:
# except Exception as e: pytest.fail(f"Error occurred: {e}")
# pytest.fail(f"Error occurred: {e}")
# hf_test_completion_tgi_stream() # hf_test_completion_tgi_stream()
# ################### Hugging Face Conversational models ######################## # ################### Hugging Face Conversational models ########################

View file

@ -3686,6 +3686,8 @@ class CustomStreamWrapper:
self.completion_stream = completion_stream self.completion_stream = completion_stream
self.sent_first_chunk = False self.sent_first_chunk = False
self.sent_last_chunk = False self.sent_last_chunk = False
self.special_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"]
self.holding_chunk = ""
if self.logging_obj: if self.logging_obj:
# Log the type of the received item # Log the type of the received item
self.logging_obj.post_call(str(type(completion_stream))) self.logging_obj.post_call(str(type(completion_stream)))
@ -3699,6 +3701,29 @@ class CustomStreamWrapper:
def logging(self, text): def logging(self, text):
if self.logging_obj: if self.logging_obj:
self.logging_obj.post_call(text) self.logging_obj.post_call(text)
def check_special_tokens(self, chunk: str):
hold = False
if self.sent_first_chunk is True:
return hold, chunk
curr_chunk = self.holding_chunk + chunk
curr_chunk = curr_chunk.strip()
for token in self.special_tokens:
if len(curr_chunk) < len(token) and curr_chunk in token:
hold = True
elif len(curr_chunk) >= len(token):
if token in curr_chunk:
self.holding_chunk = curr_chunk.replace(token, "")
hold = True
else:
pass
if hold is False: # reset
self.holding_chunk = ""
return hold, curr_chunk
def handle_anthropic_chunk(self, chunk): def handle_anthropic_chunk(self, chunk):
str_line = chunk.decode("utf-8") # Convert bytes to string str_line = chunk.decode("utf-8") # Convert bytes to string
@ -4100,13 +4125,16 @@ class CustomStreamWrapper:
model_response.model = self.model model_response.model = self.model
if len(completion_obj["content"]) > 0: # cannot set content of an OpenAI Object to be an empty string if len(completion_obj["content"]) > 0: # cannot set content of an OpenAI Object to be an empty string
if self.sent_first_chunk == False: hold, model_response_str = self.check_special_tokens(completion_obj["content"])
completion_obj["role"] = "assistant" if hold is False:
self.sent_first_chunk = True completion_obj["content"] = model_response_str
model_response.choices[0].delta = Delta(**completion_obj) if self.sent_first_chunk == False:
# LOGGING completion_obj["role"] = "assistant"
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start() self.sent_first_chunk = True
return model_response model_response.choices[0].delta = Delta(**completion_obj)
# LOGGING
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start()
return model_response
elif model_response.choices[0].finish_reason: elif model_response.choices[0].finish_reason:
model_response.choices[0].finish_reason = map_finish_reason(model_response.choices[0].finish_reason) # ensure consistent output to openai model_response.choices[0].finish_reason = map_finish_reason(model_response.choices[0].finish_reason) # ensure consistent output to openai
# LOGGING # LOGGING