diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py index 51b61b0d8..88cc71c5a 100644 --- a/litellm/llms/huggingface_restapi.py +++ b/litellm/llms/huggingface_restapi.py @@ -7,7 +7,7 @@ import time from typing import Callable from litellm.utils import ModelResponse from typing import Optional -from .huggingface_model_prompt_templates.factory import prompt_factory, custom_prompt +from .prompt_templates.factory import prompt_factory, custom_prompt class HuggingfaceError(Exception): def __init__(self, status_code, message): diff --git a/litellm/llms/huggingface_model_prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py similarity index 83% rename from litellm/llms/huggingface_model_prompt_templates/factory.py rename to litellm/llms/prompt_templates/factory.py index b2c5c1425..d62db1d80 100644 --- a/litellm/llms/huggingface_model_prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -13,8 +13,18 @@ def llama_2_chat_pt(messages): prompt += message["content"] + "[/INST]" return prompt -def llama_2_pt(messages): - return " ".join(message["content"] for message in messages) +# TogetherAI Llama2 prompt template +def togetherai_llama_2_chat_pt(messages): + prompt = "[INST]\n" + for message in messages: + if message["role"] == "system": + prompt += message["content"] + elif message["role"] == "assistant": + prompt += message["content"] + elif message["role"] == "user": + prompt += message["content"] + prompt += "\n[\INST]\n\n" + return prompt # Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110 def falcon_instruct_pt(messages): @@ -77,21 +87,17 @@ def custom_prompt(role_dict: dict, pre_message_sep: str, post_message_sep: str, def prompt_factory(model: str, messages: list): model = model.lower() - if "bloom" in model: - return default_pt(messages=messages) - elif "flan-t5" in model: - return default_pt(messages=messages) - elif "meta-llama" in model: + if "meta-llama/Llama-2" in model: if "chat" in model: return llama_2_chat_pt(messages=messages) else: return default_pt(messages=messages) - elif "falcon" in model: # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template. + elif "tiiuae/falcon" in model: # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template. if "instruct" in model: return falcon_instruct_pt(messages=messages) else: return default_pt(messages=messages) - elif "mpt" in model: + elif "mosaicml/mpt" in model: if "chat" in model: return mpt_chat_pt(messages=messages) else: @@ -101,9 +107,11 @@ def prompt_factory(model: str, messages: list): return llama_2_chat_pt(messages=messages) # https://huggingface.co/blog/codellama#conversational-instructions else: return default_pt(messages=messages) - elif "wizardcoder" in model: + elif "wizardlm/wizardcoder" in model: return wizardcoder_pt(messages=messages) - elif "phind-codellama" in model: + elif "phind/phind-codellama" in model: return phind_codellama_pt(messages=messages) + elif "togethercomputer/llama-2" in model: + return togetherai_llama_2_chat_pt(messages=messages) else: return default_pt(messages=messages) \ No newline at end of file diff --git a/litellm/llms/together_ai.py b/litellm/llms/together_ai.py index 89be0c407..d55616ef1 100644 --- a/litellm/llms/together_ai.py +++ b/litellm/llms/together_ai.py @@ -5,6 +5,7 @@ import requests import time from typing import Callable from litellm.utils import ModelResponse +from .prompt_templates.factory import prompt_factory, custom_prompt class TogetherAIError(Exception): def __init__(self, status_code, message): @@ -34,21 +35,19 @@ def completion( encoding, api_key, logging_obj, + custom_prompt_dict={}, optional_params=None, litellm_params=None, logger_fn=None, ): headers = validate_environment(api_key) model = model - prompt = "" - for message in messages: - if "role" in message: - if message["role"] == "user": - prompt += f"{message['content']}" - else: - prompt += f"{message['content']}" - else: - prompt += f"{message['content']}" + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_details = custom_prompt_dict[model] + prompt = custom_prompt(role_dict=model_prompt_details["roles"], pre_message_sep=model_prompt_details["pre_message_sep"], post_message_sep=model_prompt_details["post_message_sep"], messages=messages) + else: + prompt = prompt_factory(model=model, messages=messages) data = { "model": model, "prompt": prompt, diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index a4340a450..378d5cf6d 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -20,7 +20,7 @@ litellm.use_client = True # litellm.set_verbose = True # litellm.secret_manager_client = InfisicalClient(token=os.environ["INFISICAL_TOKEN"]) -user_message = "write me a function to print hello world in python" +user_message = "Write a short poem about the sky" messages = [{"content": user_message, "role": "user"}] @@ -390,12 +390,13 @@ def test_completion_replicate_stability(): def test_completion_together_ai(): model_name = "togethercomputer/llama-2-70b-chat" try: - response = completion(model=model_name, messages=messages) + response = completion(model=model_name, messages=messages, max_tokens=256, logger_fn=logger_fn) # Add any assertions here to check the response print(response) except Exception as e: pytest.fail(f"Error occurred: {e}") +test_completion_together_ai() def test_completion_sagemaker(): try: