mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 19:54:13 +00:00
Merge pull request #1561 from BerriAI/litellm_sagemaker_streaming
[Feat] Add REAL Sagemaker streaming
This commit is contained in:
commit
97dd61a6cb
5 changed files with 141 additions and 15 deletions
61
cookbook/misc/sagmaker_streaming.py
Normal file
61
cookbook/misc/sagmaker_streaming.py
Normal file
|
@ -0,0 +1,61 @@
|
|||
# Notes - on how to do sagemaker streaming using boto3
|
||||
import json
|
||||
import boto3
|
||||
|
||||
import sys, os
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
import os, io
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
import litellm
|
||||
|
||||
import io
|
||||
import json
|
||||
|
||||
|
||||
class TokenIterator:
|
||||
def __init__(self, stream):
|
||||
self.byte_iterator = iter(stream)
|
||||
self.buffer = io.BytesIO()
|
||||
self.read_pos = 0
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
while True:
|
||||
self.buffer.seek(self.read_pos)
|
||||
line = self.buffer.readline()
|
||||
if line and line[-1] == ord("\n"):
|
||||
self.read_pos += len(line) + 1
|
||||
full_line = line[:-1].decode("utf-8")
|
||||
line_data = json.loads(full_line.lstrip("data:").rstrip("/n"))
|
||||
return line_data["token"]["text"]
|
||||
chunk = next(self.byte_iterator)
|
||||
self.buffer.seek(0, io.SEEK_END)
|
||||
self.buffer.write(chunk["PayloadPart"]["Bytes"])
|
||||
|
||||
|
||||
payload = {
|
||||
"inputs": "How do I build a website?",
|
||||
"parameters": {"max_new_tokens": 256},
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
import boto3
|
||||
|
||||
client = boto3.client("sagemaker-runtime", region_name="us-west-2")
|
||||
response = client.invoke_endpoint_with_response_stream(
|
||||
EndpointName="berri-benchmarking-Llama-2-70b-chat-hf-4",
|
||||
Body=json.dumps(payload),
|
||||
ContentType="application/json",
|
||||
)
|
||||
|
||||
# for token in TokenIterator(response["Body"]):
|
||||
# print(token)
|
|
@ -25,6 +25,33 @@ class SagemakerError(Exception):
|
|||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
import io
|
||||
import json
|
||||
|
||||
|
||||
class TokenIterator:
|
||||
def __init__(self, stream):
|
||||
self.byte_iterator = iter(stream)
|
||||
self.buffer = io.BytesIO()
|
||||
self.read_pos = 0
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
while True:
|
||||
self.buffer.seek(self.read_pos)
|
||||
line = self.buffer.readline()
|
||||
if line and line[-1] == ord("\n"):
|
||||
self.read_pos += len(line) + 1
|
||||
full_line = line[:-1].decode("utf-8")
|
||||
line_data = json.loads(full_line.lstrip("data:").rstrip("/n"))
|
||||
return line_data["token"]["text"]
|
||||
chunk = next(self.byte_iterator)
|
||||
self.buffer.seek(0, io.SEEK_END)
|
||||
self.buffer.write(chunk["PayloadPart"]["Bytes"])
|
||||
|
||||
|
||||
class SagemakerConfig:
|
||||
"""
|
||||
Reference: https://d-uuwbxj1u4cnu.studio.us-west-2.sagemaker.aws/jupyter/default/lab/workspaces/auto-q/tree/DemoNotebooks/meta-textgeneration-llama-2-7b-SDK_1.ipynb
|
||||
|
@ -121,7 +148,6 @@ def completion(
|
|||
|
||||
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
|
||||
inference_params = deepcopy(optional_params)
|
||||
inference_params.pop("stream", None)
|
||||
|
||||
## Load Config
|
||||
config = litellm.SagemakerConfig.get_config()
|
||||
|
@ -152,6 +178,28 @@ def completion(
|
|||
hf_model_name or model
|
||||
) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
|
||||
prompt = prompt_factory(model=hf_model_name, messages=messages)
|
||||
stream = inference_params.pop("stream", None)
|
||||
if stream == True:
|
||||
data = json.dumps(
|
||||
{"inputs": prompt, "parameters": inference_params, "stream": True}
|
||||
).encode("utf-8")
|
||||
## LOGGING
|
||||
request_str = f"""
|
||||
response = client.invoke_endpoint_with_response_stream(
|
||||
EndpointName={model},
|
||||
ContentType="application/json",
|
||||
Body={data},
|
||||
CustomAttributes="accept_eula=true",
|
||||
)
|
||||
""" # type: ignore
|
||||
response = client.invoke_endpoint_with_response_stream(
|
||||
EndpointName=model,
|
||||
ContentType="application/json",
|
||||
Body=data,
|
||||
CustomAttributes="accept_eula=true",
|
||||
)
|
||||
|
||||
return response["Body"]
|
||||
|
||||
data = json.dumps({"inputs": prompt, "parameters": inference_params}).encode(
|
||||
"utf-8"
|
||||
|
|
|
@ -1520,10 +1520,12 @@ def completion(
|
|||
|
||||
# fake streaming for sagemaker
|
||||
print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER")
|
||||
resp_string = model_response["choices"][0]["message"]["content"]
|
||||
from .llms.sagemaker import TokenIterator
|
||||
|
||||
tokenIterator = TokenIterator(model_response)
|
||||
response = CustomStreamWrapper(
|
||||
resp_string,
|
||||
model,
|
||||
completion_stream=tokenIterator,
|
||||
model=model,
|
||||
custom_llm_provider="sagemaker",
|
||||
logging_obj=logging,
|
||||
)
|
||||
|
|
|
@ -1394,6 +1394,30 @@ def test_completion_sagemaker():
|
|||
# test_completion_sagemaker()
|
||||
|
||||
|
||||
def test_completion_sagemaker_stream():
|
||||
try:
|
||||
litellm.set_verbose = False
|
||||
print("testing sagemaker")
|
||||
response = completion(
|
||||
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
|
||||
messages=messages,
|
||||
temperature=0.2,
|
||||
max_tokens=80,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
complete_streaming_response = ""
|
||||
|
||||
for chunk in response:
|
||||
print(chunk)
|
||||
complete_streaming_response += chunk.choices[0].delta.content or ""
|
||||
# Add any assertions here to check the response
|
||||
# print(response)
|
||||
assert len(complete_streaming_response) > 0
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_completion_chat_sagemaker():
|
||||
try:
|
||||
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||
|
|
|
@ -7736,18 +7736,9 @@ class CustomStreamWrapper:
|
|||
self.sent_last_chunk = True
|
||||
elif self.custom_llm_provider == "sagemaker":
|
||||
print_verbose(f"ENTERS SAGEMAKER STREAMING")
|
||||
if len(self.completion_stream) == 0:
|
||||
if self.sent_last_chunk:
|
||||
raise StopIteration
|
||||
else:
|
||||
model_response.choices[0].finish_reason = "stop"
|
||||
self.sent_last_chunk = True
|
||||
new_chunk = self.completion_stream
|
||||
print_verbose(f"sagemaker chunk: {new_chunk}")
|
||||
new_chunk = next(self.completion_stream)
|
||||
|
||||
completion_obj["content"] = new_chunk
|
||||
self.completion_stream = self.completion_stream[
|
||||
len(self.completion_stream) :
|
||||
]
|
||||
elif self.custom_llm_provider == "petals":
|
||||
if len(self.completion_stream) == 0:
|
||||
if self.sent_last_chunk:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue