forked from phoenix/litellm-mirror
update custom prompt template function
This commit is contained in:
parent
8440791e04
commit
0ace48d719
4 changed files with 29 additions and 43 deletions
Binary file not shown.
|
@ -3,27 +3,22 @@ def default_pt(messages):
|
|||
|
||||
# Llama2 prompt template
|
||||
def llama_2_chat_pt(messages):
|
||||
prompt = "<s>"
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
prompt += "[INST] <<SYS>>" + message["content"]
|
||||
elif message["role"] == "assistant":
|
||||
prompt += message["content"] + "</s><s>[INST]"
|
||||
elif message["role"] == "user":
|
||||
prompt += message["content"] + "[/INST]"
|
||||
return prompt
|
||||
|
||||
# TogetherAI Llama2 prompt template
|
||||
def togetherai_llama_2_chat_pt(messages):
|
||||
prompt = "[INST]\n"
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
prompt += message["content"]
|
||||
elif message["role"] == "assistant":
|
||||
prompt += message["content"]
|
||||
elif message["role"] == "user":
|
||||
prompt += message["content"]
|
||||
prompt += "\n[\INST]\n\n"
|
||||
prompt = custom_prompt(
|
||||
role_dict={
|
||||
"system": {
|
||||
"pre_message": "[INST] <<SYS>>\n",
|
||||
"post_message": "\n<</SYS>>\n [/INST]\n"
|
||||
},
|
||||
"user": { # follow this format https://github.com/facebookresearch/llama/blob/77062717054710e352a99add63d160274ce670c6/llama/generation.py#L348
|
||||
"pre_message": "[INST] ",
|
||||
"post_message": " [/INST]\n"
|
||||
},
|
||||
"assistant": {
|
||||
"post_message": "\n" # follows this - https://replicate.com/blog/how-to-prompt-llama
|
||||
}
|
||||
},
|
||||
messages=messages
|
||||
)
|
||||
return prompt
|
||||
|
||||
# Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110
|
||||
|
@ -84,15 +79,15 @@ def phind_codellama_pt(messages):
|
|||
return prompt
|
||||
|
||||
# Custom prompt template
|
||||
def custom_prompt(role_dict: dict, pre_message_sep: str, post_message_sep: str, messages: list):
|
||||
prompt = ""
|
||||
def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="", final_prompt_value: str=""):
|
||||
prompt = initial_prompt_value
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
prompt += f"{role_dict['system']}{pre_message_sep}" + message["content"] + post_message_sep
|
||||
elif message["role"] == "user":
|
||||
prompt += f"{role_dict['user']}{pre_message_sep}" + message["content"] + post_message_sep
|
||||
elif message["role"] == "assistant":
|
||||
prompt += f"{role_dict['assistant']}{pre_message_sep}" + message["content"] + post_message_sep
|
||||
role = message["role"]
|
||||
pre_message_str = role_dict[role]["pre_message"] if role in role_dict and "pre_message" in role_dict[role] else ""
|
||||
post_message_str = role_dict[role]["post_message"] if role in role_dict and "post_message" in role_dict[role] else ""
|
||||
prompt += pre_message_str + message["content"] + post_message_str
|
||||
|
||||
prompt += final_prompt_value
|
||||
return prompt
|
||||
|
||||
def prompt_factory(model: str, messages: list):
|
||||
|
@ -100,30 +95,21 @@ def prompt_factory(model: str, messages: list):
|
|||
if "meta-llama/Llama-2" in model:
|
||||
if "chat" in model:
|
||||
return llama_2_chat_pt(messages=messages)
|
||||
else:
|
||||
return default_pt(messages=messages)
|
||||
elif "tiiuae/falcon" in model: # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template.
|
||||
if model == "tiiuae/falcon-180B-chat":
|
||||
return falcon_chat_pt(messages=messages)
|
||||
elif "instruct" in model:
|
||||
return falcon_instruct_pt(messages=messages)
|
||||
else:
|
||||
return default_pt(messages=messages)
|
||||
elif "mosaicml/mpt" in model:
|
||||
if "chat" in model:
|
||||
return mpt_chat_pt(messages=messages)
|
||||
else:
|
||||
return default_pt(messages=messages)
|
||||
elif "codellama/codellama" in model:
|
||||
if "instruct" in model:
|
||||
return llama_2_chat_pt(messages=messages) # https://huggingface.co/blog/codellama#conversational-instructions
|
||||
else:
|
||||
return default_pt(messages=messages)
|
||||
elif "wizardlm/wizardcoder" in model:
|
||||
return wizardcoder_pt(messages=messages)
|
||||
elif "phind/phind-codellama" in model:
|
||||
return phind_codellama_pt(messages=messages)
|
||||
elif "togethercomputer/llama-2" in model and "instruct" in model:
|
||||
return togetherai_llama_2_chat_pt(messages=messages)
|
||||
else:
|
||||
return default_pt(messages=messages)
|
||||
elif "togethercomputer/llama-2" in model and ("instruct" in model or "chat" in model):
|
||||
return llama_2_chat_pt(messages=messages)
|
||||
return default_pt(messages=messages) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)
|
|
@ -435,7 +435,7 @@ def test_completion_together_ai():
|
|||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
test_completion_together_ai()
|
||||
# def test_customprompt_together_ai():
|
||||
# try:
|
||||
# litellm.register_prompt_template(
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "litellm"
|
||||
version = "0.1.544"
|
||||
version = "0.1.545"
|
||||
description = "Library to easily interface with LLM API providers"
|
||||
authors = ["BerriAI"]
|
||||
license = "MIT License"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue