mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
(fix) select_data_generator - sagemaker
This commit is contained in:
parent
67dddc94d9
commit
00d18cbc86
1 changed files with 21 additions and 16 deletions
|
@ -1436,6 +1436,19 @@ async def async_data_generator(response, user_api_key_dict):
|
||||||
yield f"data: {str(e)}\n\n"
|
yield f"data: {str(e)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
def select_data_generator(response, user_api_key_dict):
|
||||||
|
# since boto3 - sagemaker does not support async calls
|
||||||
|
if response.custom_llm_provider == "sagemaker":
|
||||||
|
return data_generator(
|
||||||
|
response=response,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# default to async_data_generator
|
||||||
|
return async_data_generator(
|
||||||
|
response=response, user_api_key_dict=user_api_key_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_litellm_model_info(model: dict = {}):
|
def get_litellm_model_info(model: dict = {}):
|
||||||
model_info = model.get("model_info", {})
|
model_info = model.get("model_info", {})
|
||||||
model_to_lookup = model.get("litellm_params", {}).get("model", None)
|
model_to_lookup = model.get("litellm_params", {}).get("model", None)
|
||||||
|
@ -1672,16 +1685,12 @@ async def completion(
|
||||||
"stream" in data and data["stream"] == True
|
"stream" in data and data["stream"] == True
|
||||||
): # use generate_responses to stream responses
|
): # use generate_responses to stream responses
|
||||||
custom_headers = {"x-litellm-model-id": model_id}
|
custom_headers = {"x-litellm-model-id": model_id}
|
||||||
stream_content = async_data_generator(
|
selected_data_generator = select_data_generator(
|
||||||
user_api_key_dict=user_api_key_dict,
|
response=response, user_api_key_dict=user_api_key_dict
|
||||||
response=response,
|
|
||||||
)
|
|
||||||
if response.custom_llm_provider == "sagemaker":
|
|
||||||
stream_content = data_generator(
|
|
||||||
response=response,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
stream_content,
|
selected_data_generator,
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers=custom_headers,
|
headers=custom_headers,
|
||||||
)
|
)
|
||||||
|
@ -1839,16 +1848,12 @@ async def chat_completion(
|
||||||
"stream" in data and data["stream"] == True
|
"stream" in data and data["stream"] == True
|
||||||
): # use generate_responses to stream responses
|
): # use generate_responses to stream responses
|
||||||
custom_headers = {"x-litellm-model-id": model_id}
|
custom_headers = {"x-litellm-model-id": model_id}
|
||||||
stream_content = async_data_generator(
|
selected_data_generator = select_data_generator(
|
||||||
user_api_key_dict=user_api_key_dict,
|
response=response, user_api_key_dict=user_api_key_dict
|
||||||
response=response,
|
|
||||||
)
|
|
||||||
if response.custom_llm_provider == "sagemaker":
|
|
||||||
stream_content = data_generator(
|
|
||||||
response=response,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
stream_content,
|
selected_data_generator,
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers=custom_headers,
|
headers=custom_headers,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue