sagemaker streaming

This commit is contained in:
ishaan-jaff 2023-09-21 16:14:44 -07:00
parent a8c0f46111
commit 6add152818
3 changed files with 36 additions and 4 deletions

View file

@ -866,10 +866,15 @@ def completion(
logging_obj=logging logging_obj=logging
) )
if "stream" in optional_params and optional_params["stream"] == True: ## [BETA] if stream==True: ## [BETA]
# don't try to access stream object, # sagemaker does not support streaming as of now so we're faking streaming:
# https://discuss.huggingface.co/t/streaming-output-text-when-deploying-on-sagemaker/39611
# "SageMaker is currently not supporting streaming responses."
# fake streaming for sagemaker
resp_string = model_response["choices"][0]["message"]["content"]
response = CustomStreamWrapper( response = CustomStreamWrapper(
iter(model_response), model, custom_llm_provider="sagemaker", logging_obj=logging resp_string, model, custom_llm_provider="sagemaker", logging_obj=logging
) )
return response return response

View file

@ -621,7 +621,26 @@ def test_completion_sagemaker():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_completion_sagemaker() # test_completion_sagemaker()
def test_completion_sagemaker_stream():
litellm.set_verbose = False
try:
response = completion(
model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b",
messages=messages,
temperature=0.2,
max_tokens=80,
stream=True,
)
# Add any assertions here to check the response
for chunk in response:
print(chunk)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_sagemaker_stream()
def test_completion_bedrock_titan(): def test_completion_bedrock_titan():
try: try:
response = completion( response = completion(

View file

@ -2720,6 +2720,14 @@ class CustomStreamWrapper:
completion_obj["content"] = self.handle_cohere_chunk(chunk) completion_obj["content"] = self.handle_cohere_chunk(chunk)
elif self.custom_llm_provider == "bedrock": elif self.custom_llm_provider == "bedrock":
completion_obj["content"] = self.handle_bedrock_stream() completion_obj["content"] = self.handle_bedrock_stream()
elif self.custom_llm_provider == "sagemaker":
if len(self.completion_stream)==0:
raise StopIteration
chunk_size = 30
new_chunk = self.completion_stream[:chunk_size]
completion_obj["content"] = new_chunk
self.completion_stream = self.completion_stream[chunk_size:]
time.sleep(0.05)
else: # openai chat/azure models else: # openai chat/azure models
chunk = next(self.completion_stream) chunk = next(self.completion_stream)
model_response = chunk model_response = chunk