diff --git a/litellm/__pycache__/__init__.cpython-311.pyc b/litellm/__pycache__/__init__.cpython-311.pyc
index b24f97219..441b62dbc 100644
Binary files a/litellm/__pycache__/__init__.cpython-311.pyc and b/litellm/__pycache__/__init__.cpython-311.pyc differ
diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py
index 205bf56fa..de5e6615c 100644
--- a/litellm/llms/prompt_templates/factory.py
+++ b/litellm/llms/prompt_templates/factory.py
@@ -3,27 +3,22 @@ def default_pt(messages):
# Llama2 prompt template
def llama_2_chat_pt(messages):
- prompt = ""
- for message in messages:
- if message["role"] == "system":
- prompt += "[INST] <>" + message["content"]
- elif message["role"] == "assistant":
- prompt += message["content"] + "[INST]"
- elif message["role"] == "user":
- prompt += message["content"] + "[/INST]"
- return prompt
-
-# TogetherAI Llama2 prompt template
-def togetherai_llama_2_chat_pt(messages):
- prompt = "[INST]\n"
- for message in messages:
- if message["role"] == "system":
- prompt += message["content"]
- elif message["role"] == "assistant":
- prompt += message["content"]
- elif message["role"] == "user":
- prompt += message["content"]
- prompt += "\n[\INST]\n\n"
+ prompt = custom_prompt(
+ role_dict={
+ "system": {
+ "pre_message": "[INST] <>\n",
+ "post_message": "\n<>\n [/INST]\n"
+ },
+ "user": { # follow this format https://github.com/facebookresearch/llama/blob/77062717054710e352a99add63d160274ce670c6/llama/generation.py#L348
+ "pre_message": "[INST] ",
+ "post_message": " [/INST]\n"
+ },
+ "assistant": {
+ "post_message": "\n" # follows this - https://replicate.com/blog/how-to-prompt-llama
+ }
+ },
+ messages=messages
+ )
return prompt
# Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110
@@ -84,15 +79,15 @@ def phind_codellama_pt(messages):
return prompt
# Custom prompt template
-def custom_prompt(role_dict: dict, pre_message_sep: str, post_message_sep: str, messages: list):
- prompt = ""
+def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="", final_prompt_value: str=""):
+ prompt = initial_prompt_value
for message in messages:
- if message["role"] == "system":
- prompt += f"{role_dict['system']}{pre_message_sep}" + message["content"] + post_message_sep
- elif message["role"] == "user":
- prompt += f"{role_dict['user']}{pre_message_sep}" + message["content"] + post_message_sep
- elif message["role"] == "assistant":
- prompt += f"{role_dict['assistant']}{pre_message_sep}" + message["content"] + post_message_sep
+ role = message["role"]
+ pre_message_str = role_dict[role]["pre_message"] if role in role_dict and "pre_message" in role_dict[role] else ""
+ post_message_str = role_dict[role]["post_message"] if role in role_dict and "post_message" in role_dict[role] else ""
+ prompt += pre_message_str + message["content"] + post_message_str
+
+ prompt += final_prompt_value
return prompt
def prompt_factory(model: str, messages: list):
@@ -100,30 +95,21 @@ def prompt_factory(model: str, messages: list):
if "meta-llama/Llama-2" in model:
if "chat" in model:
return llama_2_chat_pt(messages=messages)
- else:
- return default_pt(messages=messages)
elif "tiiuae/falcon" in model: # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template.
if model == "tiiuae/falcon-180B-chat":
return falcon_chat_pt(messages=messages)
elif "instruct" in model:
return falcon_instruct_pt(messages=messages)
- else:
- return default_pt(messages=messages)
elif "mosaicml/mpt" in model:
if "chat" in model:
return mpt_chat_pt(messages=messages)
- else:
- return default_pt(messages=messages)
elif "codellama/codellama" in model:
if "instruct" in model:
return llama_2_chat_pt(messages=messages) # https://huggingface.co/blog/codellama#conversational-instructions
- else:
- return default_pt(messages=messages)
elif "wizardlm/wizardcoder" in model:
return wizardcoder_pt(messages=messages)
elif "phind/phind-codellama" in model:
return phind_codellama_pt(messages=messages)
- elif "togethercomputer/llama-2" in model and "instruct" in model:
- return togetherai_llama_2_chat_pt(messages=messages)
- else:
- return default_pt(messages=messages)
\ No newline at end of file
+ elif "togethercomputer/llama-2" in model and ("instruct" in model or "chat" in model):
+ return llama_2_chat_pt(messages=messages)
+ return default_pt(messages=messages) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)
\ No newline at end of file
diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py
index e076900a2..c4b97bf2d 100644
--- a/litellm/tests/test_completion.py
+++ b/litellm/tests/test_completion.py
@@ -435,7 +435,7 @@ def test_completion_together_ai():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
-
+test_completion_together_ai()
# def test_customprompt_together_ai():
# try:
# litellm.register_prompt_template(
diff --git a/pyproject.toml b/pyproject.toml
index 8e4c6bb49..1dbba1893 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "litellm"
-version = "0.1.544"
+version = "0.1.545"
description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"]
license = "MIT License"