mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
fix(utils.py): remove special characters from streaming output
This commit is contained in:
parent
a2f2fd3841
commit
6e7e409615
2 changed files with 60 additions and 34 deletions
|
@ -3686,6 +3686,8 @@ class CustomStreamWrapper:
|
|||
self.completion_stream = completion_stream
|
||||
self.sent_first_chunk = False
|
||||
self.sent_last_chunk = False
|
||||
self.special_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"]
|
||||
self.holding_chunk = ""
|
||||
if self.logging_obj:
|
||||
# Log the type of the received item
|
||||
self.logging_obj.post_call(str(type(completion_stream)))
|
||||
|
@ -3699,6 +3701,29 @@ class CustomStreamWrapper:
|
|||
def logging(self, text):
|
||||
if self.logging_obj:
|
||||
self.logging_obj.post_call(text)
|
||||
|
||||
def check_special_tokens(self, chunk: str):
|
||||
hold = False
|
||||
if self.sent_first_chunk is True:
|
||||
return hold, chunk
|
||||
|
||||
curr_chunk = self.holding_chunk + chunk
|
||||
curr_chunk = curr_chunk.strip()
|
||||
|
||||
for token in self.special_tokens:
|
||||
if len(curr_chunk) < len(token) and curr_chunk in token:
|
||||
hold = True
|
||||
elif len(curr_chunk) >= len(token):
|
||||
if token in curr_chunk:
|
||||
self.holding_chunk = curr_chunk.replace(token, "")
|
||||
hold = True
|
||||
else:
|
||||
pass
|
||||
|
||||
if hold is False: # reset
|
||||
self.holding_chunk = ""
|
||||
return hold, curr_chunk
|
||||
|
||||
|
||||
def handle_anthropic_chunk(self, chunk):
|
||||
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||
|
@ -4100,13 +4125,16 @@ class CustomStreamWrapper:
|
|||
|
||||
model_response.model = self.model
|
||||
if len(completion_obj["content"]) > 0: # cannot set content of an OpenAI Object to be an empty string
|
||||
if self.sent_first_chunk == False:
|
||||
completion_obj["role"] = "assistant"
|
||||
self.sent_first_chunk = True
|
||||
model_response.choices[0].delta = Delta(**completion_obj)
|
||||
# LOGGING
|
||||
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start()
|
||||
return model_response
|
||||
hold, model_response_str = self.check_special_tokens(completion_obj["content"])
|
||||
if hold is False:
|
||||
completion_obj["content"] = model_response_str
|
||||
if self.sent_first_chunk == False:
|
||||
completion_obj["role"] = "assistant"
|
||||
self.sent_first_chunk = True
|
||||
model_response.choices[0].delta = Delta(**completion_obj)
|
||||
# LOGGING
|
||||
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start()
|
||||
return model_response
|
||||
elif model_response.choices[0].finish_reason:
|
||||
model_response.choices[0].finish_reason = map_finish_reason(model_response.choices[0].finish_reason) # ensure consistent output to openai
|
||||
# LOGGING
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue