(fix) sagemaker Llama-2 70b

This commit is contained in:
ishaan-jaff 2023-12-05 15:32:15 -08:00
parent 68ca2a28d4
commit c4bda13820

View file

@ -2206,7 +2206,9 @@ def get_optional_params( # use the openai defaults
if max_tokens is not None:
optional_params["max_output_tokens"] = max_tokens
elif custom_llm_provider == "sagemaker":
if "llama-2" in model:
if "llama-2" in model.lower() or (
"llama" in model.lower() and "2" in model.lower() # some combination of llama and "2" should exist
): # jumpstart can also send "Llama-2-70b-chat-hf-48xlarge"
# llama-2 models on sagemaker support the following args
"""
max_new_tokens: Model generates text until the output length (excluding the input context length) reaches max_new_tokens. If specified, it must be a positive integer.