add support for custom hf prompt templates

This commit is contained in:
Krrish Dholakia 2023-09-30 15:37:30 -07:00
parent 2af46e8be9
commit 16c755257b
6 changed files with 172 additions and 85 deletions

View file

@ -1,3 +1,7 @@
import requests, traceback
import json
from jinja2 import Template, exceptions, Environment, meta
def default_pt(messages):
return " ".join(message["content"] for message in messages)
@ -104,6 +108,76 @@ def phind_codellama_pt(messages):
prompt += "### Assistant\n" + message["content"] + "\n\n"
return prompt
def hf_chat_template(model: str, messages: list):
## get the tokenizer config from huggingface
def _get_tokenizer_config(hf_model_name):
url = f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json"
# Make a GET request to fetch the JSON data
response = requests.get(url)
if response.status_code == 200:
# Parse the JSON data
tokenizer_config = json.loads(response.content)
return {"status": "success", "tokenizer": tokenizer_config}
else:
return {"status": "failure"}
tokenizer_config = _get_tokenizer_config(model)
if tokenizer_config["status"] == "failure" or "chat_template" not in tokenizer_config["tokenizer"]:
raise Exception("No chat template found")
## read the bos token, eos token and chat template from the json
tokenizer_config = tokenizer_config["tokenizer"]
bos_token = tokenizer_config["bos_token"]
eos_token = tokenizer_config["eos_token"]
chat_template = tokenizer_config["chat_template"]
def raise_exception(message):
raise Exception(f"Error message - {message}")
# Create a template object from the template text
env = Environment()
env.globals['raise_exception'] = raise_exception
template = env.from_string(chat_template)
def _is_system_in_template():
try:
# Try rendering the template with a system message
response = template.render(messages=[{"role": "system", "content": "test"}], eos_token= "<eos>", bos_token= "<bos>")
return True
# This will be raised if Jinja attempts to render the system message and it can't
except:
return False
try:
# Render the template with the provided values
if _is_system_in_template():
rendered_text = template.render(bos_token=bos_token, eos_token=eos_token, messages=messages)
else:
# treat a system message as a user message, if system not in template
try:
reformatted_messages = []
for message in messages:
if message["role"] == "system":
reformatted_messages.append({"role": "user", "content": message["content"]})
else:
reformatted_messages.append(message)
rendered_text = template.render(bos_token=bos_token, eos_token=eos_token, messages=reformatted_messages)
except Exception as e:
if "Conversation roles must alternate user/assistant" in str(e):
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
new_messages = []
for i in range(len(reformatted_messages)-1):
new_messages.append(reformatted_messages[i])
if reformatted_messages[i]["role"] == reformatted_messages[i+1]["role"]:
if reformatted_messages[i]["role"] == "user":
new_messages.append({"role": "assistant", "content": ""})
else:
new_messages.append({"role": "user", "content": ""})
new_messages.append(reformatted_messages[-1])
rendered_text = template.render(bos_token=bos_token, eos_token=eos_token, messages=new_messages)
return rendered_text
except:
raise Exception("Error rendering template")
# Custom prompt template
def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="", final_prompt_value: str=""):
prompt = initial_prompt_value
@ -117,27 +191,31 @@ def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="",
return prompt
def prompt_factory(model: str, messages: list):
original_model_name = model
model = model.lower()
if "meta-llama/llama-2" in model:
if "chat" in model:
try:
if "meta-llama/llama-2" in model:
if "chat" in model:
return llama_2_chat_pt(messages=messages)
elif "tiiuae/falcon" in model: # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template.
if model == "tiiuae/falcon-180B-chat":
return falcon_chat_pt(messages=messages)
elif "instruct" in model:
return falcon_instruct_pt(messages=messages)
elif "mosaicml/mpt" in model:
if "chat" in model:
return mpt_chat_pt(messages=messages)
elif "codellama/codellama" in model:
if "instruct" in model:
return llama_2_chat_pt(messages=messages) # https://huggingface.co/blog/codellama#conversational-instructions
elif "wizardlm/wizardcoder" in model:
return wizardcoder_pt(messages=messages)
elif "phind/phind-codellama" in model:
return phind_codellama_pt(messages=messages)
elif "togethercomputer/llama-2" in model and ("instruct" in model or "chat" in model):
return llama_2_chat_pt(messages=messages)
elif "tiiuae/falcon" in model: # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template.
if model == "tiiuae/falcon-180B-chat":
return falcon_chat_pt(messages=messages)
elif "instruct" in model:
return falcon_instruct_pt(messages=messages)
elif "mosaicml/mpt" in model:
if "chat" in model:
return mpt_chat_pt(messages=messages)
elif "codellama/codellama" in model:
if "instruct" in model:
return llama_2_chat_pt(messages=messages) # https://huggingface.co/blog/codellama#conversational-instructions
elif "wizardlm/wizardcoder" in model:
return wizardcoder_pt(messages=messages)
elif "phind/phind-codellama" in model:
return phind_codellama_pt(messages=messages)
elif "togethercomputer/llama-2" in model and ("instruct" in model or "chat" in model):
return llama_2_chat_pt(messages=messages)
elif "mistralai/mistral" in model and "instruct" in model:
return mistral_instruct_pt(messages=messages)
return default_pt(messages=messages) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)
else:
return hf_chat_template(original_model_name, messages)
except:
traceback.print_exc()
return default_pt(messages=messages) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)

View file

@ -1,17 +1,22 @@
# import sys, os
# import traceback
# from dotenv import load_dotenv
import sys, os
import traceback
from dotenv import load_dotenv
# load_dotenv()
# import os
load_dotenv()
import os
# sys.path.insert(
# 0, os.path.abspath("../..")
# ) # Adds the parent directory to the system path
# import pytest
# import litellm
# from litellm import embedding, completion, text_completion
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
from litellm.llms.prompt_templates.factory import prompt_factory
def test_prompt_formatting():
try:
prompt = prompt_factory(model="mistralai/Mistral-7B-Instruct-v0.1", messages=[{"role": "system", "content": "Be a good bot"}, {"role": "user", "content": "Hello world"}])
assert prompt == "<s>[INST] Be a good bot [/INST]</s> [INST] Hello world [/INST]"
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# def logger_fn(user_model_dict):
# return
# print(f"user_model_dict: {user_model_dict}")

View file

@ -314,58 +314,62 @@ def test_completion_cohere_stream_bad_key():
# test_completion_nlp_cloud_bad_key()
# def test_completion_hf_stream():
# try:
# litellm.set_verbose = True
# # messages = [
# # {
# # "content": "Hello! How are you today?",
# # "role": "user"
# # },
# # ]
# # response = completion(
# # model="huggingface/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, api_base="https://n9ox93a8sv5ihsow.us-east-1.aws.endpoints.huggingface.cloud", stream=True, max_tokens=1000
# # )
# # complete_response = ""
# # # Add any assertions here to check the response
# # for idx, chunk in enumerate(response):
# # chunk, finished = streaming_format_tests(idx, chunk)
# # if finished:
# # break
# # complete_response += chunk
# # if complete_response.strip() == "":
# # raise Exception("Empty response received")
# # completion_response_1 = complete_response
# messages = [
# {
# "content": "Hello! How are you today?",
# "role": "user"
# },
# {
# "content": "I'm doing well, thank you for asking! I'm excited to be here and help you with any questions or concerns you may have. What can I assist you with today?</s>",
# "role": "assistant"
# },
# ]
# response = completion(
# model="huggingface/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, api_base="https://n9ox93a8sv5ihsow.us-east-1.aws.endpoints.huggingface.cloud", stream=True, max_tokens=1000
# )
# complete_response = ""
# # Add any assertions here to check the response
# for idx, chunk in enumerate(response):
# chunk, finished = streaming_format_tests(idx, chunk)
# if finished:
# break
# complete_response += chunk
# if complete_response.strip() == "":
# raise Exception("Empty response received")
# # print(f"completion_response_1: {completion_response_1}")
# print(f"completion_response: {complete_response}")
# except InvalidRequestError as e:
# pass
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
def test_completion_hf_stream():
try:
litellm.set_verbose = True
# messages = [
# {
# "content": "Hello! How are you today?",
# "role": "user"
# },
# ]
# response = completion(
# model="huggingface/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, api_base="https://n9ox93a8sv5ihsow.us-east-1.aws.endpoints.huggingface.cloud", stream=True, max_tokens=1000
# )
# complete_response = ""
# # Add any assertions here to check the response
# for idx, chunk in enumerate(response):
# chunk, finished = streaming_format_tests(idx, chunk)
# if finished:
# break
# complete_response += chunk
# if complete_response.strip() == "":
# raise Exception("Empty response received")
# completion_response_1 = complete_response
messages = [
{
"content": "Hello! How are you today?",
"role": "user"
},
{
"content": "I'm doing well, thank you for asking! I'm excited to be here and help you with any questions or concerns you may have. What can I assist you with today?",
"role": "assistant"
},
{
"content": "What is the price of crude oil?",
"role": "user"
},
]
response = completion(
model="huggingface/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, api_base="https://n9ox93a8sv5ihsow.us-east-1.aws.endpoints.huggingface.cloud", stream=True, max_tokens=1000
)
complete_response = ""
# Add any assertions here to check the response
for idx, chunk in enumerate(response):
chunk, finished = streaming_format_tests(idx, chunk)
if finished:
break
complete_response += chunk
if complete_response.strip() == "":
raise Exception("Empty response received")
# print(f"completion_response_1: {completion_response_1}")
print(f"completion_response: {complete_response}")
except InvalidRequestError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_hf_stream()
test_completion_hf_stream()
# def test_completion_hf_stream_bad_key():
# try:

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "litellm"
version = "0.1.802"
version = "0.1.803"
description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"]
license = "MIT License"