fix(utils.py): fix streaming special character flushing logic

This commit is contained in:
Krrish Dholakia 2024-04-17 18:03:40 -07:00
parent 7d0086d742
commit 15ae7a8314
2 changed files with 8 additions and 8 deletions

View file

@ -8860,11 +8860,11 @@ class CustomStreamWrapper:
Output parse <s> / </s> special tokens for sagemaker + hf streaming.
"""
hold = False
if (
self.custom_llm_provider != "huggingface"
and self.custom_llm_provider != "sagemaker"
):
return hold, chunk
# if (
# self.custom_llm_provider != "huggingface"
# and self.custom_llm_provider != "sagemaker"
# ):
# return hold, chunk
if finish_reason:
for token in self.special_tokens:
@ -8881,6 +8881,7 @@ class CustomStreamWrapper:
for token in self.special_tokens:
if len(curr_chunk) < len(token) and curr_chunk in token:
hold = True
self.holding_chunk = curr_chunk
elif len(curr_chunk) >= len(token):
if token in curr_chunk:
self.holding_chunk = curr_chunk.replace(token, "")
@ -9962,6 +9963,7 @@ class CustomStreamWrapper:
f"model_response.choices[0].delta: {model_response.choices[0].delta}; completion_obj: {completion_obj}"
)
print_verbose(f"self.sent_first_chunk: {self.sent_first_chunk}")
## RETURN ARG
if (
"content" in completion_obj
@ -10034,7 +10036,6 @@ class CustomStreamWrapper:
elif self.received_finish_reason is not None:
if self.sent_last_chunk == True:
raise StopIteration
# flush any remaining holding chunk
if len(self.holding_chunk) > 0:
if model_response.choices[0].delta.content is None: