fix(together_ai.py): improve together ai custom prompt templating

This commit is contained in:
Krrish Dholakia 2023-12-06 19:34:42 -08:00
parent 04eecaa493
commit ac7d0a1632
3 changed files with 61 additions and 24 deletions

View file

@ -2,7 +2,7 @@ from enum import Enum
import requests, traceback
import json
from jinja2 import Template, exceptions, Environment, meta
from typing import Optional
from typing import Optional, Any
def default_pt(messages):
return " ".join(message["content"] for message in messages)
@ -159,26 +159,27 @@ def phind_codellama_pt(messages):
prompt += "### Assistant\n" + message["content"] + "\n\n"
return prompt
def hf_chat_template(model: str, messages: list):
def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=None):
## get the tokenizer config from huggingface
def _get_tokenizer_config(hf_model_name):
url = f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json"
# Make a GET request to fetch the JSON data
response = requests.get(url)
if response.status_code == 200:
# Parse the JSON data
tokenizer_config = json.loads(response.content)
return {"status": "success", "tokenizer": tokenizer_config}
else:
return {"status": "failure"}
tokenizer_config = _get_tokenizer_config(model)
if tokenizer_config["status"] == "failure" or "chat_template" not in tokenizer_config["tokenizer"]:
raise Exception("No chat template found")
## read the bos token, eos token and chat template from the json
tokenizer_config = tokenizer_config["tokenizer"]
bos_token = tokenizer_config["bos_token"]
eos_token = tokenizer_config["eos_token"]
chat_template = tokenizer_config["chat_template"]
if chat_template is None:
def _get_tokenizer_config(hf_model_name):
url = f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json"
# Make a GET request to fetch the JSON data
response = requests.get(url)
if response.status_code == 200:
# Parse the JSON data
tokenizer_config = json.loads(response.content)
return {"status": "success", "tokenizer": tokenizer_config}
else:
return {"status": "failure"}
tokenizer_config = _get_tokenizer_config(model)
if tokenizer_config["status"] == "failure" or "chat_template" not in tokenizer_config["tokenizer"]:
raise Exception("No chat template found")
## read the bos token, eos token and chat template from the json
tokenizer_config = tokenizer_config["tokenizer"]
bos_token = tokenizer_config["bos_token"]
eos_token = tokenizer_config["eos_token"]
chat_template = tokenizer_config["chat_template"]
def raise_exception(message):
raise Exception(f"Error message - {message}")
@ -262,6 +263,35 @@ def claude_2_1_pt(messages: list): # format - https://docs.anthropic.com/claude/
prompt += f"{AnthropicConstants.AI_PROMPT.value}" # prompt must end with \"\n\nAssistant: " turn
return prompt
### TOGETHER AI
def get_model_info(token, model):
headers = {
'Authorization': f'Bearer {token}'
}
response = requests.get('https://api.together.xyz/models/info', headers=headers)
if response.status_code == 200:
model_info = response.json()
for m in model_info:
if m["name"].lower().strip() == model.strip():
return m['config'].get('prompt_format', None), m['config'].get('chat_template', None)
return None, None
else:
return None, None
def format_prompt_togetherai(messages, prompt_format, chat_template):
human_prompt, assistant_prompt = prompt_format.split('{prompt}')
if chat_template is not None:
prompt = hf_chat_template(model=None, messages=messages, chat_template=chat_template)
elif prompt_format is not None:
prompt = custom_prompt(role_dict={}, messages=messages, initial_prompt_value=human_prompt, final_prompt_value=assistant_prompt)
else:
prompt = default_pt(messages)
return prompt
###
def anthropic_pt(messages: list): # format - https://docs.anthropic.com/claude/reference/complete_post
class AnthropicConstants(Enum):
HUMAN_PROMPT = "\n\nHuman: "
@ -328,7 +358,7 @@ def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="",
prompt += final_prompt_value
return prompt
def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str]=None):
def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str]=None, api_key: Optional[str]=None):
original_model_name = model
model = model.lower()
if custom_llm_provider == "ollama":
@ -338,7 +368,9 @@ def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str
return claude_2_1_pt(messages=messages)
else:
return anthropic_pt(messages=messages)
elif custom_llm_provider == "together_ai":
prompt_format, chat_template = get_model_info(token=api_key, model=model)
return format_prompt_togetherai(messages=messages, prompt_format=prompt_format, chat_template=chat_template)
try:
if "meta-llama/llama-2" in model and "chat" in model:
return llama_2_chat_pt(messages=messages)

View file

@ -115,7 +115,7 @@ def completion(
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
prompt = prompt_factory(model=model, messages=messages, api_key=api_key, custom_llm_provider="together_ai") # api key required to query together ai model list
data = {
"model": model,

View file

@ -1001,8 +1001,13 @@ def test_replicate_custom_prompt_dict():
######## Test TogetherAI ########
def test_completion_together_ai():
model_name = "together_ai/togethercomputer/llama-2-70b-chat"
model_name = "together_ai/togethercomputer/CodeLlama-13b-Instruct"
try:
messages =[
{"role": "user", "content": "Who are you"},
{"role": "assistant", "content": "I am your helpful assistant."},
{"role": "user", "content": "Tell me a joke"},
]
response = completion(model=model_name, messages=messages, max_tokens=256, n=1, logger_fn=logger_fn)
# Add any assertions here to check the response
print(response)