fix(anthropic.py-+-bedrock.py): anthropic prompt format

This commit is contained in:
Krrish Dholakia 2023-10-20 10:56:07 -07:00
parent 220935c3cc
commit 4b48af7c3c
4 changed files with 60 additions and 40 deletions

View file

@ -6,6 +6,7 @@ import time
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse from litellm.utils import ModelResponse
import litellm import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
class AnthropicConstants(Enum): class AnthropicConstants(Enum):
HUMAN_PROMPT = "\n\nHuman: " HUMAN_PROMPT = "\n\nHuman: "
@ -71,6 +72,7 @@ def completion(
model: str, model: str,
messages: list, messages: list,
api_base: str, api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse, model_response: ModelResponse,
print_verbose: Callable, print_verbose: Callable,
encoding, encoding,
@ -81,25 +83,18 @@ def completion(
logger_fn=None, logger_fn=None,
): ):
headers = validate_environment(api_key) headers = validate_environment(api_key)
prompt = f"{AnthropicConstants.HUMAN_PROMPT.value}" if model in custom_prompt_dict:
for message in messages: # check if the model has a registered custom prompt
if "role" in message: model_prompt_details = custom_prompt_dict[model]
if message["role"] == "user": prompt = custom_prompt(
prompt += ( role_dict=model_prompt_details["roles"],
f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}" initial_prompt_value=model_prompt_details["initial_prompt_value"],
) final_prompt_value=model_prompt_details["final_prompt_value"],
elif message["role"] == "system": messages=messages
prompt += ( )
f"{AnthropicConstants.HUMAN_PROMPT.value}<admin>{message['content']}</admin>" else:
) prompt = prompt_factory(model=model, messages=messages, custom_llm_provider="anthropic")
else:
prompt += (
f"{AnthropicConstants.AI_PROMPT.value}{message['content']}"
)
else:
prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
prompt += f"{AnthropicConstants.AI_PROMPT.value}"
## Load Config ## Load Config
config = litellm.AnthropicConfig.get_config() config = litellm.AnthropicConfig.get_config()
for k, v in config.items(): for k, v in config.items():

View file

@ -4,6 +4,7 @@ import time
from typing import Callable, Optional from typing import Callable, Optional
import litellm import litellm
from litellm.utils import ModelResponse, get_secret from litellm.utils import ModelResponse, get_secret
from .prompt_templates.factory import prompt_factory, custom_prompt
class BedrockError(Exception): class BedrockError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
@ -206,27 +207,20 @@ def init_bedrock_client(
return client return client
def convert_messages_to_prompt(messages, provider): def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
# handle anthropic prompts using anthropic constants # handle anthropic prompts using anthropic constants
if provider == "anthropic": if provider == "anthropic":
prompt = f"{AnthropicConstants.HUMAN_PROMPT.value}" if model in custom_prompt_dict:
for message in messages: # check if the model has a registered custom prompt
if "role" in message: model_prompt_details = custom_prompt_dict[model]
if message["role"] == "user": prompt = custom_prompt(
prompt += ( role_dict=model_prompt_details["roles"],
f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}" initial_prompt_value=model_prompt_details["initial_prompt_value"],
) final_prompt_value=model_prompt_details["final_prompt_value"],
elif message["role"] == "system": messages=messages
prompt += ( )
f"{AnthropicConstants.HUMAN_PROMPT.value}<admin>{message['content']}</admin>" else:
) prompt = prompt_factory(model=model, messages=messages, custom_llm_provider="anthropic")
else:
prompt += (
f"{AnthropicConstants.AI_PROMPT.value}{message['content']}"
)
else:
prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
prompt += f"{AnthropicConstants.AI_PROMPT.value}"
else: else:
prompt = "" prompt = ""
for message in messages: for message in messages:
@ -256,6 +250,7 @@ os.environ['AWS_SECRET_ACCESS_KEY'] = ""
def completion( def completion(
model: str, model: str,
messages: list, messages: list,
custom_prompt_dict: dict,
model_response: ModelResponse, model_response: ModelResponse,
print_verbose: Callable, print_verbose: Callable,
encoding, encoding,
@ -282,7 +277,7 @@ def completion(
model = model model = model
provider = model.split(".")[0] provider = model.split(".")[0]
prompt = convert_messages_to_prompt(messages, provider) prompt = convert_messages_to_prompt(model, messages, provider, custom_prompt_dict)
inference_params = copy.deepcopy(optional_params) inference_params = copy.deepcopy(optional_params)
stream = inference_params.pop("stream", False) stream = inference_params.pop("stream", False)
if provider == "anthropic": if provider == "anthropic":

View file

@ -1,3 +1,4 @@
from enum import Enum
import requests, traceback import requests, traceback
import json import json
from jinja2 import Template, exceptions, Environment, meta from jinja2 import Template, exceptions, Environment, meta
@ -201,6 +202,31 @@ def hf_chat_template(model: str, messages: list):
except: except:
raise Exception("Error rendering template") raise Exception("Error rendering template")
# Anthropic template
def anthropic_pt(messages: list): # format - https://docs.anthropic.com/claude/reference/complete_post
class AnthropicConstants(Enum):
HUMAN_PROMPT = "\n\nHuman: "
AI_PROMPT = "\n\nAssistant: "
prompt = ""
for idx, message in enumerate(messages): # needs to start with `\n\nHuman: ` and end with `\n\nAssistant: `
if message["role"] == "user":
prompt += (
f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
)
elif message["role"] == "system":
prompt += (
f"{AnthropicConstants.HUMAN_PROMPT.value}<admin>{message['content']}</admin>"
)
else:
prompt += (
f"{AnthropicConstants.AI_PROMPT.value}{message['content']}"
)
if idx == 0 and message["role"] == "assistant": # ensure the prompt always starts with `\n\nHuman: `
prompt = f"{AnthropicConstants.HUMAN_PROMPT.value}" + prompt
prompt += f"{AnthropicConstants.AI_PROMPT.value}"
return prompt
# Function call template # Function call template
def function_call_prompt(messages: list, functions: list): def function_call_prompt(messages: list, functions: list):
function_prompt = "The following functions are available to you:" function_prompt = "The following functions are available to you:"
@ -249,6 +275,8 @@ def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str
if custom_llm_provider == "ollama": if custom_llm_provider == "ollama":
return ollama_pt(messages=messages) return ollama_pt(messages=messages)
elif custom_llm_provider == "anthropic":
return anthropic_pt(messages=messages)
try: try:
if "meta-llama/llama-2" in model: if "meta-llama/llama-2" in model:

View file

@ -586,7 +586,7 @@ def completion(
return response return response
response = model_response response = model_response
elif model in litellm.anthropic_models: elif custom_llm_provider=="anthropic":
anthropic_key = ( anthropic_key = (
api_key or litellm.anthropic_key or os.environ.get("ANTHROPIC_API_KEY") or litellm.api_key api_key or litellm.anthropic_key or os.environ.get("ANTHROPIC_API_KEY") or litellm.api_key
) )
@ -600,6 +600,7 @@ def completion(
model=model, model=model,
messages=messages, messages=messages,
api_base=api_base, api_base=api_base,
custom_prompt_dict=litellm.custom_prompt_dict,
model_response=model_response, model_response=model_response,
print_verbose=print_verbose, print_verbose=print_verbose,
optional_params=optional_params, optional_params=optional_params,
@ -1036,6 +1037,7 @@ def completion(
model_response = bedrock.completion( model_response = bedrock.completion(
model=model, model=model,
messages=messages, messages=messages,
custom_prompt_dict=litellm.custom_prompt_dict,
model_response=model_response, model_response=model_response,
print_verbose=print_verbose, print_verbose=print_verbose,
optional_params=optional_params, optional_params=optional_params,