prompt formatting for together ai llama2 models

This commit is contained in:
Krrish Dholakia 2023-09-05 11:57:13 -07:00
parent 51b64c59f3
commit 090ec35a4d
4 changed files with 31 additions and 23 deletions

View file

@ -7,7 +7,7 @@ import time
from typing import Callable
from litellm.utils import ModelResponse
from typing import Optional
from .huggingface_model_prompt_templates.factory import prompt_factory, custom_prompt
from .prompt_templates.factory import prompt_factory, custom_prompt
class HuggingfaceError(Exception):
def __init__(self, status_code, message):

View file

@ -13,8 +13,18 @@ def llama_2_chat_pt(messages):
prompt += message["content"] + "[/INST]"
return prompt
def llama_2_pt(messages):
return " ".join(message["content"] for message in messages)
# TogetherAI Llama2 prompt template
def togetherai_llama_2_chat_pt(messages):
prompt = "[INST]\n"
for message in messages:
if message["role"] == "system":
prompt += message["content"]
elif message["role"] == "assistant":
prompt += message["content"]
elif message["role"] == "user":
prompt += message["content"]
prompt += "\n[\INST]\n\n"
return prompt
# Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110
def falcon_instruct_pt(messages):
@ -77,21 +87,17 @@ def custom_prompt(role_dict: dict, pre_message_sep: str, post_message_sep: str,
def prompt_factory(model: str, messages: list):
model = model.lower()
if "bloom" in model:
return default_pt(messages=messages)
elif "flan-t5" in model:
return default_pt(messages=messages)
elif "meta-llama" in model:
if "meta-llama/Llama-2" in model:
if "chat" in model:
return llama_2_chat_pt(messages=messages)
else:
return default_pt(messages=messages)
elif "falcon" in model: # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template.
elif "tiiuae/falcon" in model: # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template.
if "instruct" in model:
return falcon_instruct_pt(messages=messages)
else:
return default_pt(messages=messages)
elif "mpt" in model:
elif "mosaicml/mpt" in model:
if "chat" in model:
return mpt_chat_pt(messages=messages)
else:
@ -101,9 +107,11 @@ def prompt_factory(model: str, messages: list):
return llama_2_chat_pt(messages=messages) # https://huggingface.co/blog/codellama#conversational-instructions
else:
return default_pt(messages=messages)
elif "wizardcoder" in model:
elif "wizardlm/wizardcoder" in model:
return wizardcoder_pt(messages=messages)
elif "phind-codellama" in model:
elif "phind/phind-codellama" in model:
return phind_codellama_pt(messages=messages)
elif "togethercomputer/llama-2" in model:
return togetherai_llama_2_chat_pt(messages=messages)
else:
return default_pt(messages=messages)

View file

@ -5,6 +5,7 @@ import requests
import time
from typing import Callable
from litellm.utils import ModelResponse
from .prompt_templates.factory import prompt_factory, custom_prompt
class TogetherAIError(Exception):
def __init__(self, status_code, message):
@ -34,21 +35,19 @@ def completion(
encoding,
api_key,
logging_obj,
custom_prompt_dict={},
optional_params=None,
litellm_params=None,
logger_fn=None,
):
headers = validate_environment(api_key)
model = model
prompt = ""
for message in messages:
if "role" in message:
if message["role"] == "user":
prompt += f"{message['content']}"
else:
prompt += f"{message['content']}"
else:
prompt += f"{message['content']}"
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(role_dict=model_prompt_details["roles"], pre_message_sep=model_prompt_details["pre_message_sep"], post_message_sep=model_prompt_details["post_message_sep"], messages=messages)
else:
prompt = prompt_factory(model=model, messages=messages)
data = {
"model": model,
"prompt": prompt,

View file

@ -20,7 +20,7 @@ litellm.use_client = True
# litellm.set_verbose = True
# litellm.secret_manager_client = InfisicalClient(token=os.environ["INFISICAL_TOKEN"])
user_message = "write me a function to print hello world in python"
user_message = "Write a short poem about the sky"
messages = [{"content": user_message, "role": "user"}]
@ -390,12 +390,13 @@ def test_completion_replicate_stability():
def test_completion_together_ai():
model_name = "togethercomputer/llama-2-70b-chat"
try:
response = completion(model=model_name, messages=messages)
response = completion(model=model_name, messages=messages, max_tokens=256, logger_fn=logger_fn)
# Add any assertions here to check the response
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
test_completion_together_ai()
def test_completion_sagemaker():
try: