fix(together_ai.py): improve together ai custom prompt templating

This commit is contained in:
Krrish Dholakia 2023-12-06 19:34:42 -08:00
parent 04eecaa493
commit ac7d0a1632
3 changed files with 61 additions and 24 deletions

View file

@ -2,7 +2,7 @@ from enum import Enum
import requests, traceback import requests, traceback
import json import json
from jinja2 import Template, exceptions, Environment, meta from jinja2 import Template, exceptions, Environment, meta
from typing import Optional from typing import Optional, Any
def default_pt(messages): def default_pt(messages):
return " ".join(message["content"] for message in messages) return " ".join(message["content"] for message in messages)
@ -159,26 +159,27 @@ def phind_codellama_pt(messages):
prompt += "### Assistant\n" + message["content"] + "\n\n" prompt += "### Assistant\n" + message["content"] + "\n\n"
return prompt return prompt
def hf_chat_template(model: str, messages: list): def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=None):
## get the tokenizer config from huggingface ## get the tokenizer config from huggingface
def _get_tokenizer_config(hf_model_name): if chat_template is None:
url = f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json" def _get_tokenizer_config(hf_model_name):
# Make a GET request to fetch the JSON data url = f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json"
response = requests.get(url) # Make a GET request to fetch the JSON data
if response.status_code == 200: response = requests.get(url)
# Parse the JSON data if response.status_code == 200:
tokenizer_config = json.loads(response.content) # Parse the JSON data
return {"status": "success", "tokenizer": tokenizer_config} tokenizer_config = json.loads(response.content)
else: return {"status": "success", "tokenizer": tokenizer_config}
return {"status": "failure"} else:
tokenizer_config = _get_tokenizer_config(model) return {"status": "failure"}
if tokenizer_config["status"] == "failure" or "chat_template" not in tokenizer_config["tokenizer"]: tokenizer_config = _get_tokenizer_config(model)
raise Exception("No chat template found") if tokenizer_config["status"] == "failure" or "chat_template" not in tokenizer_config["tokenizer"]:
## read the bos token, eos token and chat template from the json raise Exception("No chat template found")
tokenizer_config = tokenizer_config["tokenizer"] ## read the bos token, eos token and chat template from the json
bos_token = tokenizer_config["bos_token"] tokenizer_config = tokenizer_config["tokenizer"]
eos_token = tokenizer_config["eos_token"] bos_token = tokenizer_config["bos_token"]
chat_template = tokenizer_config["chat_template"] eos_token = tokenizer_config["eos_token"]
chat_template = tokenizer_config["chat_template"]
def raise_exception(message): def raise_exception(message):
raise Exception(f"Error message - {message}") raise Exception(f"Error message - {message}")
@ -262,6 +263,35 @@ def claude_2_1_pt(messages: list): # format - https://docs.anthropic.com/claude/
prompt += f"{AnthropicConstants.AI_PROMPT.value}" # prompt must end with \"\n\nAssistant: " turn prompt += f"{AnthropicConstants.AI_PROMPT.value}" # prompt must end with \"\n\nAssistant: " turn
return prompt return prompt
### TOGETHER AI
def get_model_info(token, model):
headers = {
'Authorization': f'Bearer {token}'
}
response = requests.get('https://api.together.xyz/models/info', headers=headers)
if response.status_code == 200:
model_info = response.json()
for m in model_info:
if m["name"].lower().strip() == model.strip():
return m['config'].get('prompt_format', None), m['config'].get('chat_template', None)
return None, None
else:
return None, None
def format_prompt_togetherai(messages, prompt_format, chat_template):
human_prompt, assistant_prompt = prompt_format.split('{prompt}')
if chat_template is not None:
prompt = hf_chat_template(model=None, messages=messages, chat_template=chat_template)
elif prompt_format is not None:
prompt = custom_prompt(role_dict={}, messages=messages, initial_prompt_value=human_prompt, final_prompt_value=assistant_prompt)
else:
prompt = default_pt(messages)
return prompt
###
def anthropic_pt(messages: list): # format - https://docs.anthropic.com/claude/reference/complete_post def anthropic_pt(messages: list): # format - https://docs.anthropic.com/claude/reference/complete_post
class AnthropicConstants(Enum): class AnthropicConstants(Enum):
HUMAN_PROMPT = "\n\nHuman: " HUMAN_PROMPT = "\n\nHuman: "
@ -328,7 +358,7 @@ def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="",
prompt += final_prompt_value prompt += final_prompt_value
return prompt return prompt
def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str]=None): def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str]=None, api_key: Optional[str]=None):
original_model_name = model original_model_name = model
model = model.lower() model = model.lower()
if custom_llm_provider == "ollama": if custom_llm_provider == "ollama":
@ -338,7 +368,9 @@ def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str
return claude_2_1_pt(messages=messages) return claude_2_1_pt(messages=messages)
else: else:
return anthropic_pt(messages=messages) return anthropic_pt(messages=messages)
elif custom_llm_provider == "together_ai":
prompt_format, chat_template = get_model_info(token=api_key, model=model)
return format_prompt_togetherai(messages=messages, prompt_format=prompt_format, chat_template=chat_template)
try: try:
if "meta-llama/llama-2" in model and "chat" in model: if "meta-llama/llama-2" in model and "chat" in model:
return llama_2_chat_pt(messages=messages) return llama_2_chat_pt(messages=messages)

View file

@ -115,7 +115,7 @@ def completion(
messages=messages, messages=messages,
) )
else: else:
prompt = prompt_factory(model=model, messages=messages) prompt = prompt_factory(model=model, messages=messages, api_key=api_key, custom_llm_provider="together_ai") # api key required to query together ai model list
data = { data = {
"model": model, "model": model,

View file

@ -1001,8 +1001,13 @@ def test_replicate_custom_prompt_dict():
######## Test TogetherAI ######## ######## Test TogetherAI ########
def test_completion_together_ai(): def test_completion_together_ai():
model_name = "together_ai/togethercomputer/llama-2-70b-chat" model_name = "together_ai/togethercomputer/CodeLlama-13b-Instruct"
try: try:
messages =[
{"role": "user", "content": "Who are you"},
{"role": "assistant", "content": "I am your helpful assistant."},
{"role": "user", "content": "Tell me a joke"},
]
response = completion(model=model_name, messages=messages, max_tokens=256, n=1, logger_fn=logger_fn) response = completion(model=model_name, messages=messages, max_tokens=256, n=1, logger_fn=logger_fn)
# Add any assertions here to check the response # Add any assertions here to check the response
print(response) print(response)