fix(bedrock.py): fix amazon titan prompt formatting

This commit is contained in:
Krrish Dholakia 2024-02-13 22:02:15 -08:00
parent b81f8ec8ca
commit cb5a13ed49
2 changed files with 45 additions and 6 deletions

View file

@ -477,8 +477,8 @@ def init_bedrock_client(
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
# handle anthropic prompts using anthropic constants
if provider == "anthropic":
# handle anthropic prompts and amazon titan prompts
if provider == "anthropic" or provider == "amazon":
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
@ -490,7 +490,7 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
)
else:
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic"
model=model, messages=messages, custom_llm_provider="bedrock"
)
else:
prompt = ""
@ -623,6 +623,7 @@ def completion(
"textGenerationConfig": inference_params,
}
)
else:
data = json.dumps({})

View file

@ -90,9 +90,11 @@ def ollama_pt(
return {"prompt": prompt, "images": images}
else:
prompt = "".join(
m["content"]
if isinstance(m["content"], str) is str
else "".join(m["content"])
(
m["content"]
if isinstance(m["content"], str) is str
else "".join(m["content"])
)
for m in messages
)
return prompt
@ -422,6 +424,34 @@ def anthropic_pt(
return prompt
def amazon_titan_pt(
messages: list,
): # format - https://github.com/BerriAI/litellm/issues/1896
"""
Amazon Titan uses 'User:' and 'Bot: in it's prompt template
"""
class AmazonTitanConstants(Enum):
HUMAN_PROMPT = "\n\nUser: " # Assuming this is similar to Anthropic prompt formatting, since amazon titan's prompt formatting is currently undocumented
AI_PROMPT = "\n\nBot: "
prompt = ""
for idx, message in enumerate(messages):
if message["role"] == "user":
prompt += f"{AmazonTitanConstants.HUMAN_PROMPT.value}{message['content']}"
elif message["role"] == "system":
prompt += f"{AmazonTitanConstants.HUMAN_PROMPT.value}<admin>{message['content']}</admin>"
else:
prompt += f"{AmazonTitanConstants.AI_PROMPT.value}{message['content']}"
if (
idx == 0 and message["role"] == "assistant"
): # ensure the prompt always starts with `\n\nHuman: `
prompt = f"{AmazonTitanConstants.HUMAN_PROMPT.value}" + prompt
if messages[-1]["role"] != "assistant":
prompt += f"{AmazonTitanConstants.AI_PROMPT.value}"
return prompt
def _load_image_from_url(image_url):
try:
from PIL import Image
@ -636,6 +666,14 @@ def prompt_factory(
return gemini_text_image_pt(messages=messages)
elif custom_llm_provider == "mistral":
return mistral_api_pt(messages=messages)
elif custom_llm_provider == "bedrock":
if "amazon.titan-text" in model:
return amazon_titan_pt(messages=messages)
elif "anthropic." in model:
if any(_ in model for _ in ["claude-2.1", "claude-v2:1"]):
return claude_2_1_pt(messages=messages)
else:
return anthropic_pt(messages=messages)
try:
if "meta-llama/llama-2" in model and "chat" in model:
return llama_2_chat_pt(messages=messages)