adding first-party + custom prompt templates for huggingface

This commit is contained in:
Krrish Dholakia 2023-09-04 14:48:16 -07:00
parent a474b89779
commit 2384806cfd
10 changed files with 186 additions and 20 deletions

View file

@ -7,6 +7,7 @@ import time
from typing import Callable
from litellm.utils import ModelResponse
from typing import Optional
from .huggingface_model_prompt_templates.factory import prompt_factory, custom_prompt
class HuggingfaceError(Exception):
def __init__(self, status_code, message):
@ -33,6 +34,7 @@ def completion(
encoding,
api_key,
logging_obj,
custom_prompt_dict={},
optional_params=None,
litellm_params=None,
logger_fn=None,
@ -47,21 +49,12 @@ def completion(
completion_url = os.getenv("HF_API_BASE", "")
else:
completion_url = f"https://api-inference.huggingface.co/models/{model}"
prompt = ""
if (
"meta-llama" in model and "chat" in model
): # use the required special tokens for meta-llama - https://huggingface.co/blog/llama2#how-to-prompt-llama-2
prompt = "<s>"
for message in messages:
if message["role"] == "system":
prompt += "[INST] <<SYS>>" + message["content"]
elif message["role"] == "assistant":
prompt += message["content"] + "</s><s>[INST]"
elif message["role"] == "user":
prompt += message["content"] + "[/INST]"
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(role_dict=model_prompt_details["roles"], pre_message_sep=model_prompt_details["pre_message_sep"], post_message_sep=model_prompt_details["post_message_sep"], messages=messages)
else:
for message in messages:
prompt += f"{message['content']}"
prompt = prompt_factory(model=model, messages=messages)
### MAP INPUT PARAMS
data = {
"inputs": prompt,