adding first-party + custom prompt templates for huggingface

This commit is contained in:
Krrish Dholakia 2023-09-04 14:48:16 -07:00
parent a474b89779
commit 2384806cfd
10 changed files with 186 additions and 20 deletions

View file

@ -47,7 +47,7 @@ def get_model_cost_map():
print("Error occurred:", e) print("Error occurred:", e)
return None return None
model_cost = get_model_cost_map() model_cost = get_model_cost_map()
custom_prompt_dict = {}
####### THREAD-SPECIFIC DATA ################### ####### THREAD-SPECIFIC DATA ###################
class MyLocal(threading.local): class MyLocal(threading.local):
def __init__(self): def __init__(self):
@ -224,7 +224,8 @@ from .utils import (
acreate, acreate,
get_model_list, get_model_list,
completion_with_split_tests, completion_with_split_tests,
get_max_tokens get_max_tokens,
register_prompt_template
) )
from .main import * # type: ignore from .main import * # type: ignore
from .integrations import * from .integrations import *

View file

@ -0,0 +1,109 @@
def default_pt(messages):
return " ".join(message["content"] for message in messages)
# Llama2 prompt template
def llama_2_chat_pt(messages):
prompt = "<s>"
for message in messages:
if message["role"] == "system":
prompt += "[INST] <<SYS>>" + message["content"]
elif message["role"] == "assistant":
prompt += message["content"] + "</s><s>[INST]"
elif message["role"] == "user":
prompt += message["content"] + "[/INST]"
return prompt
def llama_2_pt(messages):
return " ".join(message["content"] for message in messages)
# Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110
def falcon_instruct_pt(messages):
prompt = ""
for message in messages:
if message["role"] == "system":
prompt += messages["content"]
else:
prompt += message['role']+":"+ message["content"].replace("\r\n", "\n").replace("\n\n", "\n")
prompt += "\n\n"
# MPT prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110
def mpt_chat_pt(messages):
prompt = ""
for message in messages:
if message["role"] == "system":
prompt += "<|im_start|>system" + message["content"] + "<|im_end|>" + "\n"
elif message["role"] == "assistant":
prompt += "<|im_start|>assistant" + message["content"] + "<|im_end|>" + "\n"
elif message["role"] == "user":
prompt += "<|im_start|>user" + message["content"] + "<|im_end|>" + "\n"
return prompt
# WizardCoder prompt template - https://huggingface.co/WizardLM/WizardCoder-Python-34B-V1.0#prompt-format
def wizardcoder_pt(messages):
prompt = ""
for message in messages:
if message["role"] == "system":
prompt += message["content"] + "\n\n"
elif message["role"] == "user": # map to 'Instruction'
prompt += "### Instruction:\n" + message["content"] + "\n\n"
elif message["role"] == "assistant": # map to 'Response'
prompt += "### Response:\n" + message["content"] + "\n\n"
return prompt
# Phind-CodeLlama prompt template - https://huggingface.co/Phind/Phind-CodeLlama-34B-v2#how-to-prompt-the-model
def phind_codellama_pt(messages):
prompt = ""
for message in messages:
if message["role"] == "system":
prompt += "### System Prompt\n" + message["content"] + "\n\n"
elif message["role"] == "user":
prompt += "### User Message\n" + message["content"] + "\n\n"
elif message["role"] == "assistant":
prompt += "### Assistant\n" + message["content"] + "\n\n"
return prompt
# Custom prompt template
def custom_prompt(role_dict: dict, pre_message_sep: str, post_message_sep: str, messages: list):
prompt = ""
for message in messages:
if message["role"] == "system":
prompt += f"{role_dict['system']}{pre_message_sep}" + message["content"] + post_message_sep
elif message["role"] == "user":
prompt += f"{role_dict['user']}{pre_message_sep}" + message["content"] + post_message_sep
elif message["role"] == "assistant":
prompt += f"{role_dict['assistant']}{pre_message_sep}" + message["content"] + post_message_sep
return prompt
def prompt_factory(model: str, messages: list):
model = model.lower()
if "bloom" in model:
return default_pt(messages=messages)
elif "flan-t5" in model:
return default_pt(messages=messages)
elif "meta-llama" in model:
if "chat" in model:
return llama_2_chat_pt(messages=messages)
else:
return default_pt(messages=messages)
elif "falcon" in model: # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template.
if "instruct" in model:
return falcon_instruct_pt(messages=messages)
else:
return default_pt(messages=messages)
elif "mpt" in model:
if "chat" in model:
return mpt_chat_pt(messages=messages)
else:
return default_pt(messages=messages)
elif "codellama/codellama" in model:
if "instruct" in model:
return llama_2_chat_pt(messages=messages) # https://huggingface.co/blog/codellama#conversational-instructions
else:
return default_pt(messages=messages)
elif "wizardcoder" in model:
return wizardcoder_pt(messages=messages)
elif "phind-codellama" in model:
return phind_codellama_pt(messages=messages)
else:
return default_pt(messages=messages)

View file

@ -7,6 +7,7 @@ import time
from typing import Callable from typing import Callable
from litellm.utils import ModelResponse from litellm.utils import ModelResponse
from typing import Optional from typing import Optional
from .huggingface_model_prompt_templates.factory import prompt_factory, custom_prompt
class HuggingfaceError(Exception): class HuggingfaceError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
@ -33,6 +34,7 @@ def completion(
encoding, encoding,
api_key, api_key,
logging_obj, logging_obj,
custom_prompt_dict={},
optional_params=None, optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
@ -47,21 +49,12 @@ def completion(
completion_url = os.getenv("HF_API_BASE", "") completion_url = os.getenv("HF_API_BASE", "")
else: else:
completion_url = f"https://api-inference.huggingface.co/models/{model}" completion_url = f"https://api-inference.huggingface.co/models/{model}"
prompt = "" if model in custom_prompt_dict:
if ( # check if the model has a registered custom prompt
"meta-llama" in model and "chat" in model model_prompt_details = custom_prompt_dict[model]
): # use the required special tokens for meta-llama - https://huggingface.co/blog/llama2#how-to-prompt-llama-2 prompt = custom_prompt(role_dict=model_prompt_details["roles"], pre_message_sep=model_prompt_details["pre_message_sep"], post_message_sep=model_prompt_details["post_message_sep"], messages=messages)
prompt = "<s>"
for message in messages:
if message["role"] == "system":
prompt += "[INST] <<SYS>>" + message["content"]
elif message["role"] == "assistant":
prompt += message["content"] + "</s><s>[INST]"
elif message["role"] == "user":
prompt += message["content"] + "[/INST]"
else: else:
for message in messages: prompt = prompt_factory(model=model, messages=messages)
prompt += f"{message['content']}"
### MAP INPUT PARAMS ### MAP INPUT PARAMS
data = { data = {
"inputs": prompt, "inputs": prompt,

View file

@ -563,8 +563,8 @@ def completion(
logger_fn=logger_fn, logger_fn=logger_fn,
encoding=encoding, encoding=encoding,
api_key=huggingface_key, api_key=huggingface_key,
logging_obj=logging logging_obj=logging,
custom_prompt_dict=litellm.custom_prompt_dict
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object, # don't try to access stream object,

View file

@ -0,0 +1,43 @@
# import sys, os
# import traceback
# from dotenv import load_dotenv
# load_dotenv()
# import os
# sys.path.insert(
# 0, os.path.abspath("../..")
# ) # Adds the parent directory to the system path
# import pytest
# import litellm
# from litellm import embedding, completion, text_completion
# def logger_fn(user_model_dict):
# return
# print(f"user_model_dict: {user_model_dict}")
# messages=[{"role": "user", "content": "Write me a function to print hello world"}]
# # test if the first-party prompt templates work
# def test_huggingface_supported_models():
# model = "huggingface/WizardLM/WizardCoder-Python-34B-V1.0"
# response = completion(model=model, messages=messages, max_tokens=256, api_base="https://ji16r2iys9a8rjk2.us-east-1.aws.endpoints.huggingface.cloud", logger_fn=logger_fn)
# print(response['choices'][0]['message']['content'])
# return response
# test_huggingface_supported_models()
# # test if a custom prompt template works
# litellm.register_prompt_template(
# model="togethercomputer/LLaMA-2-7B-32K",
# roles={"system":"", "assistant":"Assistant:", "user":"User:"},
# pre_message_sep= "\n",
# post_message_sep= "\n"
# )
# def test_huggingface_custom_model():
# model = "huggingface/togethercomputer/LLaMA-2-7B-32K"
# response = completion(model=model, messages=messages, api_base="https://ecd4sb5n09bo4ei2.us-east-1.aws.endpoints.huggingface.cloud", logger_fn=logger_fn)
# print(response['choices'][0]['message']['content'])
# return response
# test_huggingface_custom_model()

View file

@ -1326,6 +1326,27 @@ def modify_integration(integration_name, integration_params):
Supabase.supabase_table_name = integration_params["table_name"] Supabase.supabase_table_name = integration_params["table_name"]
# custom prompt helper function
def register_prompt_template(model: str, roles: dict, pre_message_sep: str, post_message_sep: str):
"""
Example usage:
```
import litellm
litellm.register_prompt_template(
model="bloomz",
roles={"system":"<|im_start|>system", "assistant":"<|im_start|>assistant", "user":"<|im_start|>user"}
pre_message_sep: "\n",
post_message_sep: "<|im_end|>\n"
)
```
"""
litellm.custom_prompt_dict[model] = {
"roles": roles,
"pre_message_sep": pre_message_sep,
"post_message_sep": post_message_sep
}
return litellm.custom_prompt_dict
####### [BETA] HOSTED PRODUCT ################ - https://docs.litellm.ai/docs/debugging/hosted_debugging ####### [BETA] HOSTED PRODUCT ################ - https://docs.litellm.ai/docs/debugging/hosted_debugging
@ -1415,7 +1436,6 @@ def get_model_list():
f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}" f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}"
) )
####### EXCEPTION MAPPING ################ ####### EXCEPTION MAPPING ################
def exception_type(model, original_exception, custom_llm_provider): def exception_type(model, original_exception, custom_llm_provider):
global user_logger_fn, liteDebuggerClient global user_logger_fn, liteDebuggerClient

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "0.1.531" version = "0.1.532"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT License" license = "MIT License"