Fixed bugs in prompt factory for ibm-mistral and llama 3 models.

This commit is contained in:
Simon Sanchez Viloria 2024-04-23 16:20:49 +02:00
parent 2ef4fb2efa
commit d72b725273

View file

@ -1362,10 +1362,11 @@ def prompt_factory(
if "granite" in model and "chat" in model:
# granite-13b-chat-v1 and granite-13b-chat-v2 use a specific prompt template
return ibm_granite_pt(messages=messages)
elif "ibm-mistral" in model:
elif "ibm-mistral" in model and "instruct" in model:
# models like ibm-mistral/mixtral-8x7b-instruct-v01-q use the mistral instruct prompt template
return mistral_instruct_pt(messages=messages)
elif "meta-llama/llama-3" in model and "instruct" in model:
# https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/
return custom_prompt(
role_dict={
"system": {"pre_message": "<|start_header_id|>system<|end_header_id|>\n", "post_message": "<|eot_id|>"},
@ -1374,7 +1375,7 @@ def prompt_factory(
},
messages=messages,
initial_prompt_value="<|begin_of_text|>",
# final_prompt_value="\n",
final_prompt_value="<|start_header_id|>assistant<|end_header_id|>\n",
)
try:
if "meta-llama/llama-2" in model and "chat" in model: