mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 19:54:13 +00:00
prompt formatting for together ai llama2 models
This commit is contained in:
parent
51b64c59f3
commit
090ec35a4d
4 changed files with 31 additions and 23 deletions
|
@ -7,7 +7,7 @@ import time
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
from litellm.utils import ModelResponse
|
from litellm.utils import ModelResponse
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from .huggingface_model_prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
|
||||||
class HuggingfaceError(Exception):
|
class HuggingfaceError(Exception):
|
||||||
def __init__(self, status_code, message):
|
def __init__(self, status_code, message):
|
||||||
|
|
|
@ -13,8 +13,18 @@ def llama_2_chat_pt(messages):
|
||||||
prompt += message["content"] + "[/INST]"
|
prompt += message["content"] + "[/INST]"
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
def llama_2_pt(messages):
|
# TogetherAI Llama2 prompt template
|
||||||
return " ".join(message["content"] for message in messages)
|
def togetherai_llama_2_chat_pt(messages):
|
||||||
|
prompt = "[INST]\n"
|
||||||
|
for message in messages:
|
||||||
|
if message["role"] == "system":
|
||||||
|
prompt += message["content"]
|
||||||
|
elif message["role"] == "assistant":
|
||||||
|
prompt += message["content"]
|
||||||
|
elif message["role"] == "user":
|
||||||
|
prompt += message["content"]
|
||||||
|
prompt += "\n[\INST]\n\n"
|
||||||
|
return prompt
|
||||||
|
|
||||||
# Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110
|
# Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110
|
||||||
def falcon_instruct_pt(messages):
|
def falcon_instruct_pt(messages):
|
||||||
|
@ -77,21 +87,17 @@ def custom_prompt(role_dict: dict, pre_message_sep: str, post_message_sep: str,
|
||||||
|
|
||||||
def prompt_factory(model: str, messages: list):
|
def prompt_factory(model: str, messages: list):
|
||||||
model = model.lower()
|
model = model.lower()
|
||||||
if "bloom" in model:
|
if "meta-llama/Llama-2" in model:
|
||||||
return default_pt(messages=messages)
|
|
||||||
elif "flan-t5" in model:
|
|
||||||
return default_pt(messages=messages)
|
|
||||||
elif "meta-llama" in model:
|
|
||||||
if "chat" in model:
|
if "chat" in model:
|
||||||
return llama_2_chat_pt(messages=messages)
|
return llama_2_chat_pt(messages=messages)
|
||||||
else:
|
else:
|
||||||
return default_pt(messages=messages)
|
return default_pt(messages=messages)
|
||||||
elif "falcon" in model: # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template.
|
elif "tiiuae/falcon" in model: # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template.
|
||||||
if "instruct" in model:
|
if "instruct" in model:
|
||||||
return falcon_instruct_pt(messages=messages)
|
return falcon_instruct_pt(messages=messages)
|
||||||
else:
|
else:
|
||||||
return default_pt(messages=messages)
|
return default_pt(messages=messages)
|
||||||
elif "mpt" in model:
|
elif "mosaicml/mpt" in model:
|
||||||
if "chat" in model:
|
if "chat" in model:
|
||||||
return mpt_chat_pt(messages=messages)
|
return mpt_chat_pt(messages=messages)
|
||||||
else:
|
else:
|
||||||
|
@ -101,9 +107,11 @@ def prompt_factory(model: str, messages: list):
|
||||||
return llama_2_chat_pt(messages=messages) # https://huggingface.co/blog/codellama#conversational-instructions
|
return llama_2_chat_pt(messages=messages) # https://huggingface.co/blog/codellama#conversational-instructions
|
||||||
else:
|
else:
|
||||||
return default_pt(messages=messages)
|
return default_pt(messages=messages)
|
||||||
elif "wizardcoder" in model:
|
elif "wizardlm/wizardcoder" in model:
|
||||||
return wizardcoder_pt(messages=messages)
|
return wizardcoder_pt(messages=messages)
|
||||||
elif "phind-codellama" in model:
|
elif "phind/phind-codellama" in model:
|
||||||
return phind_codellama_pt(messages=messages)
|
return phind_codellama_pt(messages=messages)
|
||||||
|
elif "togethercomputer/llama-2" in model:
|
||||||
|
return togetherai_llama_2_chat_pt(messages=messages)
|
||||||
else:
|
else:
|
||||||
return default_pt(messages=messages)
|
return default_pt(messages=messages)
|
|
@ -5,6 +5,7 @@ import requests
|
||||||
import time
|
import time
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
from litellm.utils import ModelResponse
|
from litellm.utils import ModelResponse
|
||||||
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
|
||||||
class TogetherAIError(Exception):
|
class TogetherAIError(Exception):
|
||||||
def __init__(self, status_code, message):
|
def __init__(self, status_code, message):
|
||||||
|
@ -34,21 +35,19 @@ def completion(
|
||||||
encoding,
|
encoding,
|
||||||
api_key,
|
api_key,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
|
custom_prompt_dict={},
|
||||||
optional_params=None,
|
optional_params=None,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
):
|
):
|
||||||
headers = validate_environment(api_key)
|
headers = validate_environment(api_key)
|
||||||
model = model
|
model = model
|
||||||
prompt = ""
|
if model in custom_prompt_dict:
|
||||||
for message in messages:
|
# check if the model has a registered custom prompt
|
||||||
if "role" in message:
|
model_prompt_details = custom_prompt_dict[model]
|
||||||
if message["role"] == "user":
|
prompt = custom_prompt(role_dict=model_prompt_details["roles"], pre_message_sep=model_prompt_details["pre_message_sep"], post_message_sep=model_prompt_details["post_message_sep"], messages=messages)
|
||||||
prompt += f"{message['content']}"
|
else:
|
||||||
else:
|
prompt = prompt_factory(model=model, messages=messages)
|
||||||
prompt += f"{message['content']}"
|
|
||||||
else:
|
|
||||||
prompt += f"{message['content']}"
|
|
||||||
data = {
|
data = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
|
|
|
@ -20,7 +20,7 @@ litellm.use_client = True
|
||||||
# litellm.set_verbose = True
|
# litellm.set_verbose = True
|
||||||
# litellm.secret_manager_client = InfisicalClient(token=os.environ["INFISICAL_TOKEN"])
|
# litellm.secret_manager_client = InfisicalClient(token=os.environ["INFISICAL_TOKEN"])
|
||||||
|
|
||||||
user_message = "write me a function to print hello world in python"
|
user_message = "Write a short poem about the sky"
|
||||||
messages = [{"content": user_message, "role": "user"}]
|
messages = [{"content": user_message, "role": "user"}]
|
||||||
|
|
||||||
|
|
||||||
|
@ -390,12 +390,13 @@ def test_completion_replicate_stability():
|
||||||
def test_completion_together_ai():
|
def test_completion_together_ai():
|
||||||
model_name = "togethercomputer/llama-2-70b-chat"
|
model_name = "togethercomputer/llama-2-70b-chat"
|
||||||
try:
|
try:
|
||||||
response = completion(model=model_name, messages=messages)
|
response = completion(model=model_name, messages=messages, max_tokens=256, logger_fn=logger_fn)
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
print(response)
|
print(response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
test_completion_together_ai()
|
||||||
|
|
||||||
def test_completion_sagemaker():
|
def test_completion_sagemaker():
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue