From 2d26875eb0b9bd347e9b9d8c7d6fced739d9d5be Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 23 Jan 2024 20:07:26 -0800 Subject: [PATCH] (fix) together_ai use sync generator --- litellm/proxy/proxy_config.yaml | 3 +++ litellm/proxy/proxy_server.py | 15 +++++++-------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 36f0aeb10..97168b19f 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -25,6 +25,9 @@ model_list: - model_name: BEDROCK_GROUP litellm_params: model: bedrock/cohere.command-text-v14 + - model_name: tg-ai + litellm_params: + model: together_ai/mistralai/Mistral-7B-Instruct-v0.1 - model_name: sagemaker litellm_params: model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4 diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index d2f321fd9..cc4222f8e 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1450,10 +1450,9 @@ async def async_data_generator(response, user_api_key_dict): def select_data_generator(response, user_api_key_dict): try: # since boto3 - sagemaker does not support async calls, we should use a sync data_generator - if ( - hasattr(response, "custom_llm_provider") - and response.custom_llm_provider == "sagemaker" - ): + if hasattr( + response, "custom_llm_provider" + ) and response.custom_llm_provider in ["sagemaker", "together_ai"]: return data_generator( response=response, ) @@ -2243,13 +2242,14 @@ async def generate_key_fn( if "max_budget" in data_json: data_json["key_max_budget"] = data_json.pop("max_budget", None) - if "budget_duration" in data_json: - data_json["key_budget_duration"] = data_json.pop("budget_duration", None) + data_json["key_budget_duration"] = data_json.pop("budget_duration", None) response = await generate_key_helper_fn(**data_json) return GenerateKeyResponse( - key=response["token"], expires=response["expires"], user_id=response["user_id"] + key=response["token"], + expires=response["expires"], + user_id=response["user_id"], ) except Exception as e: if isinstance(e, HTTPException): @@ -2268,7 +2268,6 @@ async def generate_key_fn( code=status.HTTP_400_BAD_REQUEST, ) - @router.post( "/key/update", tags=["key management"], dependencies=[Depends(user_api_key_auth)]