From 998094d38dca6d33db78cce6362ebd42894b2ace Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 22 Jan 2024 21:49:44 -0800 Subject: [PATCH] v0 add TokenIterator, stream support --- litellm/llms/sagemaker.py | 50 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker.py index 7b50b05af3..1608f7a0ff 100644 --- a/litellm/llms/sagemaker.py +++ b/litellm/llms/sagemaker.py @@ -25,6 +25,33 @@ class SagemakerError(Exception): ) # Call the base class constructor with the parameters it needs +import io +import json + + +class TokenIterator: + def __init__(self, stream): + self.byte_iterator = iter(stream) + self.buffer = io.BytesIO() + self.read_pos = 0 + + def __iter__(self): + return self + + def __next__(self): + while True: + self.buffer.seek(self.read_pos) + line = self.buffer.readline() + if line and line[-1] == ord("\n"): + self.read_pos += len(line) + 1 + full_line = line[:-1].decode("utf-8") + line_data = json.loads(full_line.lstrip("data:").rstrip("/n")) + return line_data["token"]["text"] + chunk = next(self.byte_iterator) + self.buffer.seek(0, io.SEEK_END) + self.buffer.write(chunk["PayloadPart"]["Bytes"]) + + class SagemakerConfig: """ Reference: https://d-uuwbxj1u4cnu.studio.us-west-2.sagemaker.aws/jupyter/default/lab/workspaces/auto-q/tree/DemoNotebooks/meta-textgeneration-llama-2-7b-SDK_1.ipynb @@ -121,7 +148,6 @@ def completion( # pop streaming if it's in the optional params as 'stream' raises an error with sagemaker inference_params = deepcopy(optional_params) - inference_params.pop("stream", None) ## Load Config config = litellm.SagemakerConfig.get_config() @@ -152,6 +178,28 @@ def completion( hf_model_name or model ) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt) prompt = prompt_factory(model=hf_model_name, messages=messages) + stream = inference_params.pop("stream", None) + if stream == True: + data = json.dumps( + {"inputs": prompt, "parameters": inference_params, "stream": True} + ).encode("utf-8") + ## LOGGING + request_str = f""" + response = client.invoke_endpoint_with_response_stream( + EndpointName={model}, + ContentType="application/json", + Body={data}, + CustomAttributes="accept_eula=true", + ) + """ # type: ignore + response = client.invoke_endpoint_with_response_stream( + EndpointName=model, + ContentType="application/json", + Body=data, + CustomAttributes="accept_eula=true", + ) + + return response["Body"] data = json.dumps({"inputs": prompt, "parameters": inference_params}).encode( "utf-8"