mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
sagemaker streaming
This commit is contained in:
parent
a8c0f46111
commit
6add152818
3 changed files with 36 additions and 4 deletions
|
@ -866,10 +866,15 @@ def completion(
|
||||||
logging_obj=logging
|
logging_obj=logging
|
||||||
)
|
)
|
||||||
|
|
||||||
if "stream" in optional_params and optional_params["stream"] == True: ## [BETA]
|
if stream==True: ## [BETA]
|
||||||
# don't try to access stream object,
|
# sagemaker does not support streaming as of now so we're faking streaming:
|
||||||
|
# https://discuss.huggingface.co/t/streaming-output-text-when-deploying-on-sagemaker/39611
|
||||||
|
# "SageMaker is currently not supporting streaming responses."
|
||||||
|
|
||||||
|
# fake streaming for sagemaker
|
||||||
|
resp_string = model_response["choices"][0]["message"]["content"]
|
||||||
response = CustomStreamWrapper(
|
response = CustomStreamWrapper(
|
||||||
iter(model_response), model, custom_llm_provider="sagemaker", logging_obj=logging
|
resp_string, model, custom_llm_provider="sagemaker", logging_obj=logging
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
|
@ -621,7 +621,26 @@ def test_completion_sagemaker():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
test_completion_sagemaker()
|
# test_completion_sagemaker()
|
||||||
|
|
||||||
|
def test_completion_sagemaker_stream():
|
||||||
|
litellm.set_verbose = False
|
||||||
|
try:
|
||||||
|
response = completion(
|
||||||
|
model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b",
|
||||||
|
messages=messages,
|
||||||
|
temperature=0.2,
|
||||||
|
max_tokens=80,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
for chunk in response:
|
||||||
|
print(chunk)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
# test_completion_sagemaker_stream()
|
||||||
|
|
||||||
def test_completion_bedrock_titan():
|
def test_completion_bedrock_titan():
|
||||||
try:
|
try:
|
||||||
response = completion(
|
response = completion(
|
||||||
|
|
|
@ -2720,6 +2720,14 @@ class CustomStreamWrapper:
|
||||||
completion_obj["content"] = self.handle_cohere_chunk(chunk)
|
completion_obj["content"] = self.handle_cohere_chunk(chunk)
|
||||||
elif self.custom_llm_provider == "bedrock":
|
elif self.custom_llm_provider == "bedrock":
|
||||||
completion_obj["content"] = self.handle_bedrock_stream()
|
completion_obj["content"] = self.handle_bedrock_stream()
|
||||||
|
elif self.custom_llm_provider == "sagemaker":
|
||||||
|
if len(self.completion_stream)==0:
|
||||||
|
raise StopIteration
|
||||||
|
chunk_size = 30
|
||||||
|
new_chunk = self.completion_stream[:chunk_size]
|
||||||
|
completion_obj["content"] = new_chunk
|
||||||
|
self.completion_stream = self.completion_stream[chunk_size:]
|
||||||
|
time.sleep(0.05)
|
||||||
else: # openai chat/azure models
|
else: # openai chat/azure models
|
||||||
chunk = next(self.completion_stream)
|
chunk = next(self.completion_stream)
|
||||||
model_response = chunk
|
model_response = chunk
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue