fix(sagemaker.py): enable passing hf model name for prompt template

This commit is contained in:
Krrish Dholakia 2023-12-05 16:31:59 -08:00
parent a38504ff1b
commit 54d8a9df3f
3 changed files with 8 additions and 7 deletions

View file

@ -63,6 +63,7 @@ def completion(
encoding,
logging_obj,
custom_prompt_dict={},
hf_model_name=None,
optional_params=None,
litellm_params=None,
logger_fn=None,
@ -119,12 +120,7 @@ def completion(
messages=messages
)
else:
hf_model_name = model
if "meta-textgeneration-llama-2" in model or "meta-textgenerationneuron-llama-2" in model: # llama2 model
if model.endswith("-f") or "-f-" in model or "chat" in model: # sagemaker default for a chat model
hf_model_name = "meta-llama/Llama-2-7b-chat" # apply the prompt template for a llama2 chat model
else:
hf_model_name = "meta-llama/Llama-2-7b" # apply the normal prompt template
hf_model_name = hf_model_name or model # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
prompt = prompt_factory(model=hf_model_name, messages=messages)
data = json.dumps({