feat(sagemaker.py): aioboto3 streaming support

This commit is contained in:
Krrish Dholakia 2024-02-12 21:18:34 -08:00
parent 23c410a548
commit 5de569fcb1
4 changed files with 80 additions and 13 deletions

View file

@ -30,8 +30,11 @@ import json
class TokenIterator: class TokenIterator:
def __init__(self, stream): def __init__(self, stream, acompletion: bool):
self.byte_iterator = iter(stream) if acompletion == False:
self.byte_iterator = iter(stream)
elif acompletion == True:
self.byte_iterator = stream
self.buffer = io.BytesIO() self.buffer = io.BytesIO()
self.read_pos = 0 self.read_pos = 0
self.end_of_data = False self.end_of_data = False
@ -64,6 +67,34 @@ class TokenIterator:
self.end_of_data = True self.end_of_data = True
return "data: [DONE]" return "data: [DONE]"
def __aiter__(self):
return self
async def __anext__(self):
try:
while True:
self.buffer.seek(self.read_pos)
line = self.buffer.readline()
if line and line[-1] == ord("\n"):
response_obj = {"text": "", "is_finished": False}
self.read_pos += len(line) + 1
full_line = line[:-1].decode("utf-8")
line_data = json.loads(full_line.lstrip("data:").rstrip("/n"))
if line_data.get("generated_text", None) is not None:
self.end_of_data = True
response_obj["is_finished"] = True
response_obj["text"] = line_data["token"]["text"]
return response_obj
chunk = await anext(self.byte_iterator)
self.buffer.seek(0, io.SEEK_END)
self.buffer.write(chunk["PayloadPart"]["Bytes"])
except StopAsyncIteration as e:
if self.end_of_data == True:
raise e # Re-raise StopIteration
else:
self.end_of_data = True
return "data: [DONE]"
class SagemakerConfig: class SagemakerConfig:
""" """
@ -197,15 +228,16 @@ def completion(
data = json.dumps( data = json.dumps(
{"inputs": prompt, "parameters": inference_params, "stream": True} {"inputs": prompt, "parameters": inference_params, "stream": True}
).encode("utf-8") ).encode("utf-8")
## LOGGING if acompletion == True:
request_str = f""" response = async_streaming(
response = client.invoke_endpoint_with_response_stream( optional_params=optional_params,
EndpointName={model}, encoding=encoding,
ContentType="application/json", model_response=model_response,
Body={data}, model=model,
CustomAttributes="accept_eula=true", logging_obj=logging_obj,
) data=data,
""" # type: ignore )
return response
response = client.invoke_endpoint_with_response_stream( response = client.invoke_endpoint_with_response_stream(
EndpointName=model, EndpointName=model,
ContentType="application/json", ContentType="application/json",
@ -311,6 +343,37 @@ def completion(
return model_response return model_response
async def async_streaming(
optional_params,
encoding,
model_response: ModelResponse,
model: str,
logging_obj: Any,
data,
):
"""
Use aioboto3
"""
import aioboto3
session = aioboto3.Session()
async with session.client("sagemaker-runtime", region_name="us-west-2") as client:
try:
response = await client.invoke_endpoint_with_response_stream(
EndpointName=model,
ContentType="application/json",
Body=data,
CustomAttributes="accept_eula=true",
)
except Exception as e:
raise SagemakerError(status_code=500, message=f"{str(e)}")
response = response["Body"]
# filtered_response = TokenIterator(stream=response, acompletion=True)
async for chunk in response:
yield chunk
# return
async def async_completion( async def async_completion(
optional_params, optional_params,
encoding, encoding,

View file

@ -1562,7 +1562,7 @@ def completion(
print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER") print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER")
from .llms.sagemaker import TokenIterator from .llms.sagemaker import TokenIterator
tokenIterator = TokenIterator(model_response) tokenIterator = TokenIterator(model_response, acompletion=acompletion)
response = CustomStreamWrapper( response = CustomStreamWrapper(
completion_stream=tokenIterator, completion_stream=tokenIterator,
model=model, model=model,

View file

@ -876,7 +876,6 @@ async def test_sagemaker_streaming_async():
temperature=0.7, temperature=0.7,
stream=True, stream=True,
) )
# Add any assertions here to check the response # Add any assertions here to check the response
print(response) print(response)
complete_response = "" complete_response = ""
@ -900,6 +899,9 @@ async def test_sagemaker_streaming_async():
pytest.fail(f"An exception occurred - {str(e)}") pytest.fail(f"An exception occurred - {str(e)}")
asyncio.run(test_sagemaker_streaming_async())
def test_completion_sagemaker_stream(): def test_completion_sagemaker_stream():
try: try:
response = completion( response = completion(

View file

@ -8691,6 +8691,8 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "ollama" or self.custom_llm_provider == "ollama"
or self.custom_llm_provider == "ollama_chat" or self.custom_llm_provider == "ollama_chat"
or self.custom_llm_provider == "vertex_ai" or self.custom_llm_provider == "vertex_ai"
or self.custom_llm_provider == "sagemaker"
or self.custom_llm_provider in litellm.openai_compatible_endpoints
): ):
print_verbose( print_verbose(
f"value of async completion stream: {self.completion_stream}" f"value of async completion stream: {self.completion_stream}"