forked from phoenix/litellm-mirror
fix(bedrock.py): fix amazon titan prompt formatting
This commit is contained in:
parent
b81f8ec8ca
commit
cb5a13ed49
2 changed files with 45 additions and 6 deletions
|
@ -477,8 +477,8 @@ def init_bedrock_client(
|
|||
|
||||
|
||||
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
|
||||
# handle anthropic prompts using anthropic constants
|
||||
if provider == "anthropic":
|
||||
# handle anthropic prompts and amazon titan prompts
|
||||
if provider == "anthropic" or provider == "amazon":
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
|
@ -490,7 +490,7 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
|
|||
)
|
||||
else:
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="anthropic"
|
||||
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||
)
|
||||
else:
|
||||
prompt = ""
|
||||
|
@ -623,6 +623,7 @@ def completion(
|
|||
"textGenerationConfig": inference_params,
|
||||
}
|
||||
)
|
||||
|
||||
else:
|
||||
data = json.dumps({})
|
||||
|
||||
|
|
|
@ -90,9 +90,11 @@ def ollama_pt(
|
|||
return {"prompt": prompt, "images": images}
|
||||
else:
|
||||
prompt = "".join(
|
||||
m["content"]
|
||||
if isinstance(m["content"], str) is str
|
||||
else "".join(m["content"])
|
||||
(
|
||||
m["content"]
|
||||
if isinstance(m["content"], str) is str
|
||||
else "".join(m["content"])
|
||||
)
|
||||
for m in messages
|
||||
)
|
||||
return prompt
|
||||
|
@ -422,6 +424,34 @@ def anthropic_pt(
|
|||
return prompt
|
||||
|
||||
|
||||
def amazon_titan_pt(
|
||||
messages: list,
|
||||
): # format - https://github.com/BerriAI/litellm/issues/1896
|
||||
"""
|
||||
Amazon Titan uses 'User:' and 'Bot: in it's prompt template
|
||||
"""
|
||||
|
||||
class AmazonTitanConstants(Enum):
|
||||
HUMAN_PROMPT = "\n\nUser: " # Assuming this is similar to Anthropic prompt formatting, since amazon titan's prompt formatting is currently undocumented
|
||||
AI_PROMPT = "\n\nBot: "
|
||||
|
||||
prompt = ""
|
||||
for idx, message in enumerate(messages):
|
||||
if message["role"] == "user":
|
||||
prompt += f"{AmazonTitanConstants.HUMAN_PROMPT.value}{message['content']}"
|
||||
elif message["role"] == "system":
|
||||
prompt += f"{AmazonTitanConstants.HUMAN_PROMPT.value}<admin>{message['content']}</admin>"
|
||||
else:
|
||||
prompt += f"{AmazonTitanConstants.AI_PROMPT.value}{message['content']}"
|
||||
if (
|
||||
idx == 0 and message["role"] == "assistant"
|
||||
): # ensure the prompt always starts with `\n\nHuman: `
|
||||
prompt = f"{AmazonTitanConstants.HUMAN_PROMPT.value}" + prompt
|
||||
if messages[-1]["role"] != "assistant":
|
||||
prompt += f"{AmazonTitanConstants.AI_PROMPT.value}"
|
||||
return prompt
|
||||
|
||||
|
||||
def _load_image_from_url(image_url):
|
||||
try:
|
||||
from PIL import Image
|
||||
|
@ -636,6 +666,14 @@ def prompt_factory(
|
|||
return gemini_text_image_pt(messages=messages)
|
||||
elif custom_llm_provider == "mistral":
|
||||
return mistral_api_pt(messages=messages)
|
||||
elif custom_llm_provider == "bedrock":
|
||||
if "amazon.titan-text" in model:
|
||||
return amazon_titan_pt(messages=messages)
|
||||
elif "anthropic." in model:
|
||||
if any(_ in model for _ in ["claude-2.1", "claude-v2:1"]):
|
||||
return claude_2_1_pt(messages=messages)
|
||||
else:
|
||||
return anthropic_pt(messages=messages)
|
||||
try:
|
||||
if "meta-llama/llama-2" in model and "chat" in model:
|
||||
return llama_2_chat_pt(messages=messages)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue