diff --git a/cookbook/misc/sagmaker_streaming.py b/cookbook/misc/sagmaker_streaming.py new file mode 100644 index 0000000000..81d857b07f --- /dev/null +++ b/cookbook/misc/sagmaker_streaming.py @@ -0,0 +1,61 @@ +# Notes - on how to do sagemaker streaming using boto3 +import json +import boto3 + +import sys, os +import traceback +from dotenv import load_dotenv + +load_dotenv() +import os, io + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest +import litellm + +import io +import json + + +class TokenIterator: + def __init__(self, stream): + self.byte_iterator = iter(stream) + self.buffer = io.BytesIO() + self.read_pos = 0 + + def __iter__(self): + return self + + def __next__(self): + while True: + self.buffer.seek(self.read_pos) + line = self.buffer.readline() + if line and line[-1] == ord("\n"): + self.read_pos += len(line) + 1 + full_line = line[:-1].decode("utf-8") + line_data = json.loads(full_line.lstrip("data:").rstrip("/n")) + return line_data["token"]["text"] + chunk = next(self.byte_iterator) + self.buffer.seek(0, io.SEEK_END) + self.buffer.write(chunk["PayloadPart"]["Bytes"]) + + +payload = { + "inputs": "How do I build a website?", + "parameters": {"max_new_tokens": 256}, + "stream": True, +} + +import boto3 + +client = boto3.client("sagemaker-runtime", region_name="us-west-2") +response = client.invoke_endpoint_with_response_stream( + EndpointName="berri-benchmarking-Llama-2-70b-chat-hf-4", + Body=json.dumps(payload), + ContentType="application/json", +) + +# for token in TokenIterator(response["Body"]): +# print(token) diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker.py index 7b50b05af3..1608f7a0ff 100644 --- a/litellm/llms/sagemaker.py +++ b/litellm/llms/sagemaker.py @@ -25,6 +25,33 @@ class SagemakerError(Exception): ) # Call the base class constructor with the parameters it needs +import io +import json + + +class TokenIterator: + def __init__(self, stream): + self.byte_iterator = iter(stream) + self.buffer = io.BytesIO() + self.read_pos = 0 + + def __iter__(self): + return self + + def __next__(self): + while True: + self.buffer.seek(self.read_pos) + line = self.buffer.readline() + if line and line[-1] == ord("\n"): + self.read_pos += len(line) + 1 + full_line = line[:-1].decode("utf-8") + line_data = json.loads(full_line.lstrip("data:").rstrip("/n")) + return line_data["token"]["text"] + chunk = next(self.byte_iterator) + self.buffer.seek(0, io.SEEK_END) + self.buffer.write(chunk["PayloadPart"]["Bytes"]) + + class SagemakerConfig: """ Reference: https://d-uuwbxj1u4cnu.studio.us-west-2.sagemaker.aws/jupyter/default/lab/workspaces/auto-q/tree/DemoNotebooks/meta-textgeneration-llama-2-7b-SDK_1.ipynb @@ -121,7 +148,6 @@ def completion( # pop streaming if it's in the optional params as 'stream' raises an error with sagemaker inference_params = deepcopy(optional_params) - inference_params.pop("stream", None) ## Load Config config = litellm.SagemakerConfig.get_config() @@ -152,6 +178,28 @@ def completion( hf_model_name or model ) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt) prompt = prompt_factory(model=hf_model_name, messages=messages) + stream = inference_params.pop("stream", None) + if stream == True: + data = json.dumps( + {"inputs": prompt, "parameters": inference_params, "stream": True} + ).encode("utf-8") + ## LOGGING + request_str = f""" + response = client.invoke_endpoint_with_response_stream( + EndpointName={model}, + ContentType="application/json", + Body={data}, + CustomAttributes="accept_eula=true", + ) + """ # type: ignore + response = client.invoke_endpoint_with_response_stream( + EndpointName=model, + ContentType="application/json", + Body=data, + CustomAttributes="accept_eula=true", + ) + + return response["Body"] data = json.dumps({"inputs": prompt, "parameters": inference_params}).encode( "utf-8" diff --git a/litellm/main.py b/litellm/main.py index 9c09085b13..6b9a0bb185 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1520,10 +1520,12 @@ def completion( # fake streaming for sagemaker print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER") - resp_string = model_response["choices"][0]["message"]["content"] + from .llms.sagemaker import TokenIterator + + tokenIterator = TokenIterator(model_response) response = CustomStreamWrapper( - resp_string, - model, + completion_stream=tokenIterator, + model=model, custom_llm_provider="sagemaker", logging_obj=logging, ) diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 644b348ec0..43ffd2b0a3 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -1394,6 +1394,30 @@ def test_completion_sagemaker(): # test_completion_sagemaker() +def test_completion_sagemaker_stream(): + try: + litellm.set_verbose = False + print("testing sagemaker") + response = completion( + model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4", + messages=messages, + temperature=0.2, + max_tokens=80, + stream=True, + ) + + complete_streaming_response = "" + + for chunk in response: + print(chunk) + complete_streaming_response += chunk.choices[0].delta.content or "" + # Add any assertions here to check the response + # print(response) + assert len(complete_streaming_response) > 0 + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + def test_completion_chat_sagemaker(): try: messages = [{"role": "user", "content": "Hey, how's it going?"}] diff --git a/litellm/utils.py b/litellm/utils.py index c8c363d6c6..e47e19b52e 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -7736,18 +7736,9 @@ class CustomStreamWrapper: self.sent_last_chunk = True elif self.custom_llm_provider == "sagemaker": print_verbose(f"ENTERS SAGEMAKER STREAMING") - if len(self.completion_stream) == 0: - if self.sent_last_chunk: - raise StopIteration - else: - model_response.choices[0].finish_reason = "stop" - self.sent_last_chunk = True - new_chunk = self.completion_stream - print_verbose(f"sagemaker chunk: {new_chunk}") + new_chunk = next(self.completion_stream) + completion_obj["content"] = new_chunk - self.completion_stream = self.completion_stream[ - len(self.completion_stream) : - ] elif self.custom_llm_provider == "petals": if len(self.completion_stream) == 0: if self.sent_last_chunk: