mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
adding first-party + custom prompt templates for huggingface
This commit is contained in:
parent
a474b89779
commit
2384806cfd
10 changed files with 186 additions and 20 deletions
|
@ -47,7 +47,7 @@ def get_model_cost_map():
|
||||||
print("Error occurred:", e)
|
print("Error occurred:", e)
|
||||||
return None
|
return None
|
||||||
model_cost = get_model_cost_map()
|
model_cost = get_model_cost_map()
|
||||||
|
custom_prompt_dict = {}
|
||||||
####### THREAD-SPECIFIC DATA ###################
|
####### THREAD-SPECIFIC DATA ###################
|
||||||
class MyLocal(threading.local):
|
class MyLocal(threading.local):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -224,7 +224,8 @@ from .utils import (
|
||||||
acreate,
|
acreate,
|
||||||
get_model_list,
|
get_model_list,
|
||||||
completion_with_split_tests,
|
completion_with_split_tests,
|
||||||
get_max_tokens
|
get_max_tokens,
|
||||||
|
register_prompt_template
|
||||||
)
|
)
|
||||||
from .main import * # type: ignore
|
from .main import * # type: ignore
|
||||||
from .integrations import *
|
from .integrations import *
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
109
litellm/llms/huggingface_model_prompt_templates/factory.py
Normal file
109
litellm/llms/huggingface_model_prompt_templates/factory.py
Normal file
|
@ -0,0 +1,109 @@
|
||||||
|
def default_pt(messages):
|
||||||
|
return " ".join(message["content"] for message in messages)
|
||||||
|
|
||||||
|
# Llama2 prompt template
|
||||||
|
def llama_2_chat_pt(messages):
|
||||||
|
prompt = "<s>"
|
||||||
|
for message in messages:
|
||||||
|
if message["role"] == "system":
|
||||||
|
prompt += "[INST] <<SYS>>" + message["content"]
|
||||||
|
elif message["role"] == "assistant":
|
||||||
|
prompt += message["content"] + "</s><s>[INST]"
|
||||||
|
elif message["role"] == "user":
|
||||||
|
prompt += message["content"] + "[/INST]"
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def llama_2_pt(messages):
|
||||||
|
return " ".join(message["content"] for message in messages)
|
||||||
|
|
||||||
|
# Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110
|
||||||
|
def falcon_instruct_pt(messages):
|
||||||
|
prompt = ""
|
||||||
|
for message in messages:
|
||||||
|
if message["role"] == "system":
|
||||||
|
prompt += messages["content"]
|
||||||
|
else:
|
||||||
|
prompt += message['role']+":"+ message["content"].replace("\r\n", "\n").replace("\n\n", "\n")
|
||||||
|
prompt += "\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
# MPT prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110
|
||||||
|
def mpt_chat_pt(messages):
|
||||||
|
prompt = ""
|
||||||
|
for message in messages:
|
||||||
|
if message["role"] == "system":
|
||||||
|
prompt += "<|im_start|>system" + message["content"] + "<|im_end|>" + "\n"
|
||||||
|
elif message["role"] == "assistant":
|
||||||
|
prompt += "<|im_start|>assistant" + message["content"] + "<|im_end|>" + "\n"
|
||||||
|
elif message["role"] == "user":
|
||||||
|
prompt += "<|im_start|>user" + message["content"] + "<|im_end|>" + "\n"
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
# WizardCoder prompt template - https://huggingface.co/WizardLM/WizardCoder-Python-34B-V1.0#prompt-format
|
||||||
|
def wizardcoder_pt(messages):
|
||||||
|
prompt = ""
|
||||||
|
for message in messages:
|
||||||
|
if message["role"] == "system":
|
||||||
|
prompt += message["content"] + "\n\n"
|
||||||
|
elif message["role"] == "user": # map to 'Instruction'
|
||||||
|
prompt += "### Instruction:\n" + message["content"] + "\n\n"
|
||||||
|
elif message["role"] == "assistant": # map to 'Response'
|
||||||
|
prompt += "### Response:\n" + message["content"] + "\n\n"
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
# Phind-CodeLlama prompt template - https://huggingface.co/Phind/Phind-CodeLlama-34B-v2#how-to-prompt-the-model
|
||||||
|
def phind_codellama_pt(messages):
|
||||||
|
prompt = ""
|
||||||
|
for message in messages:
|
||||||
|
if message["role"] == "system":
|
||||||
|
prompt += "### System Prompt\n" + message["content"] + "\n\n"
|
||||||
|
elif message["role"] == "user":
|
||||||
|
prompt += "### User Message\n" + message["content"] + "\n\n"
|
||||||
|
elif message["role"] == "assistant":
|
||||||
|
prompt += "### Assistant\n" + message["content"] + "\n\n"
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
# Custom prompt template
|
||||||
|
def custom_prompt(role_dict: dict, pre_message_sep: str, post_message_sep: str, messages: list):
|
||||||
|
prompt = ""
|
||||||
|
for message in messages:
|
||||||
|
if message["role"] == "system":
|
||||||
|
prompt += f"{role_dict['system']}{pre_message_sep}" + message["content"] + post_message_sep
|
||||||
|
elif message["role"] == "user":
|
||||||
|
prompt += f"{role_dict['user']}{pre_message_sep}" + message["content"] + post_message_sep
|
||||||
|
elif message["role"] == "assistant":
|
||||||
|
prompt += f"{role_dict['assistant']}{pre_message_sep}" + message["content"] + post_message_sep
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def prompt_factory(model: str, messages: list):
|
||||||
|
model = model.lower()
|
||||||
|
if "bloom" in model:
|
||||||
|
return default_pt(messages=messages)
|
||||||
|
elif "flan-t5" in model:
|
||||||
|
return default_pt(messages=messages)
|
||||||
|
elif "meta-llama" in model:
|
||||||
|
if "chat" in model:
|
||||||
|
return llama_2_chat_pt(messages=messages)
|
||||||
|
else:
|
||||||
|
return default_pt(messages=messages)
|
||||||
|
elif "falcon" in model: # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template.
|
||||||
|
if "instruct" in model:
|
||||||
|
return falcon_instruct_pt(messages=messages)
|
||||||
|
else:
|
||||||
|
return default_pt(messages=messages)
|
||||||
|
elif "mpt" in model:
|
||||||
|
if "chat" in model:
|
||||||
|
return mpt_chat_pt(messages=messages)
|
||||||
|
else:
|
||||||
|
return default_pt(messages=messages)
|
||||||
|
elif "codellama/codellama" in model:
|
||||||
|
if "instruct" in model:
|
||||||
|
return llama_2_chat_pt(messages=messages) # https://huggingface.co/blog/codellama#conversational-instructions
|
||||||
|
else:
|
||||||
|
return default_pt(messages=messages)
|
||||||
|
elif "wizardcoder" in model:
|
||||||
|
return wizardcoder_pt(messages=messages)
|
||||||
|
elif "phind-codellama" in model:
|
||||||
|
return phind_codellama_pt(messages=messages)
|
||||||
|
else:
|
||||||
|
return default_pt(messages=messages)
|
|
@ -7,6 +7,7 @@ import time
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
from litellm.utils import ModelResponse
|
from litellm.utils import ModelResponse
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from .huggingface_model_prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
|
||||||
class HuggingfaceError(Exception):
|
class HuggingfaceError(Exception):
|
||||||
def __init__(self, status_code, message):
|
def __init__(self, status_code, message):
|
||||||
|
@ -33,6 +34,7 @@ def completion(
|
||||||
encoding,
|
encoding,
|
||||||
api_key,
|
api_key,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
|
custom_prompt_dict={},
|
||||||
optional_params=None,
|
optional_params=None,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
|
@ -47,21 +49,12 @@ def completion(
|
||||||
completion_url = os.getenv("HF_API_BASE", "")
|
completion_url = os.getenv("HF_API_BASE", "")
|
||||||
else:
|
else:
|
||||||
completion_url = f"https://api-inference.huggingface.co/models/{model}"
|
completion_url = f"https://api-inference.huggingface.co/models/{model}"
|
||||||
prompt = ""
|
if model in custom_prompt_dict:
|
||||||
if (
|
# check if the model has a registered custom prompt
|
||||||
"meta-llama" in model and "chat" in model
|
model_prompt_details = custom_prompt_dict[model]
|
||||||
): # use the required special tokens for meta-llama - https://huggingface.co/blog/llama2#how-to-prompt-llama-2
|
prompt = custom_prompt(role_dict=model_prompt_details["roles"], pre_message_sep=model_prompt_details["pre_message_sep"], post_message_sep=model_prompt_details["post_message_sep"], messages=messages)
|
||||||
prompt = "<s>"
|
|
||||||
for message in messages:
|
|
||||||
if message["role"] == "system":
|
|
||||||
prompt += "[INST] <<SYS>>" + message["content"]
|
|
||||||
elif message["role"] == "assistant":
|
|
||||||
prompt += message["content"] + "</s><s>[INST]"
|
|
||||||
elif message["role"] == "user":
|
|
||||||
prompt += message["content"] + "[/INST]"
|
|
||||||
else:
|
else:
|
||||||
for message in messages:
|
prompt = prompt_factory(model=model, messages=messages)
|
||||||
prompt += f"{message['content']}"
|
|
||||||
### MAP INPUT PARAMS
|
### MAP INPUT PARAMS
|
||||||
data = {
|
data = {
|
||||||
"inputs": prompt,
|
"inputs": prompt,
|
||||||
|
|
|
@ -563,8 +563,8 @@ def completion(
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
api_key=huggingface_key,
|
api_key=huggingface_key,
|
||||||
logging_obj=logging
|
logging_obj=logging,
|
||||||
|
custom_prompt_dict=litellm.custom_prompt_dict
|
||||||
)
|
)
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] == True:
|
||||||
# don't try to access stream object,
|
# don't try to access stream object,
|
||||||
|
|
43
litellm/tests/test_hf_prompt_templates.py
Normal file
43
litellm/tests/test_hf_prompt_templates.py
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
# import sys, os
|
||||||
|
# import traceback
|
||||||
|
# from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# load_dotenv()
|
||||||
|
# import os
|
||||||
|
|
||||||
|
# sys.path.insert(
|
||||||
|
# 0, os.path.abspath("../..")
|
||||||
|
# ) # Adds the parent directory to the system path
|
||||||
|
# import pytest
|
||||||
|
# import litellm
|
||||||
|
# from litellm import embedding, completion, text_completion
|
||||||
|
|
||||||
|
# def logger_fn(user_model_dict):
|
||||||
|
# return
|
||||||
|
# print(f"user_model_dict: {user_model_dict}")
|
||||||
|
|
||||||
|
# messages=[{"role": "user", "content": "Write me a function to print hello world"}]
|
||||||
|
|
||||||
|
# # test if the first-party prompt templates work
|
||||||
|
# def test_huggingface_supported_models():
|
||||||
|
# model = "huggingface/WizardLM/WizardCoder-Python-34B-V1.0"
|
||||||
|
# response = completion(model=model, messages=messages, max_tokens=256, api_base="https://ji16r2iys9a8rjk2.us-east-1.aws.endpoints.huggingface.cloud", logger_fn=logger_fn)
|
||||||
|
# print(response['choices'][0]['message']['content'])
|
||||||
|
# return response
|
||||||
|
|
||||||
|
# test_huggingface_supported_models()
|
||||||
|
|
||||||
|
# # test if a custom prompt template works
|
||||||
|
# litellm.register_prompt_template(
|
||||||
|
# model="togethercomputer/LLaMA-2-7B-32K",
|
||||||
|
# roles={"system":"", "assistant":"Assistant:", "user":"User:"},
|
||||||
|
# pre_message_sep= "\n",
|
||||||
|
# post_message_sep= "\n"
|
||||||
|
# )
|
||||||
|
# def test_huggingface_custom_model():
|
||||||
|
# model = "huggingface/togethercomputer/LLaMA-2-7B-32K"
|
||||||
|
# response = completion(model=model, messages=messages, api_base="https://ecd4sb5n09bo4ei2.us-east-1.aws.endpoints.huggingface.cloud", logger_fn=logger_fn)
|
||||||
|
# print(response['choices'][0]['message']['content'])
|
||||||
|
# return response
|
||||||
|
|
||||||
|
# test_huggingface_custom_model()
|
|
@ -1326,6 +1326,27 @@ def modify_integration(integration_name, integration_params):
|
||||||
Supabase.supabase_table_name = integration_params["table_name"]
|
Supabase.supabase_table_name = integration_params["table_name"]
|
||||||
|
|
||||||
|
|
||||||
|
# custom prompt helper function
|
||||||
|
def register_prompt_template(model: str, roles: dict, pre_message_sep: str, post_message_sep: str):
|
||||||
|
"""
|
||||||
|
Example usage:
|
||||||
|
```
|
||||||
|
import litellm
|
||||||
|
litellm.register_prompt_template(
|
||||||
|
model="bloomz",
|
||||||
|
roles={"system":"<|im_start|>system", "assistant":"<|im_start|>assistant", "user":"<|im_start|>user"}
|
||||||
|
pre_message_sep: "\n",
|
||||||
|
post_message_sep: "<|im_end|>\n"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
litellm.custom_prompt_dict[model] = {
|
||||||
|
"roles": roles,
|
||||||
|
"pre_message_sep": pre_message_sep,
|
||||||
|
"post_message_sep": post_message_sep
|
||||||
|
}
|
||||||
|
return litellm.custom_prompt_dict
|
||||||
|
|
||||||
####### [BETA] HOSTED PRODUCT ################ - https://docs.litellm.ai/docs/debugging/hosted_debugging
|
####### [BETA] HOSTED PRODUCT ################ - https://docs.litellm.ai/docs/debugging/hosted_debugging
|
||||||
|
|
||||||
|
|
||||||
|
@ -1415,7 +1436,6 @@ def get_model_list():
|
||||||
f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}"
|
f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
####### EXCEPTION MAPPING ################
|
####### EXCEPTION MAPPING ################
|
||||||
def exception_type(model, original_exception, custom_llm_provider):
|
def exception_type(model, original_exception, custom_llm_provider):
|
||||||
global user_logger_fn, liteDebuggerClient
|
global user_logger_fn, liteDebuggerClient
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "litellm"
|
name = "litellm"
|
||||||
version = "0.1.531"
|
version = "0.1.532"
|
||||||
description = "Library to easily interface with LLM API providers"
|
description = "Library to easily interface with LLM API providers"
|
||||||
authors = ["BerriAI"]
|
authors = ["BerriAI"]
|
||||||
license = "MIT License"
|
license = "MIT License"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue