mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(together_ai.py): improve together ai custom prompt templating
This commit is contained in:
parent
04eecaa493
commit
ac7d0a1632
3 changed files with 61 additions and 24 deletions
|
@ -2,7 +2,7 @@ from enum import Enum
|
|||
import requests, traceback
|
||||
import json
|
||||
from jinja2 import Template, exceptions, Environment, meta
|
||||
from typing import Optional
|
||||
from typing import Optional, Any
|
||||
|
||||
def default_pt(messages):
|
||||
return " ".join(message["content"] for message in messages)
|
||||
|
@ -159,8 +159,9 @@ def phind_codellama_pt(messages):
|
|||
prompt += "### Assistant\n" + message["content"] + "\n\n"
|
||||
return prompt
|
||||
|
||||
def hf_chat_template(model: str, messages: list):
|
||||
def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=None):
|
||||
## get the tokenizer config from huggingface
|
||||
if chat_template is None:
|
||||
def _get_tokenizer_config(hf_model_name):
|
||||
url = f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json"
|
||||
# Make a GET request to fetch the JSON data
|
||||
|
@ -262,6 +263,35 @@ def claude_2_1_pt(messages: list): # format - https://docs.anthropic.com/claude/
|
|||
prompt += f"{AnthropicConstants.AI_PROMPT.value}" # prompt must end with \"\n\nAssistant: " turn
|
||||
return prompt
|
||||
|
||||
### TOGETHER AI
|
||||
|
||||
def get_model_info(token, model):
|
||||
headers = {
|
||||
'Authorization': f'Bearer {token}'
|
||||
}
|
||||
response = requests.get('https://api.together.xyz/models/info', headers=headers)
|
||||
if response.status_code == 200:
|
||||
model_info = response.json()
|
||||
for m in model_info:
|
||||
if m["name"].lower().strip() == model.strip():
|
||||
return m['config'].get('prompt_format', None), m['config'].get('chat_template', None)
|
||||
return None, None
|
||||
else:
|
||||
return None, None
|
||||
|
||||
def format_prompt_togetherai(messages, prompt_format, chat_template):
|
||||
human_prompt, assistant_prompt = prompt_format.split('{prompt}')
|
||||
|
||||
if chat_template is not None:
|
||||
prompt = hf_chat_template(model=None, messages=messages, chat_template=chat_template)
|
||||
elif prompt_format is not None:
|
||||
prompt = custom_prompt(role_dict={}, messages=messages, initial_prompt_value=human_prompt, final_prompt_value=assistant_prompt)
|
||||
else:
|
||||
prompt = default_pt(messages)
|
||||
return prompt
|
||||
|
||||
###
|
||||
|
||||
def anthropic_pt(messages: list): # format - https://docs.anthropic.com/claude/reference/complete_post
|
||||
class AnthropicConstants(Enum):
|
||||
HUMAN_PROMPT = "\n\nHuman: "
|
||||
|
@ -328,7 +358,7 @@ def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="",
|
|||
prompt += final_prompt_value
|
||||
return prompt
|
||||
|
||||
def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str]=None):
|
||||
def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str]=None, api_key: Optional[str]=None):
|
||||
original_model_name = model
|
||||
model = model.lower()
|
||||
if custom_llm_provider == "ollama":
|
||||
|
@ -338,7 +368,9 @@ def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str
|
|||
return claude_2_1_pt(messages=messages)
|
||||
else:
|
||||
return anthropic_pt(messages=messages)
|
||||
|
||||
elif custom_llm_provider == "together_ai":
|
||||
prompt_format, chat_template = get_model_info(token=api_key, model=model)
|
||||
return format_prompt_togetherai(messages=messages, prompt_format=prompt_format, chat_template=chat_template)
|
||||
try:
|
||||
if "meta-llama/llama-2" in model and "chat" in model:
|
||||
return llama_2_chat_pt(messages=messages)
|
||||
|
|
|
@ -115,7 +115,7 @@ def completion(
|
|||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(model=model, messages=messages)
|
||||
prompt = prompt_factory(model=model, messages=messages, api_key=api_key, custom_llm_provider="together_ai") # api key required to query together ai model list
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
|
|
|
@ -1001,8 +1001,13 @@ def test_replicate_custom_prompt_dict():
|
|||
|
||||
######## Test TogetherAI ########
|
||||
def test_completion_together_ai():
|
||||
model_name = "together_ai/togethercomputer/llama-2-70b-chat"
|
||||
model_name = "together_ai/togethercomputer/CodeLlama-13b-Instruct"
|
||||
try:
|
||||
messages =[
|
||||
{"role": "user", "content": "Who are you"},
|
||||
{"role": "assistant", "content": "I am your helpful assistant."},
|
||||
{"role": "user", "content": "Tell me a joke"},
|
||||
]
|
||||
response = completion(model=model_name, messages=messages, max_tokens=256, n=1, logger_fn=logger_fn)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue