forked from phoenix/litellm-mirror
fix(bedrock.py): fix amazon titan prompt formatting
This commit is contained in:
parent
b81f8ec8ca
commit
cb5a13ed49
2 changed files with 45 additions and 6 deletions
|
@ -477,8 +477,8 @@ def init_bedrock_client(
|
||||||
|
|
||||||
|
|
||||||
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
|
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
|
||||||
# handle anthropic prompts using anthropic constants
|
# handle anthropic prompts and amazon titan prompts
|
||||||
if provider == "anthropic":
|
if provider == "anthropic" or provider == "amazon":
|
||||||
if model in custom_prompt_dict:
|
if model in custom_prompt_dict:
|
||||||
# check if the model has a registered custom prompt
|
# check if the model has a registered custom prompt
|
||||||
model_prompt_details = custom_prompt_dict[model]
|
model_prompt_details = custom_prompt_dict[model]
|
||||||
|
@ -490,7 +490,7 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prompt = prompt_factory(
|
prompt = prompt_factory(
|
||||||
model=model, messages=messages, custom_llm_provider="anthropic"
|
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prompt = ""
|
prompt = ""
|
||||||
|
@ -623,6 +623,7 @@ def completion(
|
||||||
"textGenerationConfig": inference_params,
|
"textGenerationConfig": inference_params,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
data = json.dumps({})
|
data = json.dumps({})
|
||||||
|
|
||||||
|
|
|
@ -90,9 +90,11 @@ def ollama_pt(
|
||||||
return {"prompt": prompt, "images": images}
|
return {"prompt": prompt, "images": images}
|
||||||
else:
|
else:
|
||||||
prompt = "".join(
|
prompt = "".join(
|
||||||
|
(
|
||||||
m["content"]
|
m["content"]
|
||||||
if isinstance(m["content"], str) is str
|
if isinstance(m["content"], str) is str
|
||||||
else "".join(m["content"])
|
else "".join(m["content"])
|
||||||
|
)
|
||||||
for m in messages
|
for m in messages
|
||||||
)
|
)
|
||||||
return prompt
|
return prompt
|
||||||
|
@ -422,6 +424,34 @@ def anthropic_pt(
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def amazon_titan_pt(
|
||||||
|
messages: list,
|
||||||
|
): # format - https://github.com/BerriAI/litellm/issues/1896
|
||||||
|
"""
|
||||||
|
Amazon Titan uses 'User:' and 'Bot: in it's prompt template
|
||||||
|
"""
|
||||||
|
|
||||||
|
class AmazonTitanConstants(Enum):
|
||||||
|
HUMAN_PROMPT = "\n\nUser: " # Assuming this is similar to Anthropic prompt formatting, since amazon titan's prompt formatting is currently undocumented
|
||||||
|
AI_PROMPT = "\n\nBot: "
|
||||||
|
|
||||||
|
prompt = ""
|
||||||
|
for idx, message in enumerate(messages):
|
||||||
|
if message["role"] == "user":
|
||||||
|
prompt += f"{AmazonTitanConstants.HUMAN_PROMPT.value}{message['content']}"
|
||||||
|
elif message["role"] == "system":
|
||||||
|
prompt += f"{AmazonTitanConstants.HUMAN_PROMPT.value}<admin>{message['content']}</admin>"
|
||||||
|
else:
|
||||||
|
prompt += f"{AmazonTitanConstants.AI_PROMPT.value}{message['content']}"
|
||||||
|
if (
|
||||||
|
idx == 0 and message["role"] == "assistant"
|
||||||
|
): # ensure the prompt always starts with `\n\nHuman: `
|
||||||
|
prompt = f"{AmazonTitanConstants.HUMAN_PROMPT.value}" + prompt
|
||||||
|
if messages[-1]["role"] != "assistant":
|
||||||
|
prompt += f"{AmazonTitanConstants.AI_PROMPT.value}"
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
def _load_image_from_url(image_url):
|
def _load_image_from_url(image_url):
|
||||||
try:
|
try:
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -636,6 +666,14 @@ def prompt_factory(
|
||||||
return gemini_text_image_pt(messages=messages)
|
return gemini_text_image_pt(messages=messages)
|
||||||
elif custom_llm_provider == "mistral":
|
elif custom_llm_provider == "mistral":
|
||||||
return mistral_api_pt(messages=messages)
|
return mistral_api_pt(messages=messages)
|
||||||
|
elif custom_llm_provider == "bedrock":
|
||||||
|
if "amazon.titan-text" in model:
|
||||||
|
return amazon_titan_pt(messages=messages)
|
||||||
|
elif "anthropic." in model:
|
||||||
|
if any(_ in model for _ in ["claude-2.1", "claude-v2:1"]):
|
||||||
|
return claude_2_1_pt(messages=messages)
|
||||||
|
else:
|
||||||
|
return anthropic_pt(messages=messages)
|
||||||
try:
|
try:
|
||||||
if "meta-llama/llama-2" in model and "chat" in model:
|
if "meta-llama/llama-2" in model and "chat" in model:
|
||||||
return llama_2_chat_pt(messages=messages)
|
return llama_2_chat_pt(messages=messages)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue