forked from phoenix/litellm-mirror
fix(anthropic.py-+-bedrock.py): anthropic prompt format
This commit is contained in:
parent
220935c3cc
commit
4b48af7c3c
4 changed files with 60 additions and 40 deletions
|
@ -6,6 +6,7 @@ import time
|
|||
from typing import Callable, Optional
|
||||
from litellm.utils import ModelResponse
|
||||
import litellm
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
||||
class AnthropicConstants(Enum):
|
||||
HUMAN_PROMPT = "\n\nHuman: "
|
||||
|
@ -71,6 +72,7 @@ def completion(
|
|||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
|
@ -81,24 +83,17 @@ def completion(
|
|||
logger_fn=None,
|
||||
):
|
||||
headers = validate_environment(api_key)
|
||||
prompt = f"{AnthropicConstants.HUMAN_PROMPT.value}"
|
||||
for message in messages:
|
||||
if "role" in message:
|
||||
if message["role"] == "user":
|
||||
prompt += (
|
||||
f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
|
||||
)
|
||||
elif message["role"] == "system":
|
||||
prompt += (
|
||||
f"{AnthropicConstants.HUMAN_PROMPT.value}<admin>{message['content']}</admin>"
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages
|
||||
)
|
||||
else:
|
||||
prompt += (
|
||||
f"{AnthropicConstants.AI_PROMPT.value}{message['content']}"
|
||||
)
|
||||
else:
|
||||
prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
|
||||
prompt += f"{AnthropicConstants.AI_PROMPT.value}"
|
||||
prompt = prompt_factory(model=model, messages=messages, custom_llm_provider="anthropic")
|
||||
|
||||
## Load Config
|
||||
config = litellm.AnthropicConfig.get_config()
|
||||
|
|
|
@ -4,6 +4,7 @@ import time
|
|||
from typing import Callable, Optional
|
||||
import litellm
|
||||
from litellm.utils import ModelResponse, get_secret
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
||||
class BedrockError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
|
@ -206,27 +207,20 @@ def init_bedrock_client(
|
|||
return client
|
||||
|
||||
|
||||
def convert_messages_to_prompt(messages, provider):
|
||||
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
|
||||
# handle anthropic prompts using anthropic constants
|
||||
if provider == "anthropic":
|
||||
prompt = f"{AnthropicConstants.HUMAN_PROMPT.value}"
|
||||
for message in messages:
|
||||
if "role" in message:
|
||||
if message["role"] == "user":
|
||||
prompt += (
|
||||
f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
|
||||
)
|
||||
elif message["role"] == "system":
|
||||
prompt += (
|
||||
f"{AnthropicConstants.HUMAN_PROMPT.value}<admin>{message['content']}</admin>"
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages
|
||||
)
|
||||
else:
|
||||
prompt += (
|
||||
f"{AnthropicConstants.AI_PROMPT.value}{message['content']}"
|
||||
)
|
||||
else:
|
||||
prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
|
||||
prompt += f"{AnthropicConstants.AI_PROMPT.value}"
|
||||
prompt = prompt_factory(model=model, messages=messages, custom_llm_provider="anthropic")
|
||||
else:
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
|
@ -256,6 +250,7 @@ os.environ['AWS_SECRET_ACCESS_KEY'] = ""
|
|||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
|
@ -282,7 +277,7 @@ def completion(
|
|||
|
||||
model = model
|
||||
provider = model.split(".")[0]
|
||||
prompt = convert_messages_to_prompt(messages, provider)
|
||||
prompt = convert_messages_to_prompt(model, messages, provider, custom_prompt_dict)
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
stream = inference_params.pop("stream", False)
|
||||
if provider == "anthropic":
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from enum import Enum
|
||||
import requests, traceback
|
||||
import json
|
||||
from jinja2 import Template, exceptions, Environment, meta
|
||||
|
@ -201,6 +202,31 @@ def hf_chat_template(model: str, messages: list):
|
|||
except:
|
||||
raise Exception("Error rendering template")
|
||||
|
||||
# Anthropic template
|
||||
def anthropic_pt(messages: list): # format - https://docs.anthropic.com/claude/reference/complete_post
|
||||
class AnthropicConstants(Enum):
|
||||
HUMAN_PROMPT = "\n\nHuman: "
|
||||
AI_PROMPT = "\n\nAssistant: "
|
||||
|
||||
prompt = ""
|
||||
for idx, message in enumerate(messages): # needs to start with `\n\nHuman: ` and end with `\n\nAssistant: `
|
||||
if message["role"] == "user":
|
||||
prompt += (
|
||||
f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
|
||||
)
|
||||
elif message["role"] == "system":
|
||||
prompt += (
|
||||
f"{AnthropicConstants.HUMAN_PROMPT.value}<admin>{message['content']}</admin>"
|
||||
)
|
||||
else:
|
||||
prompt += (
|
||||
f"{AnthropicConstants.AI_PROMPT.value}{message['content']}"
|
||||
)
|
||||
if idx == 0 and message["role"] == "assistant": # ensure the prompt always starts with `\n\nHuman: `
|
||||
prompt = f"{AnthropicConstants.HUMAN_PROMPT.value}" + prompt
|
||||
prompt += f"{AnthropicConstants.AI_PROMPT.value}"
|
||||
return prompt
|
||||
|
||||
# Function call template
|
||||
def function_call_prompt(messages: list, functions: list):
|
||||
function_prompt = "The following functions are available to you:"
|
||||
|
@ -249,6 +275,8 @@ def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str
|
|||
|
||||
if custom_llm_provider == "ollama":
|
||||
return ollama_pt(messages=messages)
|
||||
elif custom_llm_provider == "anthropic":
|
||||
return anthropic_pt(messages=messages)
|
||||
|
||||
try:
|
||||
if "meta-llama/llama-2" in model:
|
||||
|
|
|
@ -586,7 +586,7 @@ def completion(
|
|||
return response
|
||||
response = model_response
|
||||
|
||||
elif model in litellm.anthropic_models:
|
||||
elif custom_llm_provider=="anthropic":
|
||||
anthropic_key = (
|
||||
api_key or litellm.anthropic_key or os.environ.get("ANTHROPIC_API_KEY") or litellm.api_key
|
||||
)
|
||||
|
@ -600,6 +600,7 @@ def completion(
|
|||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
|
@ -1036,6 +1037,7 @@ def completion(
|
|||
model_response = bedrock.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue