fix(utils.py): remove special characters from streaming output

This commit is contained in:
Krrish Dholakia 2023-11-06 12:21:31 -08:00
parent a2f2fd3841
commit 6e7e409615
2 changed files with 60 additions and 34 deletions

View file

@ -3686,6 +3686,8 @@ class CustomStreamWrapper:
self.completion_stream = completion_stream
self.sent_first_chunk = False
self.sent_last_chunk = False
self.special_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"]
self.holding_chunk = ""
if self.logging_obj:
# Log the type of the received item
self.logging_obj.post_call(str(type(completion_stream)))
@ -3699,6 +3701,29 @@ class CustomStreamWrapper:
def logging(self, text):
if self.logging_obj:
self.logging_obj.post_call(text)
def check_special_tokens(self, chunk: str):
hold = False
if self.sent_first_chunk is True:
return hold, chunk
curr_chunk = self.holding_chunk + chunk
curr_chunk = curr_chunk.strip()
for token in self.special_tokens:
if len(curr_chunk) < len(token) and curr_chunk in token:
hold = True
elif len(curr_chunk) >= len(token):
if token in curr_chunk:
self.holding_chunk = curr_chunk.replace(token, "")
hold = True
else:
pass
if hold is False: # reset
self.holding_chunk = ""
return hold, curr_chunk
def handle_anthropic_chunk(self, chunk):
str_line = chunk.decode("utf-8") # Convert bytes to string
@ -4100,13 +4125,16 @@ class CustomStreamWrapper:
model_response.model = self.model
if len(completion_obj["content"]) > 0: # cannot set content of an OpenAI Object to be an empty string
if self.sent_first_chunk == False:
completion_obj["role"] = "assistant"
self.sent_first_chunk = True
model_response.choices[0].delta = Delta(**completion_obj)
# LOGGING
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start()
return model_response
hold, model_response_str = self.check_special_tokens(completion_obj["content"])
if hold is False:
completion_obj["content"] = model_response_str
if self.sent_first_chunk == False:
completion_obj["role"] = "assistant"
self.sent_first_chunk = True
model_response.choices[0].delta = Delta(**completion_obj)
# LOGGING
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start()
return model_response
elif model_response.choices[0].finish_reason:
model_response.choices[0].finish_reason = map_finish_reason(model_response.choices[0].finish_reason) # ensure consistent output to openai
# LOGGING