mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
add support for custom hf prompt templates
This commit is contained in:
parent
2af46e8be9
commit
16c755257b
6 changed files with 172 additions and 85 deletions
Binary file not shown.
Binary file not shown.
|
@ -1,3 +1,7 @@
|
|||
import requests, traceback
|
||||
import json
|
||||
from jinja2 import Template, exceptions, Environment, meta
|
||||
|
||||
def default_pt(messages):
|
||||
return " ".join(message["content"] for message in messages)
|
||||
|
||||
|
@ -104,6 +108,76 @@ def phind_codellama_pt(messages):
|
|||
prompt += "### Assistant\n" + message["content"] + "\n\n"
|
||||
return prompt
|
||||
|
||||
def hf_chat_template(model: str, messages: list):
|
||||
## get the tokenizer config from huggingface
|
||||
def _get_tokenizer_config(hf_model_name):
|
||||
url = f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json"
|
||||
# Make a GET request to fetch the JSON data
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
# Parse the JSON data
|
||||
tokenizer_config = json.loads(response.content)
|
||||
return {"status": "success", "tokenizer": tokenizer_config}
|
||||
else:
|
||||
return {"status": "failure"}
|
||||
tokenizer_config = _get_tokenizer_config(model)
|
||||
if tokenizer_config["status"] == "failure" or "chat_template" not in tokenizer_config["tokenizer"]:
|
||||
raise Exception("No chat template found")
|
||||
## read the bos token, eos token and chat template from the json
|
||||
tokenizer_config = tokenizer_config["tokenizer"]
|
||||
bos_token = tokenizer_config["bos_token"]
|
||||
eos_token = tokenizer_config["eos_token"]
|
||||
chat_template = tokenizer_config["chat_template"]
|
||||
|
||||
def raise_exception(message):
|
||||
raise Exception(f"Error message - {message}")
|
||||
|
||||
# Create a template object from the template text
|
||||
env = Environment()
|
||||
env.globals['raise_exception'] = raise_exception
|
||||
template = env.from_string(chat_template)
|
||||
|
||||
def _is_system_in_template():
|
||||
try:
|
||||
# Try rendering the template with a system message
|
||||
response = template.render(messages=[{"role": "system", "content": "test"}], eos_token= "<eos>", bos_token= "<bos>")
|
||||
return True
|
||||
|
||||
# This will be raised if Jinja attempts to render the system message and it can't
|
||||
except:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Render the template with the provided values
|
||||
if _is_system_in_template():
|
||||
rendered_text = template.render(bos_token=bos_token, eos_token=eos_token, messages=messages)
|
||||
else:
|
||||
# treat a system message as a user message, if system not in template
|
||||
try:
|
||||
reformatted_messages = []
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
reformatted_messages.append({"role": "user", "content": message["content"]})
|
||||
else:
|
||||
reformatted_messages.append(message)
|
||||
rendered_text = template.render(bos_token=bos_token, eos_token=eos_token, messages=reformatted_messages)
|
||||
except Exception as e:
|
||||
if "Conversation roles must alternate user/assistant" in str(e):
|
||||
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
|
||||
new_messages = []
|
||||
for i in range(len(reformatted_messages)-1):
|
||||
new_messages.append(reformatted_messages[i])
|
||||
if reformatted_messages[i]["role"] == reformatted_messages[i+1]["role"]:
|
||||
if reformatted_messages[i]["role"] == "user":
|
||||
new_messages.append({"role": "assistant", "content": ""})
|
||||
else:
|
||||
new_messages.append({"role": "user", "content": ""})
|
||||
new_messages.append(reformatted_messages[-1])
|
||||
rendered_text = template.render(bos_token=bos_token, eos_token=eos_token, messages=new_messages)
|
||||
return rendered_text
|
||||
except:
|
||||
raise Exception("Error rendering template")
|
||||
|
||||
# Custom prompt template
|
||||
def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="", final_prompt_value: str=""):
|
||||
prompt = initial_prompt_value
|
||||
|
@ -117,27 +191,31 @@ def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="",
|
|||
return prompt
|
||||
|
||||
def prompt_factory(model: str, messages: list):
|
||||
original_model_name = model
|
||||
model = model.lower()
|
||||
if "meta-llama/llama-2" in model:
|
||||
if "chat" in model:
|
||||
try:
|
||||
if "meta-llama/llama-2" in model:
|
||||
if "chat" in model:
|
||||
return llama_2_chat_pt(messages=messages)
|
||||
elif "tiiuae/falcon" in model: # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template.
|
||||
if model == "tiiuae/falcon-180B-chat":
|
||||
return falcon_chat_pt(messages=messages)
|
||||
elif "instruct" in model:
|
||||
return falcon_instruct_pt(messages=messages)
|
||||
elif "mosaicml/mpt" in model:
|
||||
if "chat" in model:
|
||||
return mpt_chat_pt(messages=messages)
|
||||
elif "codellama/codellama" in model:
|
||||
if "instruct" in model:
|
||||
return llama_2_chat_pt(messages=messages) # https://huggingface.co/blog/codellama#conversational-instructions
|
||||
elif "wizardlm/wizardcoder" in model:
|
||||
return wizardcoder_pt(messages=messages)
|
||||
elif "phind/phind-codellama" in model:
|
||||
return phind_codellama_pt(messages=messages)
|
||||
elif "togethercomputer/llama-2" in model and ("instruct" in model or "chat" in model):
|
||||
return llama_2_chat_pt(messages=messages)
|
||||
elif "tiiuae/falcon" in model: # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template.
|
||||
if model == "tiiuae/falcon-180B-chat":
|
||||
return falcon_chat_pt(messages=messages)
|
||||
elif "instruct" in model:
|
||||
return falcon_instruct_pt(messages=messages)
|
||||
elif "mosaicml/mpt" in model:
|
||||
if "chat" in model:
|
||||
return mpt_chat_pt(messages=messages)
|
||||
elif "codellama/codellama" in model:
|
||||
if "instruct" in model:
|
||||
return llama_2_chat_pt(messages=messages) # https://huggingface.co/blog/codellama#conversational-instructions
|
||||
elif "wizardlm/wizardcoder" in model:
|
||||
return wizardcoder_pt(messages=messages)
|
||||
elif "phind/phind-codellama" in model:
|
||||
return phind_codellama_pt(messages=messages)
|
||||
elif "togethercomputer/llama-2" in model and ("instruct" in model or "chat" in model):
|
||||
return llama_2_chat_pt(messages=messages)
|
||||
elif "mistralai/mistral" in model and "instruct" in model:
|
||||
return mistral_instruct_pt(messages=messages)
|
||||
return default_pt(messages=messages) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)
|
||||
else:
|
||||
return hf_chat_template(original_model_name, messages)
|
||||
except:
|
||||
traceback.print_exc()
|
||||
return default_pt(messages=messages) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)
|
|
@ -1,17 +1,22 @@
|
|||
# import sys, os
|
||||
# import traceback
|
||||
# from dotenv import load_dotenv
|
||||
import sys, os
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# load_dotenv()
|
||||
# import os
|
||||
load_dotenv()
|
||||
import os
|
||||
|
||||
# sys.path.insert(
|
||||
# 0, os.path.abspath("../..")
|
||||
# ) # Adds the parent directory to the system path
|
||||
# import pytest
|
||||
# import litellm
|
||||
# from litellm import embedding, completion, text_completion
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
from litellm.llms.prompt_templates.factory import prompt_factory
|
||||
|
||||
def test_prompt_formatting():
|
||||
try:
|
||||
prompt = prompt_factory(model="mistralai/Mistral-7B-Instruct-v0.1", messages=[{"role": "system", "content": "Be a good bot"}, {"role": "user", "content": "Hello world"}])
|
||||
assert prompt == "<s>[INST] Be a good bot [/INST]</s> [INST] Hello world [/INST]"
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {str(e)}")
|
||||
# def logger_fn(user_model_dict):
|
||||
# return
|
||||
# print(f"user_model_dict: {user_model_dict}")
|
||||
|
|
|
@ -314,58 +314,62 @@ def test_completion_cohere_stream_bad_key():
|
|||
|
||||
# test_completion_nlp_cloud_bad_key()
|
||||
|
||||
# def test_completion_hf_stream():
|
||||
# try:
|
||||
# litellm.set_verbose = True
|
||||
# # messages = [
|
||||
# # {
|
||||
# # "content": "Hello! How are you today?",
|
||||
# # "role": "user"
|
||||
# # },
|
||||
# # ]
|
||||
# # response = completion(
|
||||
# # model="huggingface/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, api_base="https://n9ox93a8sv5ihsow.us-east-1.aws.endpoints.huggingface.cloud", stream=True, max_tokens=1000
|
||||
# # )
|
||||
# # complete_response = ""
|
||||
# # # Add any assertions here to check the response
|
||||
# # for idx, chunk in enumerate(response):
|
||||
# # chunk, finished = streaming_format_tests(idx, chunk)
|
||||
# # if finished:
|
||||
# # break
|
||||
# # complete_response += chunk
|
||||
# # if complete_response.strip() == "":
|
||||
# # raise Exception("Empty response received")
|
||||
# # completion_response_1 = complete_response
|
||||
# messages = [
|
||||
# {
|
||||
# "content": "Hello! How are you today?",
|
||||
# "role": "user"
|
||||
# },
|
||||
# {
|
||||
# "content": "I'm doing well, thank you for asking! I'm excited to be here and help you with any questions or concerns you may have. What can I assist you with today?</s>",
|
||||
# "role": "assistant"
|
||||
# },
|
||||
# ]
|
||||
# response = completion(
|
||||
# model="huggingface/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, api_base="https://n9ox93a8sv5ihsow.us-east-1.aws.endpoints.huggingface.cloud", stream=True, max_tokens=1000
|
||||
# )
|
||||
# complete_response = ""
|
||||
# # Add any assertions here to check the response
|
||||
# for idx, chunk in enumerate(response):
|
||||
# chunk, finished = streaming_format_tests(idx, chunk)
|
||||
# if finished:
|
||||
# break
|
||||
# complete_response += chunk
|
||||
# if complete_response.strip() == "":
|
||||
# raise Exception("Empty response received")
|
||||
# # print(f"completion_response_1: {completion_response_1}")
|
||||
# print(f"completion_response: {complete_response}")
|
||||
# except InvalidRequestError as e:
|
||||
# pass
|
||||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
def test_completion_hf_stream():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
# messages = [
|
||||
# {
|
||||
# "content": "Hello! How are you today?",
|
||||
# "role": "user"
|
||||
# },
|
||||
# ]
|
||||
# response = completion(
|
||||
# model="huggingface/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, api_base="https://n9ox93a8sv5ihsow.us-east-1.aws.endpoints.huggingface.cloud", stream=True, max_tokens=1000
|
||||
# )
|
||||
# complete_response = ""
|
||||
# # Add any assertions here to check the response
|
||||
# for idx, chunk in enumerate(response):
|
||||
# chunk, finished = streaming_format_tests(idx, chunk)
|
||||
# if finished:
|
||||
# break
|
||||
# complete_response += chunk
|
||||
# if complete_response.strip() == "":
|
||||
# raise Exception("Empty response received")
|
||||
# completion_response_1 = complete_response
|
||||
messages = [
|
||||
{
|
||||
"content": "Hello! How are you today?",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "I'm doing well, thank you for asking! I'm excited to be here and help you with any questions or concerns you may have. What can I assist you with today?",
|
||||
"role": "assistant"
|
||||
},
|
||||
{
|
||||
"content": "What is the price of crude oil?",
|
||||
"role": "user"
|
||||
},
|
||||
]
|
||||
response = completion(
|
||||
model="huggingface/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, api_base="https://n9ox93a8sv5ihsow.us-east-1.aws.endpoints.huggingface.cloud", stream=True, max_tokens=1000
|
||||
)
|
||||
complete_response = ""
|
||||
# Add any assertions here to check the response
|
||||
for idx, chunk in enumerate(response):
|
||||
chunk, finished = streaming_format_tests(idx, chunk)
|
||||
if finished:
|
||||
break
|
||||
complete_response += chunk
|
||||
if complete_response.strip() == "":
|
||||
raise Exception("Empty response received")
|
||||
# print(f"completion_response_1: {completion_response_1}")
|
||||
print(f"completion_response: {complete_response}")
|
||||
except InvalidRequestError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# test_completion_hf_stream()
|
||||
test_completion_hf_stream()
|
||||
|
||||
# def test_completion_hf_stream_bad_key():
|
||||
# try:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "litellm"
|
||||
version = "0.1.802"
|
||||
version = "0.1.803"
|
||||
description = "Library to easily interface with LLM API providers"
|
||||
authors = ["BerriAI"]
|
||||
license = "MIT License"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue