fix(utils.py): fix sagemaker async logging for sync streaming

https://github.com/BerriAI/litellm/issues/1592
This commit is contained in:
Krrish Dholakia 2024-01-25 12:49:45 -08:00
parent 39d5407e67
commit 09ec6d6458
10 changed files with 247 additions and 64 deletions

View file

@ -34,22 +34,35 @@ class TokenIterator:
self.byte_iterator = iter(stream)
self.buffer = io.BytesIO()
self.read_pos = 0
self.end_of_data = False
def __iter__(self):
return self
def __next__(self):
while True:
self.buffer.seek(self.read_pos)
line = self.buffer.readline()
if line and line[-1] == ord("\n"):
self.read_pos += len(line) + 1
full_line = line[:-1].decode("utf-8")
line_data = json.loads(full_line.lstrip("data:").rstrip("/n"))
return line_data["token"]["text"]
chunk = next(self.byte_iterator)
self.buffer.seek(0, io.SEEK_END)
self.buffer.write(chunk["PayloadPart"]["Bytes"])
try:
while True:
self.buffer.seek(self.read_pos)
line = self.buffer.readline()
if line and line[-1] == ord("\n"):
response_obj = {"text": "", "is_finished": False}
self.read_pos += len(line) + 1
full_line = line[:-1].decode("utf-8")
line_data = json.loads(full_line.lstrip("data:").rstrip("/n"))
if line_data.get("generated_text", None) is not None:
self.end_of_data = True
response_obj["is_finished"] = True
response_obj["text"] = line_data["token"]["text"]
return response_obj
chunk = next(self.byte_iterator)
self.buffer.seek(0, io.SEEK_END)
self.buffer.write(chunk["PayloadPart"]["Bytes"])
except StopIteration as e:
if self.end_of_data == True:
raise e # Re-raise StopIteration
else:
self.end_of_data = True
return "data: [DONE]"
class SagemakerConfig: