diff --git a/litellm/main.py b/litellm/main.py index 150fa31f1d..4428bef1b7 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -866,10 +866,15 @@ def completion( logging_obj=logging ) - if "stream" in optional_params and optional_params["stream"] == True: ## [BETA] - # don't try to access stream object, + if stream==True: ## [BETA] + # sagemaker does not support streaming as of now so we're faking streaming: + # https://discuss.huggingface.co/t/streaming-output-text-when-deploying-on-sagemaker/39611 + # "SageMaker is currently not supporting streaming responses." + + # fake streaming for sagemaker + resp_string = model_response["choices"][0]["message"]["content"] response = CustomStreamWrapper( - iter(model_response), model, custom_llm_provider="sagemaker", logging_obj=logging + resp_string, model, custom_llm_provider="sagemaker", logging_obj=logging ) return response diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 6bc56ac8f6..fecec84442 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -621,7 +621,26 @@ def test_completion_sagemaker(): except Exception as e: pytest.fail(f"Error occurred: {e}") -test_completion_sagemaker() +# test_completion_sagemaker() + +def test_completion_sagemaker_stream(): + litellm.set_verbose = False + try: + response = completion( + model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b", + messages=messages, + temperature=0.2, + max_tokens=80, + stream=True, + ) + # Add any assertions here to check the response + for chunk in response: + print(chunk) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + +# test_completion_sagemaker_stream() + def test_completion_bedrock_titan(): try: response = completion( diff --git a/litellm/utils.py b/litellm/utils.py index 2590df7936..325fea5e34 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2720,6 +2720,14 @@ class CustomStreamWrapper: completion_obj["content"] = self.handle_cohere_chunk(chunk) elif self.custom_llm_provider == "bedrock": completion_obj["content"] = self.handle_bedrock_stream() + elif self.custom_llm_provider == "sagemaker": + if len(self.completion_stream)==0: + raise StopIteration + chunk_size = 30 + new_chunk = self.completion_stream[:chunk_size] + completion_obj["content"] = new_chunk + self.completion_stream = self.completion_stream[chunk_size:] + time.sleep(0.05) else: # openai chat/azure models chunk = next(self.completion_stream) model_response = chunk