prompt formatting for together ai llama2 models

This commit is contained in:
Krrish Dholakia 2023-09-05 11:57:13 -07:00
parent 51b64c59f3
commit 090ec35a4d
4 changed files with 31 additions and 23 deletions

View file

@ -7,7 +7,7 @@ import time
from typing import Callable from typing import Callable
from litellm.utils import ModelResponse from litellm.utils import ModelResponse
from typing import Optional from typing import Optional
from .huggingface_model_prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
class HuggingfaceError(Exception): class HuggingfaceError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):

View file

@ -13,8 +13,18 @@ def llama_2_chat_pt(messages):
prompt += message["content"] + "[/INST]" prompt += message["content"] + "[/INST]"
return prompt return prompt
def llama_2_pt(messages): # TogetherAI Llama2 prompt template
return " ".join(message["content"] for message in messages) def togetherai_llama_2_chat_pt(messages):
prompt = "[INST]\n"
for message in messages:
if message["role"] == "system":
prompt += message["content"]
elif message["role"] == "assistant":
prompt += message["content"]
elif message["role"] == "user":
prompt += message["content"]
prompt += "\n[\INST]\n\n"
return prompt
# Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110 # Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110
def falcon_instruct_pt(messages): def falcon_instruct_pt(messages):
@ -77,21 +87,17 @@ def custom_prompt(role_dict: dict, pre_message_sep: str, post_message_sep: str,
def prompt_factory(model: str, messages: list): def prompt_factory(model: str, messages: list):
model = model.lower() model = model.lower()
if "bloom" in model: if "meta-llama/Llama-2" in model:
return default_pt(messages=messages)
elif "flan-t5" in model:
return default_pt(messages=messages)
elif "meta-llama" in model:
if "chat" in model: if "chat" in model:
return llama_2_chat_pt(messages=messages) return llama_2_chat_pt(messages=messages)
else: else:
return default_pt(messages=messages) return default_pt(messages=messages)
elif "falcon" in model: # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template. elif "tiiuae/falcon" in model: # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template.
if "instruct" in model: if "instruct" in model:
return falcon_instruct_pt(messages=messages) return falcon_instruct_pt(messages=messages)
else: else:
return default_pt(messages=messages) return default_pt(messages=messages)
elif "mpt" in model: elif "mosaicml/mpt" in model:
if "chat" in model: if "chat" in model:
return mpt_chat_pt(messages=messages) return mpt_chat_pt(messages=messages)
else: else:
@ -101,9 +107,11 @@ def prompt_factory(model: str, messages: list):
return llama_2_chat_pt(messages=messages) # https://huggingface.co/blog/codellama#conversational-instructions return llama_2_chat_pt(messages=messages) # https://huggingface.co/blog/codellama#conversational-instructions
else: else:
return default_pt(messages=messages) return default_pt(messages=messages)
elif "wizardcoder" in model: elif "wizardlm/wizardcoder" in model:
return wizardcoder_pt(messages=messages) return wizardcoder_pt(messages=messages)
elif "phind-codellama" in model: elif "phind/phind-codellama" in model:
return phind_codellama_pt(messages=messages) return phind_codellama_pt(messages=messages)
elif "togethercomputer/llama-2" in model:
return togetherai_llama_2_chat_pt(messages=messages)
else: else:
return default_pt(messages=messages) return default_pt(messages=messages)

View file

@ -5,6 +5,7 @@ import requests
import time import time
from typing import Callable from typing import Callable
from litellm.utils import ModelResponse from litellm.utils import ModelResponse
from .prompt_templates.factory import prompt_factory, custom_prompt
class TogetherAIError(Exception): class TogetherAIError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
@ -34,21 +35,19 @@ def completion(
encoding, encoding,
api_key, api_key,
logging_obj, logging_obj,
custom_prompt_dict={},
optional_params=None, optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
): ):
headers = validate_environment(api_key) headers = validate_environment(api_key)
model = model model = model
prompt = "" if model in custom_prompt_dict:
for message in messages: # check if the model has a registered custom prompt
if "role" in message: model_prompt_details = custom_prompt_dict[model]
if message["role"] == "user": prompt = custom_prompt(role_dict=model_prompt_details["roles"], pre_message_sep=model_prompt_details["pre_message_sep"], post_message_sep=model_prompt_details["post_message_sep"], messages=messages)
prompt += f"{message['content']}" else:
else: prompt = prompt_factory(model=model, messages=messages)
prompt += f"{message['content']}"
else:
prompt += f"{message['content']}"
data = { data = {
"model": model, "model": model,
"prompt": prompt, "prompt": prompt,

View file

@ -20,7 +20,7 @@ litellm.use_client = True
# litellm.set_verbose = True # litellm.set_verbose = True
# litellm.secret_manager_client = InfisicalClient(token=os.environ["INFISICAL_TOKEN"]) # litellm.secret_manager_client = InfisicalClient(token=os.environ["INFISICAL_TOKEN"])
user_message = "write me a function to print hello world in python" user_message = "Write a short poem about the sky"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
@ -390,12 +390,13 @@ def test_completion_replicate_stability():
def test_completion_together_ai(): def test_completion_together_ai():
model_name = "togethercomputer/llama-2-70b-chat" model_name = "togethercomputer/llama-2-70b-chat"
try: try:
response = completion(model=model_name, messages=messages) response = completion(model=model_name, messages=messages, max_tokens=256, logger_fn=logger_fn)
# Add any assertions here to check the response # Add any assertions here to check the response
print(response) print(response)
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_completion_together_ai()
def test_completion_sagemaker(): def test_completion_sagemaker():
try: try: