Merge pull request #1561 from BerriAI/litellm_sagemaker_streaming

[Feat] Add REAL Sagemaker streaming
This commit is contained in:
Ishaan Jaff 2024-01-22 22:10:20 -08:00 committed by GitHub
commit 97dd61a6cb
5 changed files with 141 additions and 15 deletions

View file

@ -0,0 +1,61 @@
# Notes - on how to do sagemaker streaming using boto3
import json
import boto3
import sys, os
import traceback
from dotenv import load_dotenv
load_dotenv()
import os, io
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
import io
import json
class TokenIterator:
def __init__(self, stream):
self.byte_iterator = iter(stream)
self.buffer = io.BytesIO()
self.read_pos = 0
def __iter__(self):
return self
def __next__(self):
while True:
self.buffer.seek(self.read_pos)
line = self.buffer.readline()
if line and line[-1] == ord("\n"):
self.read_pos += len(line) + 1
full_line = line[:-1].decode("utf-8")
line_data = json.loads(full_line.lstrip("data:").rstrip("/n"))
return line_data["token"]["text"]
chunk = next(self.byte_iterator)
self.buffer.seek(0, io.SEEK_END)
self.buffer.write(chunk["PayloadPart"]["Bytes"])
payload = {
"inputs": "How do I build a website?",
"parameters": {"max_new_tokens": 256},
"stream": True,
}
import boto3
client = boto3.client("sagemaker-runtime", region_name="us-west-2")
response = client.invoke_endpoint_with_response_stream(
EndpointName="berri-benchmarking-Llama-2-70b-chat-hf-4",
Body=json.dumps(payload),
ContentType="application/json",
)
# for token in TokenIterator(response["Body"]):
# print(token)

View file

@ -25,6 +25,33 @@ class SagemakerError(Exception):
) # Call the base class constructor with the parameters it needs
import io
import json
class TokenIterator:
def __init__(self, stream):
self.byte_iterator = iter(stream)
self.buffer = io.BytesIO()
self.read_pos = 0
def __iter__(self):
return self
def __next__(self):
while True:
self.buffer.seek(self.read_pos)
line = self.buffer.readline()
if line and line[-1] == ord("\n"):
self.read_pos += len(line) + 1
full_line = line[:-1].decode("utf-8")
line_data = json.loads(full_line.lstrip("data:").rstrip("/n"))
return line_data["token"]["text"]
chunk = next(self.byte_iterator)
self.buffer.seek(0, io.SEEK_END)
self.buffer.write(chunk["PayloadPart"]["Bytes"])
class SagemakerConfig:
"""
Reference: https://d-uuwbxj1u4cnu.studio.us-west-2.sagemaker.aws/jupyter/default/lab/workspaces/auto-q/tree/DemoNotebooks/meta-textgeneration-llama-2-7b-SDK_1.ipynb
@ -121,7 +148,6 @@ def completion(
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
inference_params = deepcopy(optional_params)
inference_params.pop("stream", None)
## Load Config
config = litellm.SagemakerConfig.get_config()
@ -152,6 +178,28 @@ def completion(
hf_model_name or model
) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
prompt = prompt_factory(model=hf_model_name, messages=messages)
stream = inference_params.pop("stream", None)
if stream == True:
data = json.dumps(
{"inputs": prompt, "parameters": inference_params, "stream": True}
).encode("utf-8")
## LOGGING
request_str = f"""
response = client.invoke_endpoint_with_response_stream(
EndpointName={model},
ContentType="application/json",
Body={data},
CustomAttributes="accept_eula=true",
)
""" # type: ignore
response = client.invoke_endpoint_with_response_stream(
EndpointName=model,
ContentType="application/json",
Body=data,
CustomAttributes="accept_eula=true",
)
return response["Body"]
data = json.dumps({"inputs": prompt, "parameters": inference_params}).encode(
"utf-8"

View file

@ -1520,10 +1520,12 @@ def completion(
# fake streaming for sagemaker
print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER")
resp_string = model_response["choices"][0]["message"]["content"]
from .llms.sagemaker import TokenIterator
tokenIterator = TokenIterator(model_response)
response = CustomStreamWrapper(
resp_string,
model,
completion_stream=tokenIterator,
model=model,
custom_llm_provider="sagemaker",
logging_obj=logging,
)

View file

@ -1394,6 +1394,30 @@ def test_completion_sagemaker():
# test_completion_sagemaker()
def test_completion_sagemaker_stream():
try:
litellm.set_verbose = False
print("testing sagemaker")
response = completion(
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
messages=messages,
temperature=0.2,
max_tokens=80,
stream=True,
)
complete_streaming_response = ""
for chunk in response:
print(chunk)
complete_streaming_response += chunk.choices[0].delta.content or ""
# Add any assertions here to check the response
# print(response)
assert len(complete_streaming_response) > 0
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_completion_chat_sagemaker():
try:
messages = [{"role": "user", "content": "Hey, how's it going?"}]

View file

@ -7736,18 +7736,9 @@ class CustomStreamWrapper:
self.sent_last_chunk = True
elif self.custom_llm_provider == "sagemaker":
print_verbose(f"ENTERS SAGEMAKER STREAMING")
if len(self.completion_stream) == 0:
if self.sent_last_chunk:
raise StopIteration
else:
model_response.choices[0].finish_reason = "stop"
self.sent_last_chunk = True
new_chunk = self.completion_stream
print_verbose(f"sagemaker chunk: {new_chunk}")
new_chunk = next(self.completion_stream)
completion_obj["content"] = new_chunk
self.completion_stream = self.completion_stream[
len(self.completion_stream) :
]
elif self.custom_llm_provider == "petals":
if len(self.completion_stream) == 0:
if self.sent_last_chunk: