diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index ec938f33ef..37c96d7733 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -2,7 +2,7 @@ from enum import Enum import requests, traceback import json from jinja2 import Template, exceptions, Environment, meta -from typing import Optional +from typing import Optional, Any def default_pt(messages): return " ".join(message["content"] for message in messages) @@ -159,26 +159,27 @@ def phind_codellama_pt(messages): prompt += "### Assistant\n" + message["content"] + "\n\n" return prompt -def hf_chat_template(model: str, messages: list): +def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=None): ## get the tokenizer config from huggingface - def _get_tokenizer_config(hf_model_name): - url = f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json" - # Make a GET request to fetch the JSON data - response = requests.get(url) - if response.status_code == 200: - # Parse the JSON data - tokenizer_config = json.loads(response.content) - return {"status": "success", "tokenizer": tokenizer_config} - else: - return {"status": "failure"} - tokenizer_config = _get_tokenizer_config(model) - if tokenizer_config["status"] == "failure" or "chat_template" not in tokenizer_config["tokenizer"]: - raise Exception("No chat template found") - ## read the bos token, eos token and chat template from the json - tokenizer_config = tokenizer_config["tokenizer"] - bos_token = tokenizer_config["bos_token"] - eos_token = tokenizer_config["eos_token"] - chat_template = tokenizer_config["chat_template"] + if chat_template is None: + def _get_tokenizer_config(hf_model_name): + url = f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json" + # Make a GET request to fetch the JSON data + response = requests.get(url) + if response.status_code == 200: + # Parse the JSON data + tokenizer_config = json.loads(response.content) + return {"status": "success", "tokenizer": tokenizer_config} + else: + return {"status": "failure"} + tokenizer_config = _get_tokenizer_config(model) + if tokenizer_config["status"] == "failure" or "chat_template" not in tokenizer_config["tokenizer"]: + raise Exception("No chat template found") + ## read the bos token, eos token and chat template from the json + tokenizer_config = tokenizer_config["tokenizer"] + bos_token = tokenizer_config["bos_token"] + eos_token = tokenizer_config["eos_token"] + chat_template = tokenizer_config["chat_template"] def raise_exception(message): raise Exception(f"Error message - {message}") @@ -262,6 +263,35 @@ def claude_2_1_pt(messages: list): # format - https://docs.anthropic.com/claude/ prompt += f"{AnthropicConstants.AI_PROMPT.value}" # prompt must end with \"\n\nAssistant: " turn return prompt +### TOGETHER AI + +def get_model_info(token, model): + headers = { + 'Authorization': f'Bearer {token}' + } + response = requests.get('https://api.together.xyz/models/info', headers=headers) + if response.status_code == 200: + model_info = response.json() + for m in model_info: + if m["name"].lower().strip() == model.strip(): + return m['config'].get('prompt_format', None), m['config'].get('chat_template', None) + return None, None + else: + return None, None + +def format_prompt_togetherai(messages, prompt_format, chat_template): + human_prompt, assistant_prompt = prompt_format.split('{prompt}') + + if chat_template is not None: + prompt = hf_chat_template(model=None, messages=messages, chat_template=chat_template) + elif prompt_format is not None: + prompt = custom_prompt(role_dict={}, messages=messages, initial_prompt_value=human_prompt, final_prompt_value=assistant_prompt) + else: + prompt = default_pt(messages) + return prompt + +### + def anthropic_pt(messages: list): # format - https://docs.anthropic.com/claude/reference/complete_post class AnthropicConstants(Enum): HUMAN_PROMPT = "\n\nHuman: " @@ -328,7 +358,7 @@ def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="", prompt += final_prompt_value return prompt -def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str]=None): +def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str]=None, api_key: Optional[str]=None): original_model_name = model model = model.lower() if custom_llm_provider == "ollama": @@ -338,7 +368,9 @@ def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str return claude_2_1_pt(messages=messages) else: return anthropic_pt(messages=messages) - + elif custom_llm_provider == "together_ai": + prompt_format, chat_template = get_model_info(token=api_key, model=model) + return format_prompt_togetherai(messages=messages, prompt_format=prompt_format, chat_template=chat_template) try: if "meta-llama/llama-2" in model and "chat" in model: return llama_2_chat_pt(messages=messages) diff --git a/litellm/llms/together_ai.py b/litellm/llms/together_ai.py index 8e4970a7b6..210ed497ee 100644 --- a/litellm/llms/together_ai.py +++ b/litellm/llms/together_ai.py @@ -115,7 +115,7 @@ def completion( messages=messages, ) else: - prompt = prompt_factory(model=model, messages=messages) + prompt = prompt_factory(model=model, messages=messages, api_key=api_key, custom_llm_provider="together_ai") # api key required to query together ai model list data = { "model": model, diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 7727cc7107..774131a5c0 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -1001,8 +1001,13 @@ def test_replicate_custom_prompt_dict(): ######## Test TogetherAI ######## def test_completion_together_ai(): - model_name = "together_ai/togethercomputer/llama-2-70b-chat" + model_name = "together_ai/togethercomputer/CodeLlama-13b-Instruct" try: + messages =[ + {"role": "user", "content": "Who are you"}, + {"role": "assistant", "content": "I am your helpful assistant."}, + {"role": "user", "content": "Tell me a joke"}, + ] response = completion(model=model_name, messages=messages, max_tokens=256, n=1, logger_fn=logger_fn) # Add any assertions here to check the response print(response)