(fix) select_data_generator - sagemaker

This commit is contained in:
ishaan-jaff 2024-01-23 12:08:58 -08:00
parent 67dddc94d9
commit 00d18cbc86

View file

@ -1436,6 +1436,19 @@ async def async_data_generator(response, user_api_key_dict):
yield f"data: {str(e)}\n\n"
def select_data_generator(response, user_api_key_dict):
# since boto3 - sagemaker does not support async calls
if response.custom_llm_provider == "sagemaker":
return data_generator(
response=response,
)
else:
# default to async_data_generator
return async_data_generator(
response=response, user_api_key_dict=user_api_key_dict
)
def get_litellm_model_info(model: dict = {}):
model_info = model.get("model_info", {})
model_to_lookup = model.get("litellm_params", {}).get("model", None)
@ -1672,16 +1685,12 @@ async def completion(
"stream" in data and data["stream"] == True
): # use generate_responses to stream responses
custom_headers = {"x-litellm-model-id": model_id}
stream_content = async_data_generator(
user_api_key_dict=user_api_key_dict,
response=response,
selected_data_generator = select_data_generator(
response=response, user_api_key_dict=user_api_key_dict
)
if response.custom_llm_provider == "sagemaker":
stream_content = data_generator(
response=response,
)
return StreamingResponse(
stream_content,
selected_data_generator,
media_type="text/event-stream",
headers=custom_headers,
)
@ -1839,16 +1848,12 @@ async def chat_completion(
"stream" in data and data["stream"] == True
): # use generate_responses to stream responses
custom_headers = {"x-litellm-model-id": model_id}
stream_content = async_data_generator(
user_api_key_dict=user_api_key_dict,
response=response,
selected_data_generator = select_data_generator(
response=response, user_api_key_dict=user_api_key_dict
)
if response.custom_llm_provider == "sagemaker":
stream_content = data_generator(
response=response,
)
return StreamingResponse(
stream_content,
selected_data_generator,
media_type="text/event-stream",
headers=custom_headers,
)