From 00d18cbc86f1b8ffc6f4fff4f999c4bc6b51c4fb Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 23 Jan 2024 12:08:58 -0800 Subject: [PATCH] (fix) select_data_generator - sagemaker --- litellm/proxy/proxy_server.py | 37 ++++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index b874d7d80d..90929abacb 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1436,6 +1436,19 @@ async def async_data_generator(response, user_api_key_dict): yield f"data: {str(e)}\n\n" +def select_data_generator(response, user_api_key_dict): + # since boto3 - sagemaker does not support async calls + if response.custom_llm_provider == "sagemaker": + return data_generator( + response=response, + ) + else: + # default to async_data_generator + return async_data_generator( + response=response, user_api_key_dict=user_api_key_dict + ) + + def get_litellm_model_info(model: dict = {}): model_info = model.get("model_info", {}) model_to_lookup = model.get("litellm_params", {}).get("model", None) @@ -1672,16 +1685,12 @@ async def completion( "stream" in data and data["stream"] == True ): # use generate_responses to stream responses custom_headers = {"x-litellm-model-id": model_id} - stream_content = async_data_generator( - user_api_key_dict=user_api_key_dict, - response=response, + selected_data_generator = select_data_generator( + response=response, user_api_key_dict=user_api_key_dict ) - if response.custom_llm_provider == "sagemaker": - stream_content = data_generator( - response=response, - ) + return StreamingResponse( - stream_content, + selected_data_generator, media_type="text/event-stream", headers=custom_headers, ) @@ -1839,16 +1848,12 @@ async def chat_completion( "stream" in data and data["stream"] == True ): # use generate_responses to stream responses custom_headers = {"x-litellm-model-id": model_id} - stream_content = async_data_generator( - user_api_key_dict=user_api_key_dict, - response=response, + selected_data_generator = select_data_generator( + response=response, user_api_key_dict=user_api_key_dict ) - if response.custom_llm_provider == "sagemaker": - stream_content = data_generator( - response=response, - ) + return StreamingResponse( - stream_content, + selected_data_generator, media_type="text/event-stream", headers=custom_headers, )