(fix) together_ai use sync generator

This commit is contained in:
ishaan-jaff 2024-01-23 20:07:26 -08:00
parent fcd66eac7d
commit 2d26875eb0
2 changed files with 10 additions and 8 deletions

View file

@ -25,6 +25,9 @@ model_list:
- model_name: BEDROCK_GROUP - model_name: BEDROCK_GROUP
litellm_params: litellm_params:
model: bedrock/cohere.command-text-v14 model: bedrock/cohere.command-text-v14
- model_name: tg-ai
litellm_params:
model: together_ai/mistralai/Mistral-7B-Instruct-v0.1
- model_name: sagemaker - model_name: sagemaker
litellm_params: litellm_params:
model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4 model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4

View file

@ -1450,10 +1450,9 @@ async def async_data_generator(response, user_api_key_dict):
def select_data_generator(response, user_api_key_dict): def select_data_generator(response, user_api_key_dict):
try: try:
# since boto3 - sagemaker does not support async calls, we should use a sync data_generator # since boto3 - sagemaker does not support async calls, we should use a sync data_generator
if ( if hasattr(
hasattr(response, "custom_llm_provider") response, "custom_llm_provider"
and response.custom_llm_provider == "sagemaker" ) and response.custom_llm_provider in ["sagemaker", "together_ai"]:
):
return data_generator( return data_generator(
response=response, response=response,
) )
@ -2243,13 +2242,14 @@ async def generate_key_fn(
if "max_budget" in data_json: if "max_budget" in data_json:
data_json["key_max_budget"] = data_json.pop("max_budget", None) data_json["key_max_budget"] = data_json.pop("max_budget", None)
if "budget_duration" in data_json: if "budget_duration" in data_json:
data_json["key_budget_duration"] = data_json.pop("budget_duration", None) data_json["key_budget_duration"] = data_json.pop("budget_duration", None)
response = await generate_key_helper_fn(**data_json) response = await generate_key_helper_fn(**data_json)
return GenerateKeyResponse( return GenerateKeyResponse(
key=response["token"], expires=response["expires"], user_id=response["user_id"] key=response["token"],
expires=response["expires"],
user_id=response["user_id"],
) )
except Exception as e: except Exception as e:
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
@ -2269,7 +2269,6 @@ async def generate_key_fn(
) )
@router.post( @router.post(
"/key/update", tags=["key management"], dependencies=[Depends(user_api_key_auth)] "/key/update", tags=["key management"], dependencies=[Depends(user_api_key_auth)]
) )