diff --git a/litellm/__init__.py b/litellm/__init__.py index a4ed950a8..f3b51947c 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -47,7 +47,7 @@ def get_model_cost_map(): print("Error occurred:", e) return None model_cost = get_model_cost_map() - +custom_prompt_dict = {} ####### THREAD-SPECIFIC DATA ################### class MyLocal(threading.local): def __init__(self): @@ -224,7 +224,8 @@ from .utils import ( acreate, get_model_list, completion_with_split_tests, - get_max_tokens + get_max_tokens, + register_prompt_template ) from .main import * # type: ignore from .integrations import * diff --git a/litellm/__pycache__/__init__.cpython-311.pyc b/litellm/__pycache__/__init__.cpython-311.pyc index e68c4da03..684c6587e 100644 Binary files a/litellm/__pycache__/__init__.cpython-311.pyc and b/litellm/__pycache__/__init__.cpython-311.pyc differ diff --git a/litellm/__pycache__/main.cpython-311.pyc b/litellm/__pycache__/main.cpython-311.pyc index 4e3285a90..809ab0689 100644 Binary files a/litellm/__pycache__/main.cpython-311.pyc and b/litellm/__pycache__/main.cpython-311.pyc differ diff --git a/litellm/__pycache__/utils.cpython-311.pyc b/litellm/__pycache__/utils.cpython-311.pyc index 309ff5b2e..7a8c10db6 100644 Binary files a/litellm/__pycache__/utils.cpython-311.pyc and b/litellm/__pycache__/utils.cpython-311.pyc differ diff --git a/litellm/llms/huggingface_model_prompt_templates/factory.py b/litellm/llms/huggingface_model_prompt_templates/factory.py new file mode 100644 index 000000000..b2c5c1425 --- /dev/null +++ b/litellm/llms/huggingface_model_prompt_templates/factory.py @@ -0,0 +1,109 @@ +def default_pt(messages): + return " ".join(message["content"] for message in messages) + +# Llama2 prompt template +def llama_2_chat_pt(messages): + prompt = "" + for message in messages: + if message["role"] == "system": + prompt += "[INST] <>" + message["content"] + elif message["role"] == "assistant": + prompt += message["content"] + "[INST]" + elif message["role"] == "user": + prompt += message["content"] + "[/INST]" + return prompt + +def llama_2_pt(messages): + return " ".join(message["content"] for message in messages) + +# Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110 +def falcon_instruct_pt(messages): + prompt = "" + for message in messages: + if message["role"] == "system": + prompt += messages["content"] + else: + prompt += message['role']+":"+ message["content"].replace("\r\n", "\n").replace("\n\n", "\n") + prompt += "\n\n" + + +# MPT prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110 +def mpt_chat_pt(messages): + prompt = "" + for message in messages: + if message["role"] == "system": + prompt += "<|im_start|>system" + message["content"] + "<|im_end|>" + "\n" + elif message["role"] == "assistant": + prompt += "<|im_start|>assistant" + message["content"] + "<|im_end|>" + "\n" + elif message["role"] == "user": + prompt += "<|im_start|>user" + message["content"] + "<|im_end|>" + "\n" + return prompt + +# WizardCoder prompt template - https://huggingface.co/WizardLM/WizardCoder-Python-34B-V1.0#prompt-format +def wizardcoder_pt(messages): + prompt = "" + for message in messages: + if message["role"] == "system": + prompt += message["content"] + "\n\n" + elif message["role"] == "user": # map to 'Instruction' + prompt += "### Instruction:\n" + message["content"] + "\n\n" + elif message["role"] == "assistant": # map to 'Response' + prompt += "### Response:\n" + message["content"] + "\n\n" + return prompt + +# Phind-CodeLlama prompt template - https://huggingface.co/Phind/Phind-CodeLlama-34B-v2#how-to-prompt-the-model +def phind_codellama_pt(messages): + prompt = "" + for message in messages: + if message["role"] == "system": + prompt += "### System Prompt\n" + message["content"] + "\n\n" + elif message["role"] == "user": + prompt += "### User Message\n" + message["content"] + "\n\n" + elif message["role"] == "assistant": + prompt += "### Assistant\n" + message["content"] + "\n\n" + return prompt + +# Custom prompt template +def custom_prompt(role_dict: dict, pre_message_sep: str, post_message_sep: str, messages: list): + prompt = "" + for message in messages: + if message["role"] == "system": + prompt += f"{role_dict['system']}{pre_message_sep}" + message["content"] + post_message_sep + elif message["role"] == "user": + prompt += f"{role_dict['user']}{pre_message_sep}" + message["content"] + post_message_sep + elif message["role"] == "assistant": + prompt += f"{role_dict['assistant']}{pre_message_sep}" + message["content"] + post_message_sep + return prompt + +def prompt_factory(model: str, messages: list): + model = model.lower() + if "bloom" in model: + return default_pt(messages=messages) + elif "flan-t5" in model: + return default_pt(messages=messages) + elif "meta-llama" in model: + if "chat" in model: + return llama_2_chat_pt(messages=messages) + else: + return default_pt(messages=messages) + elif "falcon" in model: # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template. + if "instruct" in model: + return falcon_instruct_pt(messages=messages) + else: + return default_pt(messages=messages) + elif "mpt" in model: + if "chat" in model: + return mpt_chat_pt(messages=messages) + else: + return default_pt(messages=messages) + elif "codellama/codellama" in model: + if "instruct" in model: + return llama_2_chat_pt(messages=messages) # https://huggingface.co/blog/codellama#conversational-instructions + else: + return default_pt(messages=messages) + elif "wizardcoder" in model: + return wizardcoder_pt(messages=messages) + elif "phind-codellama" in model: + return phind_codellama_pt(messages=messages) + else: + return default_pt(messages=messages) \ No newline at end of file diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py index dcd7c3efd..51b61b0d8 100644 --- a/litellm/llms/huggingface_restapi.py +++ b/litellm/llms/huggingface_restapi.py @@ -7,6 +7,7 @@ import time from typing import Callable from litellm.utils import ModelResponse from typing import Optional +from .huggingface_model_prompt_templates.factory import prompt_factory, custom_prompt class HuggingfaceError(Exception): def __init__(self, status_code, message): @@ -33,6 +34,7 @@ def completion( encoding, api_key, logging_obj, + custom_prompt_dict={}, optional_params=None, litellm_params=None, logger_fn=None, @@ -47,21 +49,12 @@ def completion( completion_url = os.getenv("HF_API_BASE", "") else: completion_url = f"https://api-inference.huggingface.co/models/{model}" - prompt = "" - if ( - "meta-llama" in model and "chat" in model - ): # use the required special tokens for meta-llama - https://huggingface.co/blog/llama2#how-to-prompt-llama-2 - prompt = "" - for message in messages: - if message["role"] == "system": - prompt += "[INST] <>" + message["content"] - elif message["role"] == "assistant": - prompt += message["content"] + "[INST]" - elif message["role"] == "user": - prompt += message["content"] + "[/INST]" + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_details = custom_prompt_dict[model] + prompt = custom_prompt(role_dict=model_prompt_details["roles"], pre_message_sep=model_prompt_details["pre_message_sep"], post_message_sep=model_prompt_details["post_message_sep"], messages=messages) else: - for message in messages: - prompt += f"{message['content']}" + prompt = prompt_factory(model=model, messages=messages) ### MAP INPUT PARAMS data = { "inputs": prompt, diff --git a/litellm/main.py b/litellm/main.py index a5a9b0b38..05b0f1981 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -563,8 +563,8 @@ def completion( logger_fn=logger_fn, encoding=encoding, api_key=huggingface_key, - logging_obj=logging - + logging_obj=logging, + custom_prompt_dict=litellm.custom_prompt_dict ) if "stream" in optional_params and optional_params["stream"] == True: # don't try to access stream object, diff --git a/litellm/tests/test_hf_prompt_templates.py b/litellm/tests/test_hf_prompt_templates.py new file mode 100644 index 000000000..0066a9e9d --- /dev/null +++ b/litellm/tests/test_hf_prompt_templates.py @@ -0,0 +1,43 @@ +# import sys, os +# import traceback +# from dotenv import load_dotenv + +# load_dotenv() +# import os + +# sys.path.insert( +# 0, os.path.abspath("../..") +# ) # Adds the parent directory to the system path +# import pytest +# import litellm +# from litellm import embedding, completion, text_completion + +# def logger_fn(user_model_dict): +# return +# print(f"user_model_dict: {user_model_dict}") + +# messages=[{"role": "user", "content": "Write me a function to print hello world"}] + +# # test if the first-party prompt templates work +# def test_huggingface_supported_models(): +# model = "huggingface/WizardLM/WizardCoder-Python-34B-V1.0" +# response = completion(model=model, messages=messages, max_tokens=256, api_base="https://ji16r2iys9a8rjk2.us-east-1.aws.endpoints.huggingface.cloud", logger_fn=logger_fn) +# print(response['choices'][0]['message']['content']) +# return response + +# test_huggingface_supported_models() + +# # test if a custom prompt template works +# litellm.register_prompt_template( +# model="togethercomputer/LLaMA-2-7B-32K", +# roles={"system":"", "assistant":"Assistant:", "user":"User:"}, +# pre_message_sep= "\n", +# post_message_sep= "\n" +# ) +# def test_huggingface_custom_model(): +# model = "huggingface/togethercomputer/LLaMA-2-7B-32K" +# response = completion(model=model, messages=messages, api_base="https://ecd4sb5n09bo4ei2.us-east-1.aws.endpoints.huggingface.cloud", logger_fn=logger_fn) +# print(response['choices'][0]['message']['content']) +# return response + +# test_huggingface_custom_model() \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index d611c05ed..3f403ce70 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1326,6 +1326,27 @@ def modify_integration(integration_name, integration_params): Supabase.supabase_table_name = integration_params["table_name"] +# custom prompt helper function +def register_prompt_template(model: str, roles: dict, pre_message_sep: str, post_message_sep: str): + """ + Example usage: + ``` + import litellm + litellm.register_prompt_template( + model="bloomz", + roles={"system":"<|im_start|>system", "assistant":"<|im_start|>assistant", "user":"<|im_start|>user"} + pre_message_sep: "\n", + post_message_sep: "<|im_end|>\n" + ) + ``` + """ + litellm.custom_prompt_dict[model] = { + "roles": roles, + "pre_message_sep": pre_message_sep, + "post_message_sep": post_message_sep + } + return litellm.custom_prompt_dict + ####### [BETA] HOSTED PRODUCT ################ - https://docs.litellm.ai/docs/debugging/hosted_debugging @@ -1415,7 +1436,6 @@ def get_model_list(): f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}" ) - ####### EXCEPTION MAPPING ################ def exception_type(model, original_exception, custom_llm_provider): global user_logger_fn, liteDebuggerClient diff --git a/pyproject.toml b/pyproject.toml index a733fdd0b..d56f01f9c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "0.1.531" +version = "0.1.532" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License"