diff --git a/litellm/__init__.py b/litellm/__init__.py
index a4ed950a8..f3b51947c 100644
--- a/litellm/__init__.py
+++ b/litellm/__init__.py
@@ -47,7 +47,7 @@ def get_model_cost_map():
print("Error occurred:", e)
return None
model_cost = get_model_cost_map()
-
+custom_prompt_dict = {}
####### THREAD-SPECIFIC DATA ###################
class MyLocal(threading.local):
def __init__(self):
@@ -224,7 +224,8 @@ from .utils import (
acreate,
get_model_list,
completion_with_split_tests,
- get_max_tokens
+ get_max_tokens,
+ register_prompt_template
)
from .main import * # type: ignore
from .integrations import *
diff --git a/litellm/__pycache__/__init__.cpython-311.pyc b/litellm/__pycache__/__init__.cpython-311.pyc
index e68c4da03..684c6587e 100644
Binary files a/litellm/__pycache__/__init__.cpython-311.pyc and b/litellm/__pycache__/__init__.cpython-311.pyc differ
diff --git a/litellm/__pycache__/main.cpython-311.pyc b/litellm/__pycache__/main.cpython-311.pyc
index 4e3285a90..809ab0689 100644
Binary files a/litellm/__pycache__/main.cpython-311.pyc and b/litellm/__pycache__/main.cpython-311.pyc differ
diff --git a/litellm/__pycache__/utils.cpython-311.pyc b/litellm/__pycache__/utils.cpython-311.pyc
index 309ff5b2e..7a8c10db6 100644
Binary files a/litellm/__pycache__/utils.cpython-311.pyc and b/litellm/__pycache__/utils.cpython-311.pyc differ
diff --git a/litellm/llms/huggingface_model_prompt_templates/factory.py b/litellm/llms/huggingface_model_prompt_templates/factory.py
new file mode 100644
index 000000000..b2c5c1425
--- /dev/null
+++ b/litellm/llms/huggingface_model_prompt_templates/factory.py
@@ -0,0 +1,109 @@
+def default_pt(messages):
+ return " ".join(message["content"] for message in messages)
+
+# Llama2 prompt template
+def llama_2_chat_pt(messages):
+ prompt = ""
+ for message in messages:
+ if message["role"] == "system":
+ prompt += "[INST] <>" + message["content"]
+ elif message["role"] == "assistant":
+ prompt += message["content"] + "[INST]"
+ elif message["role"] == "user":
+ prompt += message["content"] + "[/INST]"
+ return prompt
+
+def llama_2_pt(messages):
+ return " ".join(message["content"] for message in messages)
+
+# Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110
+def falcon_instruct_pt(messages):
+ prompt = ""
+ for message in messages:
+ if message["role"] == "system":
+ prompt += messages["content"]
+ else:
+ prompt += message['role']+":"+ message["content"].replace("\r\n", "\n").replace("\n\n", "\n")
+ prompt += "\n\n"
+
+
+# MPT prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110
+def mpt_chat_pt(messages):
+ prompt = ""
+ for message in messages:
+ if message["role"] == "system":
+ prompt += "<|im_start|>system" + message["content"] + "<|im_end|>" + "\n"
+ elif message["role"] == "assistant":
+ prompt += "<|im_start|>assistant" + message["content"] + "<|im_end|>" + "\n"
+ elif message["role"] == "user":
+ prompt += "<|im_start|>user" + message["content"] + "<|im_end|>" + "\n"
+ return prompt
+
+# WizardCoder prompt template - https://huggingface.co/WizardLM/WizardCoder-Python-34B-V1.0#prompt-format
+def wizardcoder_pt(messages):
+ prompt = ""
+ for message in messages:
+ if message["role"] == "system":
+ prompt += message["content"] + "\n\n"
+ elif message["role"] == "user": # map to 'Instruction'
+ prompt += "### Instruction:\n" + message["content"] + "\n\n"
+ elif message["role"] == "assistant": # map to 'Response'
+ prompt += "### Response:\n" + message["content"] + "\n\n"
+ return prompt
+
+# Phind-CodeLlama prompt template - https://huggingface.co/Phind/Phind-CodeLlama-34B-v2#how-to-prompt-the-model
+def phind_codellama_pt(messages):
+ prompt = ""
+ for message in messages:
+ if message["role"] == "system":
+ prompt += "### System Prompt\n" + message["content"] + "\n\n"
+ elif message["role"] == "user":
+ prompt += "### User Message\n" + message["content"] + "\n\n"
+ elif message["role"] == "assistant":
+ prompt += "### Assistant\n" + message["content"] + "\n\n"
+ return prompt
+
+# Custom prompt template
+def custom_prompt(role_dict: dict, pre_message_sep: str, post_message_sep: str, messages: list):
+ prompt = ""
+ for message in messages:
+ if message["role"] == "system":
+ prompt += f"{role_dict['system']}{pre_message_sep}" + message["content"] + post_message_sep
+ elif message["role"] == "user":
+ prompt += f"{role_dict['user']}{pre_message_sep}" + message["content"] + post_message_sep
+ elif message["role"] == "assistant":
+ prompt += f"{role_dict['assistant']}{pre_message_sep}" + message["content"] + post_message_sep
+ return prompt
+
+def prompt_factory(model: str, messages: list):
+ model = model.lower()
+ if "bloom" in model:
+ return default_pt(messages=messages)
+ elif "flan-t5" in model:
+ return default_pt(messages=messages)
+ elif "meta-llama" in model:
+ if "chat" in model:
+ return llama_2_chat_pt(messages=messages)
+ else:
+ return default_pt(messages=messages)
+ elif "falcon" in model: # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template.
+ if "instruct" in model:
+ return falcon_instruct_pt(messages=messages)
+ else:
+ return default_pt(messages=messages)
+ elif "mpt" in model:
+ if "chat" in model:
+ return mpt_chat_pt(messages=messages)
+ else:
+ return default_pt(messages=messages)
+ elif "codellama/codellama" in model:
+ if "instruct" in model:
+ return llama_2_chat_pt(messages=messages) # https://huggingface.co/blog/codellama#conversational-instructions
+ else:
+ return default_pt(messages=messages)
+ elif "wizardcoder" in model:
+ return wizardcoder_pt(messages=messages)
+ elif "phind-codellama" in model:
+ return phind_codellama_pt(messages=messages)
+ else:
+ return default_pt(messages=messages)
\ No newline at end of file
diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py
index dcd7c3efd..51b61b0d8 100644
--- a/litellm/llms/huggingface_restapi.py
+++ b/litellm/llms/huggingface_restapi.py
@@ -7,6 +7,7 @@ import time
from typing import Callable
from litellm.utils import ModelResponse
from typing import Optional
+from .huggingface_model_prompt_templates.factory import prompt_factory, custom_prompt
class HuggingfaceError(Exception):
def __init__(self, status_code, message):
@@ -33,6 +34,7 @@ def completion(
encoding,
api_key,
logging_obj,
+ custom_prompt_dict={},
optional_params=None,
litellm_params=None,
logger_fn=None,
@@ -47,21 +49,12 @@ def completion(
completion_url = os.getenv("HF_API_BASE", "")
else:
completion_url = f"https://api-inference.huggingface.co/models/{model}"
- prompt = ""
- if (
- "meta-llama" in model and "chat" in model
- ): # use the required special tokens for meta-llama - https://huggingface.co/blog/llama2#how-to-prompt-llama-2
- prompt = ""
- for message in messages:
- if message["role"] == "system":
- prompt += "[INST] <>" + message["content"]
- elif message["role"] == "assistant":
- prompt += message["content"] + "[INST]"
- elif message["role"] == "user":
- prompt += message["content"] + "[/INST]"
+ if model in custom_prompt_dict:
+ # check if the model has a registered custom prompt
+ model_prompt_details = custom_prompt_dict[model]
+ prompt = custom_prompt(role_dict=model_prompt_details["roles"], pre_message_sep=model_prompt_details["pre_message_sep"], post_message_sep=model_prompt_details["post_message_sep"], messages=messages)
else:
- for message in messages:
- prompt += f"{message['content']}"
+ prompt = prompt_factory(model=model, messages=messages)
### MAP INPUT PARAMS
data = {
"inputs": prompt,
diff --git a/litellm/main.py b/litellm/main.py
index a5a9b0b38..05b0f1981 100644
--- a/litellm/main.py
+++ b/litellm/main.py
@@ -563,8 +563,8 @@ def completion(
logger_fn=logger_fn,
encoding=encoding,
api_key=huggingface_key,
- logging_obj=logging
-
+ logging_obj=logging,
+ custom_prompt_dict=litellm.custom_prompt_dict
)
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
diff --git a/litellm/tests/test_hf_prompt_templates.py b/litellm/tests/test_hf_prompt_templates.py
new file mode 100644
index 000000000..0066a9e9d
--- /dev/null
+++ b/litellm/tests/test_hf_prompt_templates.py
@@ -0,0 +1,43 @@
+# import sys, os
+# import traceback
+# from dotenv import load_dotenv
+
+# load_dotenv()
+# import os
+
+# sys.path.insert(
+# 0, os.path.abspath("../..")
+# ) # Adds the parent directory to the system path
+# import pytest
+# import litellm
+# from litellm import embedding, completion, text_completion
+
+# def logger_fn(user_model_dict):
+# return
+# print(f"user_model_dict: {user_model_dict}")
+
+# messages=[{"role": "user", "content": "Write me a function to print hello world"}]
+
+# # test if the first-party prompt templates work
+# def test_huggingface_supported_models():
+# model = "huggingface/WizardLM/WizardCoder-Python-34B-V1.0"
+# response = completion(model=model, messages=messages, max_tokens=256, api_base="https://ji16r2iys9a8rjk2.us-east-1.aws.endpoints.huggingface.cloud", logger_fn=logger_fn)
+# print(response['choices'][0]['message']['content'])
+# return response
+
+# test_huggingface_supported_models()
+
+# # test if a custom prompt template works
+# litellm.register_prompt_template(
+# model="togethercomputer/LLaMA-2-7B-32K",
+# roles={"system":"", "assistant":"Assistant:", "user":"User:"},
+# pre_message_sep= "\n",
+# post_message_sep= "\n"
+# )
+# def test_huggingface_custom_model():
+# model = "huggingface/togethercomputer/LLaMA-2-7B-32K"
+# response = completion(model=model, messages=messages, api_base="https://ecd4sb5n09bo4ei2.us-east-1.aws.endpoints.huggingface.cloud", logger_fn=logger_fn)
+# print(response['choices'][0]['message']['content'])
+# return response
+
+# test_huggingface_custom_model()
\ No newline at end of file
diff --git a/litellm/utils.py b/litellm/utils.py
index d611c05ed..3f403ce70 100644
--- a/litellm/utils.py
+++ b/litellm/utils.py
@@ -1326,6 +1326,27 @@ def modify_integration(integration_name, integration_params):
Supabase.supabase_table_name = integration_params["table_name"]
+# custom prompt helper function
+def register_prompt_template(model: str, roles: dict, pre_message_sep: str, post_message_sep: str):
+ """
+ Example usage:
+ ```
+ import litellm
+ litellm.register_prompt_template(
+ model="bloomz",
+ roles={"system":"<|im_start|>system", "assistant":"<|im_start|>assistant", "user":"<|im_start|>user"}
+ pre_message_sep: "\n",
+ post_message_sep: "<|im_end|>\n"
+ )
+ ```
+ """
+ litellm.custom_prompt_dict[model] = {
+ "roles": roles,
+ "pre_message_sep": pre_message_sep,
+ "post_message_sep": post_message_sep
+ }
+ return litellm.custom_prompt_dict
+
####### [BETA] HOSTED PRODUCT ################ - https://docs.litellm.ai/docs/debugging/hosted_debugging
@@ -1415,7 +1436,6 @@ def get_model_list():
f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}"
)
-
####### EXCEPTION MAPPING ################
def exception_type(model, original_exception, custom_llm_provider):
global user_logger_fn, liteDebuggerClient
diff --git a/pyproject.toml b/pyproject.toml
index a733fdd0b..d56f01f9c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "litellm"
-version = "0.1.531"
+version = "0.1.532"
description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"]
license = "MIT License"