mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(together_ai.py): improve together ai custom prompt templating
This commit is contained in:
parent
04eecaa493
commit
ac7d0a1632
3 changed files with 61 additions and 24 deletions
|
@ -2,7 +2,7 @@ from enum import Enum
|
||||||
import requests, traceback
|
import requests, traceback
|
||||||
import json
|
import json
|
||||||
from jinja2 import Template, exceptions, Environment, meta
|
from jinja2 import Template, exceptions, Environment, meta
|
||||||
from typing import Optional
|
from typing import Optional, Any
|
||||||
|
|
||||||
def default_pt(messages):
|
def default_pt(messages):
|
||||||
return " ".join(message["content"] for message in messages)
|
return " ".join(message["content"] for message in messages)
|
||||||
|
@ -159,26 +159,27 @@ def phind_codellama_pt(messages):
|
||||||
prompt += "### Assistant\n" + message["content"] + "\n\n"
|
prompt += "### Assistant\n" + message["content"] + "\n\n"
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
def hf_chat_template(model: str, messages: list):
|
def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=None):
|
||||||
## get the tokenizer config from huggingface
|
## get the tokenizer config from huggingface
|
||||||
def _get_tokenizer_config(hf_model_name):
|
if chat_template is None:
|
||||||
url = f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json"
|
def _get_tokenizer_config(hf_model_name):
|
||||||
# Make a GET request to fetch the JSON data
|
url = f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json"
|
||||||
response = requests.get(url)
|
# Make a GET request to fetch the JSON data
|
||||||
if response.status_code == 200:
|
response = requests.get(url)
|
||||||
# Parse the JSON data
|
if response.status_code == 200:
|
||||||
tokenizer_config = json.loads(response.content)
|
# Parse the JSON data
|
||||||
return {"status": "success", "tokenizer": tokenizer_config}
|
tokenizer_config = json.loads(response.content)
|
||||||
else:
|
return {"status": "success", "tokenizer": tokenizer_config}
|
||||||
return {"status": "failure"}
|
else:
|
||||||
tokenizer_config = _get_tokenizer_config(model)
|
return {"status": "failure"}
|
||||||
if tokenizer_config["status"] == "failure" or "chat_template" not in tokenizer_config["tokenizer"]:
|
tokenizer_config = _get_tokenizer_config(model)
|
||||||
raise Exception("No chat template found")
|
if tokenizer_config["status"] == "failure" or "chat_template" not in tokenizer_config["tokenizer"]:
|
||||||
## read the bos token, eos token and chat template from the json
|
raise Exception("No chat template found")
|
||||||
tokenizer_config = tokenizer_config["tokenizer"]
|
## read the bos token, eos token and chat template from the json
|
||||||
bos_token = tokenizer_config["bos_token"]
|
tokenizer_config = tokenizer_config["tokenizer"]
|
||||||
eos_token = tokenizer_config["eos_token"]
|
bos_token = tokenizer_config["bos_token"]
|
||||||
chat_template = tokenizer_config["chat_template"]
|
eos_token = tokenizer_config["eos_token"]
|
||||||
|
chat_template = tokenizer_config["chat_template"]
|
||||||
|
|
||||||
def raise_exception(message):
|
def raise_exception(message):
|
||||||
raise Exception(f"Error message - {message}")
|
raise Exception(f"Error message - {message}")
|
||||||
|
@ -262,6 +263,35 @@ def claude_2_1_pt(messages: list): # format - https://docs.anthropic.com/claude/
|
||||||
prompt += f"{AnthropicConstants.AI_PROMPT.value}" # prompt must end with \"\n\nAssistant: " turn
|
prompt += f"{AnthropicConstants.AI_PROMPT.value}" # prompt must end with \"\n\nAssistant: " turn
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
### TOGETHER AI
|
||||||
|
|
||||||
|
def get_model_info(token, model):
|
||||||
|
headers = {
|
||||||
|
'Authorization': f'Bearer {token}'
|
||||||
|
}
|
||||||
|
response = requests.get('https://api.together.xyz/models/info', headers=headers)
|
||||||
|
if response.status_code == 200:
|
||||||
|
model_info = response.json()
|
||||||
|
for m in model_info:
|
||||||
|
if m["name"].lower().strip() == model.strip():
|
||||||
|
return m['config'].get('prompt_format', None), m['config'].get('chat_template', None)
|
||||||
|
return None, None
|
||||||
|
else:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def format_prompt_togetherai(messages, prompt_format, chat_template):
|
||||||
|
human_prompt, assistant_prompt = prompt_format.split('{prompt}')
|
||||||
|
|
||||||
|
if chat_template is not None:
|
||||||
|
prompt = hf_chat_template(model=None, messages=messages, chat_template=chat_template)
|
||||||
|
elif prompt_format is not None:
|
||||||
|
prompt = custom_prompt(role_dict={}, messages=messages, initial_prompt_value=human_prompt, final_prompt_value=assistant_prompt)
|
||||||
|
else:
|
||||||
|
prompt = default_pt(messages)
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
###
|
||||||
|
|
||||||
def anthropic_pt(messages: list): # format - https://docs.anthropic.com/claude/reference/complete_post
|
def anthropic_pt(messages: list): # format - https://docs.anthropic.com/claude/reference/complete_post
|
||||||
class AnthropicConstants(Enum):
|
class AnthropicConstants(Enum):
|
||||||
HUMAN_PROMPT = "\n\nHuman: "
|
HUMAN_PROMPT = "\n\nHuman: "
|
||||||
|
@ -328,7 +358,7 @@ def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="",
|
||||||
prompt += final_prompt_value
|
prompt += final_prompt_value
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str]=None):
|
def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str]=None, api_key: Optional[str]=None):
|
||||||
original_model_name = model
|
original_model_name = model
|
||||||
model = model.lower()
|
model = model.lower()
|
||||||
if custom_llm_provider == "ollama":
|
if custom_llm_provider == "ollama":
|
||||||
|
@ -338,7 +368,9 @@ def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str
|
||||||
return claude_2_1_pt(messages=messages)
|
return claude_2_1_pt(messages=messages)
|
||||||
else:
|
else:
|
||||||
return anthropic_pt(messages=messages)
|
return anthropic_pt(messages=messages)
|
||||||
|
elif custom_llm_provider == "together_ai":
|
||||||
|
prompt_format, chat_template = get_model_info(token=api_key, model=model)
|
||||||
|
return format_prompt_togetherai(messages=messages, prompt_format=prompt_format, chat_template=chat_template)
|
||||||
try:
|
try:
|
||||||
if "meta-llama/llama-2" in model and "chat" in model:
|
if "meta-llama/llama-2" in model and "chat" in model:
|
||||||
return llama_2_chat_pt(messages=messages)
|
return llama_2_chat_pt(messages=messages)
|
||||||
|
|
|
@ -115,7 +115,7 @@ def completion(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prompt = prompt_factory(model=model, messages=messages)
|
prompt = prompt_factory(model=model, messages=messages, api_key=api_key, custom_llm_provider="together_ai") # api key required to query together ai model list
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"model": model,
|
"model": model,
|
||||||
|
|
|
@ -1001,8 +1001,13 @@ def test_replicate_custom_prompt_dict():
|
||||||
|
|
||||||
######## Test TogetherAI ########
|
######## Test TogetherAI ########
|
||||||
def test_completion_together_ai():
|
def test_completion_together_ai():
|
||||||
model_name = "together_ai/togethercomputer/llama-2-70b-chat"
|
model_name = "together_ai/togethercomputer/CodeLlama-13b-Instruct"
|
||||||
try:
|
try:
|
||||||
|
messages =[
|
||||||
|
{"role": "user", "content": "Who are you"},
|
||||||
|
{"role": "assistant", "content": "I am your helpful assistant."},
|
||||||
|
{"role": "user", "content": "Tell me a joke"},
|
||||||
|
]
|
||||||
response = completion(model=model_name, messages=messages, max_tokens=256, n=1, logger_fn=logger_fn)
|
response = completion(model=model_name, messages=messages, max_tokens=256, n=1, logger_fn=logger_fn)
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
print(response)
|
print(response)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue