From 0ace48d719597f101c06cffcaf0eb132da62e635 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 6 Sep 2023 13:14:33 -0700 Subject: [PATCH] update custom prompt template function --- litellm/__pycache__/__init__.cpython-311.pyc | Bin 7809 -> 7894 bytes litellm/llms/prompt_templates/factory.py | 68 ++++++++----------- litellm/tests/test_completion.py | 2 +- pyproject.toml | 2 +- 4 files changed, 29 insertions(+), 43 deletions(-) diff --git a/litellm/__pycache__/__init__.cpython-311.pyc b/litellm/__pycache__/__init__.cpython-311.pyc index b24f97219209f5befd68b3d6b6c954d92c717740..441b62dbcca4a6770810e99d09c4367fa5c69332 100644 GIT binary patch delta 158 zcmZp)y=Ke1oR^o20SK7S{YbgHk++YBk$v(Uo;@MKxv3?I`pNkzsW~}`xrw@l#z|I& z$(DvDDHbLs#s=m|2BsDkrk0jT7RHvwCKd*X#z~2mhUQ6TW(LXTsg@~eiK%85hNdP) zNv6g|rfEh-7AZ-S!+9eZ6DM=Y(+36^ae+bLgWTrd(rrut DoVqSB delta 73 zcmca++i1(XoR^o20SE$){zx&|$lJ%m$ToQn&z{L^dBYiZPR`(qn9RzT#2CK$0AC%i W5QhlU2L>2%fkEJd$mRf[INST]" - elif message["role"] == "user": - prompt += message["content"] + "[/INST]" - return prompt - -# TogetherAI Llama2 prompt template -def togetherai_llama_2_chat_pt(messages): - prompt = "[INST]\n" - for message in messages: - if message["role"] == "system": - prompt += message["content"] - elif message["role"] == "assistant": - prompt += message["content"] - elif message["role"] == "user": - prompt += message["content"] - prompt += "\n[\INST]\n\n" + prompt = custom_prompt( + role_dict={ + "system": { + "pre_message": "[INST] <>\n", + "post_message": "\n<>\n [/INST]\n" + }, + "user": { # follow this format https://github.com/facebookresearch/llama/blob/77062717054710e352a99add63d160274ce670c6/llama/generation.py#L348 + "pre_message": "[INST] ", + "post_message": " [/INST]\n" + }, + "assistant": { + "post_message": "\n" # follows this - https://replicate.com/blog/how-to-prompt-llama + } + }, + messages=messages + ) return prompt # Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110 @@ -84,15 +79,15 @@ def phind_codellama_pt(messages): return prompt # Custom prompt template -def custom_prompt(role_dict: dict, pre_message_sep: str, post_message_sep: str, messages: list): - prompt = "" +def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="", final_prompt_value: str=""): + prompt = initial_prompt_value for message in messages: - if message["role"] == "system": - prompt += f"{role_dict['system']}{pre_message_sep}" + message["content"] + post_message_sep - elif message["role"] == "user": - prompt += f"{role_dict['user']}{pre_message_sep}" + message["content"] + post_message_sep - elif message["role"] == "assistant": - prompt += f"{role_dict['assistant']}{pre_message_sep}" + message["content"] + post_message_sep + role = message["role"] + pre_message_str = role_dict[role]["pre_message"] if role in role_dict and "pre_message" in role_dict[role] else "" + post_message_str = role_dict[role]["post_message"] if role in role_dict and "post_message" in role_dict[role] else "" + prompt += pre_message_str + message["content"] + post_message_str + + prompt += final_prompt_value return prompt def prompt_factory(model: str, messages: list): @@ -100,30 +95,21 @@ def prompt_factory(model: str, messages: list): if "meta-llama/Llama-2" in model: if "chat" in model: return llama_2_chat_pt(messages=messages) - else: - return default_pt(messages=messages) elif "tiiuae/falcon" in model: # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template. if model == "tiiuae/falcon-180B-chat": return falcon_chat_pt(messages=messages) elif "instruct" in model: return falcon_instruct_pt(messages=messages) - else: - return default_pt(messages=messages) elif "mosaicml/mpt" in model: if "chat" in model: return mpt_chat_pt(messages=messages) - else: - return default_pt(messages=messages) elif "codellama/codellama" in model: if "instruct" in model: return llama_2_chat_pt(messages=messages) # https://huggingface.co/blog/codellama#conversational-instructions - else: - return default_pt(messages=messages) elif "wizardlm/wizardcoder" in model: return wizardcoder_pt(messages=messages) elif "phind/phind-codellama" in model: return phind_codellama_pt(messages=messages) - elif "togethercomputer/llama-2" in model and "instruct" in model: - return togetherai_llama_2_chat_pt(messages=messages) - else: - return default_pt(messages=messages) \ No newline at end of file + elif "togethercomputer/llama-2" in model and ("instruct" in model or "chat" in model): + return llama_2_chat_pt(messages=messages) + return default_pt(messages=messages) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2) \ No newline at end of file diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index e076900a2..c4b97bf2d 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -435,7 +435,7 @@ def test_completion_together_ai(): except Exception as e: pytest.fail(f"Error occurred: {e}") - +test_completion_together_ai() # def test_customprompt_together_ai(): # try: # litellm.register_prompt_template( diff --git a/pyproject.toml b/pyproject.toml index 8e4c6bb49..1dbba1893 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "0.1.544" +version = "0.1.545" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License"