From cb5a13ed49d33e19a6ed45570491508abfa42cc0 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 13 Feb 2024 22:02:15 -0800 Subject: [PATCH] fix(bedrock.py): fix amazon titan prompt formatting --- litellm/llms/bedrock.py | 7 ++-- litellm/llms/prompt_templates/factory.py | 44 ++++++++++++++++++++++-- 2 files changed, 45 insertions(+), 6 deletions(-) diff --git a/litellm/llms/bedrock.py b/litellm/llms/bedrock.py index ae076fdf0..b7f1c5023 100644 --- a/litellm/llms/bedrock.py +++ b/litellm/llms/bedrock.py @@ -477,8 +477,8 @@ def init_bedrock_client( def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict): - # handle anthropic prompts using anthropic constants - if provider == "anthropic": + # handle anthropic prompts and amazon titan prompts + if provider == "anthropic" or provider == "amazon": if model in custom_prompt_dict: # check if the model has a registered custom prompt model_prompt_details = custom_prompt_dict[model] @@ -490,7 +490,7 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict): ) else: prompt = prompt_factory( - model=model, messages=messages, custom_llm_provider="anthropic" + model=model, messages=messages, custom_llm_provider="bedrock" ) else: prompt = "" @@ -623,6 +623,7 @@ def completion( "textGenerationConfig": inference_params, } ) + else: data = json.dumps({}) diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index 6321860cc..7896d7c96 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -90,9 +90,11 @@ def ollama_pt( return {"prompt": prompt, "images": images} else: prompt = "".join( - m["content"] - if isinstance(m["content"], str) is str - else "".join(m["content"]) + ( + m["content"] + if isinstance(m["content"], str) is str + else "".join(m["content"]) + ) for m in messages ) return prompt @@ -422,6 +424,34 @@ def anthropic_pt( return prompt +def amazon_titan_pt( + messages: list, +): # format - https://github.com/BerriAI/litellm/issues/1896 + """ + Amazon Titan uses 'User:' and 'Bot: in it's prompt template + """ + + class AmazonTitanConstants(Enum): + HUMAN_PROMPT = "\n\nUser: " # Assuming this is similar to Anthropic prompt formatting, since amazon titan's prompt formatting is currently undocumented + AI_PROMPT = "\n\nBot: " + + prompt = "" + for idx, message in enumerate(messages): + if message["role"] == "user": + prompt += f"{AmazonTitanConstants.HUMAN_PROMPT.value}{message['content']}" + elif message["role"] == "system": + prompt += f"{AmazonTitanConstants.HUMAN_PROMPT.value}{message['content']}" + else: + prompt += f"{AmazonTitanConstants.AI_PROMPT.value}{message['content']}" + if ( + idx == 0 and message["role"] == "assistant" + ): # ensure the prompt always starts with `\n\nHuman: ` + prompt = f"{AmazonTitanConstants.HUMAN_PROMPT.value}" + prompt + if messages[-1]["role"] != "assistant": + prompt += f"{AmazonTitanConstants.AI_PROMPT.value}" + return prompt + + def _load_image_from_url(image_url): try: from PIL import Image @@ -636,6 +666,14 @@ def prompt_factory( return gemini_text_image_pt(messages=messages) elif custom_llm_provider == "mistral": return mistral_api_pt(messages=messages) + elif custom_llm_provider == "bedrock": + if "amazon.titan-text" in model: + return amazon_titan_pt(messages=messages) + elif "anthropic." in model: + if any(_ in model for _ in ["claude-2.1", "claude-v2:1"]): + return claude_2_1_pt(messages=messages) + else: + return anthropic_pt(messages=messages) try: if "meta-llama/llama-2" in model and "chat" in model: return llama_2_chat_pt(messages=messages)