feat(sagemaker.py): aioboto3 streaming support

This commit is contained in:
Krrish Dholakia 2024-02-12 21:18:34 -08:00
parent 23c410a548
commit 5de569fcb1
4 changed files with 80 additions and 13 deletions

View file

@ -30,8 +30,11 @@ import json
class TokenIterator:
def __init__(self, stream):
self.byte_iterator = iter(stream)
def __init__(self, stream, acompletion: bool):
if acompletion == False:
self.byte_iterator = iter(stream)
elif acompletion == True:
self.byte_iterator = stream
self.buffer = io.BytesIO()
self.read_pos = 0
self.end_of_data = False
@ -64,6 +67,34 @@ class TokenIterator:
self.end_of_data = True
return "data: [DONE]"
def __aiter__(self):
return self
async def __anext__(self):
try:
while True:
self.buffer.seek(self.read_pos)
line = self.buffer.readline()
if line and line[-1] == ord("\n"):
response_obj = {"text": "", "is_finished": False}
self.read_pos += len(line) + 1
full_line = line[:-1].decode("utf-8")
line_data = json.loads(full_line.lstrip("data:").rstrip("/n"))
if line_data.get("generated_text", None) is not None:
self.end_of_data = True
response_obj["is_finished"] = True
response_obj["text"] = line_data["token"]["text"]
return response_obj
chunk = await anext(self.byte_iterator)
self.buffer.seek(0, io.SEEK_END)
self.buffer.write(chunk["PayloadPart"]["Bytes"])
except StopAsyncIteration as e:
if self.end_of_data == True:
raise e # Re-raise StopIteration
else:
self.end_of_data = True
return "data: [DONE]"
class SagemakerConfig:
"""
@ -197,15 +228,16 @@ def completion(
data = json.dumps(
{"inputs": prompt, "parameters": inference_params, "stream": True}
).encode("utf-8")
## LOGGING
request_str = f"""
response = client.invoke_endpoint_with_response_stream(
EndpointName={model},
ContentType="application/json",
Body={data},
CustomAttributes="accept_eula=true",
)
""" # type: ignore
if acompletion == True:
response = async_streaming(
optional_params=optional_params,
encoding=encoding,
model_response=model_response,
model=model,
logging_obj=logging_obj,
data=data,
)
return response
response = client.invoke_endpoint_with_response_stream(
EndpointName=model,
ContentType="application/json",
@ -311,6 +343,37 @@ def completion(
return model_response
async def async_streaming(
optional_params,
encoding,
model_response: ModelResponse,
model: str,
logging_obj: Any,
data,
):
"""
Use aioboto3
"""
import aioboto3
session = aioboto3.Session()
async with session.client("sagemaker-runtime", region_name="us-west-2") as client:
try:
response = await client.invoke_endpoint_with_response_stream(
EndpointName=model,
ContentType="application/json",
Body=data,
CustomAttributes="accept_eula=true",
)
except Exception as e:
raise SagemakerError(status_code=500, message=f"{str(e)}")
response = response["Body"]
# filtered_response = TokenIterator(stream=response, acompletion=True)
async for chunk in response:
yield chunk
# return
async def async_completion(
optional_params,
encoding,