mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 19:54:13 +00:00
prompt formatting for together ai llama2 models
This commit is contained in:
parent
51b64c59f3
commit
090ec35a4d
4 changed files with 31 additions and 23 deletions
|
@ -7,7 +7,7 @@ import time
|
|||
from typing import Callable
|
||||
from litellm.utils import ModelResponse
|
||||
from typing import Optional
|
||||
from .huggingface_model_prompt_templates.factory import prompt_factory, custom_prompt
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
||||
class HuggingfaceError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
|
|
|
@ -13,8 +13,18 @@ def llama_2_chat_pt(messages):
|
|||
prompt += message["content"] + "[/INST]"
|
||||
return prompt
|
||||
|
||||
def llama_2_pt(messages):
|
||||
return " ".join(message["content"] for message in messages)
|
||||
# TogetherAI Llama2 prompt template
|
||||
def togetherai_llama_2_chat_pt(messages):
|
||||
prompt = "[INST]\n"
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
prompt += message["content"]
|
||||
elif message["role"] == "assistant":
|
||||
prompt += message["content"]
|
||||
elif message["role"] == "user":
|
||||
prompt += message["content"]
|
||||
prompt += "\n[\INST]\n\n"
|
||||
return prompt
|
||||
|
||||
# Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110
|
||||
def falcon_instruct_pt(messages):
|
||||
|
@ -77,21 +87,17 @@ def custom_prompt(role_dict: dict, pre_message_sep: str, post_message_sep: str,
|
|||
|
||||
def prompt_factory(model: str, messages: list):
|
||||
model = model.lower()
|
||||
if "bloom" in model:
|
||||
return default_pt(messages=messages)
|
||||
elif "flan-t5" in model:
|
||||
return default_pt(messages=messages)
|
||||
elif "meta-llama" in model:
|
||||
if "meta-llama/Llama-2" in model:
|
||||
if "chat" in model:
|
||||
return llama_2_chat_pt(messages=messages)
|
||||
else:
|
||||
return default_pt(messages=messages)
|
||||
elif "falcon" in model: # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template.
|
||||
elif "tiiuae/falcon" in model: # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template.
|
||||
if "instruct" in model:
|
||||
return falcon_instruct_pt(messages=messages)
|
||||
else:
|
||||
return default_pt(messages=messages)
|
||||
elif "mpt" in model:
|
||||
elif "mosaicml/mpt" in model:
|
||||
if "chat" in model:
|
||||
return mpt_chat_pt(messages=messages)
|
||||
else:
|
||||
|
@ -101,9 +107,11 @@ def prompt_factory(model: str, messages: list):
|
|||
return llama_2_chat_pt(messages=messages) # https://huggingface.co/blog/codellama#conversational-instructions
|
||||
else:
|
||||
return default_pt(messages=messages)
|
||||
elif "wizardcoder" in model:
|
||||
elif "wizardlm/wizardcoder" in model:
|
||||
return wizardcoder_pt(messages=messages)
|
||||
elif "phind-codellama" in model:
|
||||
elif "phind/phind-codellama" in model:
|
||||
return phind_codellama_pt(messages=messages)
|
||||
elif "togethercomputer/llama-2" in model:
|
||||
return togetherai_llama_2_chat_pt(messages=messages)
|
||||
else:
|
||||
return default_pt(messages=messages)
|
|
@ -5,6 +5,7 @@ import requests
|
|||
import time
|
||||
from typing import Callable
|
||||
from litellm.utils import ModelResponse
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
||||
class TogetherAIError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
|
@ -34,21 +35,19 @@ def completion(
|
|||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
custom_prompt_dict={},
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
):
|
||||
headers = validate_environment(api_key)
|
||||
model = model
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
if "role" in message:
|
||||
if message["role"] == "user":
|
||||
prompt += f"{message['content']}"
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(role_dict=model_prompt_details["roles"], pre_message_sep=model_prompt_details["pre_message_sep"], post_message_sep=model_prompt_details["post_message_sep"], messages=messages)
|
||||
else:
|
||||
prompt += f"{message['content']}"
|
||||
else:
|
||||
prompt += f"{message['content']}"
|
||||
prompt = prompt_factory(model=model, messages=messages)
|
||||
data = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
|
|
|
@ -20,7 +20,7 @@ litellm.use_client = True
|
|||
# litellm.set_verbose = True
|
||||
# litellm.secret_manager_client = InfisicalClient(token=os.environ["INFISICAL_TOKEN"])
|
||||
|
||||
user_message = "write me a function to print hello world in python"
|
||||
user_message = "Write a short poem about the sky"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
|
||||
|
||||
|
@ -390,12 +390,13 @@ def test_completion_replicate_stability():
|
|||
def test_completion_together_ai():
|
||||
model_name = "togethercomputer/llama-2-70b-chat"
|
||||
try:
|
||||
response = completion(model=model_name, messages=messages)
|
||||
response = completion(model=model_name, messages=messages, max_tokens=256, logger_fn=logger_fn)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
test_completion_together_ai()
|
||||
|
||||
def test_completion_sagemaker():
|
||||
try:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue