forked from phoenix/litellm-mirror
fix(anthropic.py-+-bedrock.py): anthropic prompt format
This commit is contained in:
parent
220935c3cc
commit
4b48af7c3c
4 changed files with 60 additions and 40 deletions
|
@ -6,6 +6,7 @@ import time
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
from litellm.utils import ModelResponse
|
from litellm.utils import ModelResponse
|
||||||
import litellm
|
import litellm
|
||||||
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
|
||||||
class AnthropicConstants(Enum):
|
class AnthropicConstants(Enum):
|
||||||
HUMAN_PROMPT = "\n\nHuman: "
|
HUMAN_PROMPT = "\n\nHuman: "
|
||||||
|
@ -71,6 +72,7 @@ def completion(
|
||||||
model: str,
|
model: str,
|
||||||
messages: list,
|
messages: list,
|
||||||
api_base: str,
|
api_base: str,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
print_verbose: Callable,
|
print_verbose: Callable,
|
||||||
encoding,
|
encoding,
|
||||||
|
@ -81,25 +83,18 @@ def completion(
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
):
|
):
|
||||||
headers = validate_environment(api_key)
|
headers = validate_environment(api_key)
|
||||||
prompt = f"{AnthropicConstants.HUMAN_PROMPT.value}"
|
if model in custom_prompt_dict:
|
||||||
for message in messages:
|
# check if the model has a registered custom prompt
|
||||||
if "role" in message:
|
model_prompt_details = custom_prompt_dict[model]
|
||||||
if message["role"] == "user":
|
prompt = custom_prompt(
|
||||||
prompt += (
|
role_dict=model_prompt_details["roles"],
|
||||||
f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
|
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||||
)
|
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||||
elif message["role"] == "system":
|
messages=messages
|
||||||
prompt += (
|
)
|
||||||
f"{AnthropicConstants.HUMAN_PROMPT.value}<admin>{message['content']}</admin>"
|
else:
|
||||||
)
|
prompt = prompt_factory(model=model, messages=messages, custom_llm_provider="anthropic")
|
||||||
else:
|
|
||||||
prompt += (
|
|
||||||
f"{AnthropicConstants.AI_PROMPT.value}{message['content']}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
|
|
||||||
prompt += f"{AnthropicConstants.AI_PROMPT.value}"
|
|
||||||
|
|
||||||
## Load Config
|
## Load Config
|
||||||
config = litellm.AnthropicConfig.get_config()
|
config = litellm.AnthropicConfig.get_config()
|
||||||
for k, v in config.items():
|
for k, v in config.items():
|
||||||
|
|
|
@ -4,6 +4,7 @@ import time
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.utils import ModelResponse, get_secret
|
from litellm.utils import ModelResponse, get_secret
|
||||||
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
|
||||||
class BedrockError(Exception):
|
class BedrockError(Exception):
|
||||||
def __init__(self, status_code, message):
|
def __init__(self, status_code, message):
|
||||||
|
@ -206,27 +207,20 @@ def init_bedrock_client(
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
||||||
def convert_messages_to_prompt(messages, provider):
|
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
|
||||||
# handle anthropic prompts using anthropic constants
|
# handle anthropic prompts using anthropic constants
|
||||||
if provider == "anthropic":
|
if provider == "anthropic":
|
||||||
prompt = f"{AnthropicConstants.HUMAN_PROMPT.value}"
|
if model in custom_prompt_dict:
|
||||||
for message in messages:
|
# check if the model has a registered custom prompt
|
||||||
if "role" in message:
|
model_prompt_details = custom_prompt_dict[model]
|
||||||
if message["role"] == "user":
|
prompt = custom_prompt(
|
||||||
prompt += (
|
role_dict=model_prompt_details["roles"],
|
||||||
f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
|
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||||
)
|
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||||
elif message["role"] == "system":
|
messages=messages
|
||||||
prompt += (
|
)
|
||||||
f"{AnthropicConstants.HUMAN_PROMPT.value}<admin>{message['content']}</admin>"
|
else:
|
||||||
)
|
prompt = prompt_factory(model=model, messages=messages, custom_llm_provider="anthropic")
|
||||||
else:
|
|
||||||
prompt += (
|
|
||||||
f"{AnthropicConstants.AI_PROMPT.value}{message['content']}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
|
|
||||||
prompt += f"{AnthropicConstants.AI_PROMPT.value}"
|
|
||||||
else:
|
else:
|
||||||
prompt = ""
|
prompt = ""
|
||||||
for message in messages:
|
for message in messages:
|
||||||
|
@ -256,6 +250,7 @@ os.environ['AWS_SECRET_ACCESS_KEY'] = ""
|
||||||
def completion(
|
def completion(
|
||||||
model: str,
|
model: str,
|
||||||
messages: list,
|
messages: list,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
print_verbose: Callable,
|
print_verbose: Callable,
|
||||||
encoding,
|
encoding,
|
||||||
|
@ -282,7 +277,7 @@ def completion(
|
||||||
|
|
||||||
model = model
|
model = model
|
||||||
provider = model.split(".")[0]
|
provider = model.split(".")[0]
|
||||||
prompt = convert_messages_to_prompt(messages, provider)
|
prompt = convert_messages_to_prompt(model, messages, provider, custom_prompt_dict)
|
||||||
inference_params = copy.deepcopy(optional_params)
|
inference_params = copy.deepcopy(optional_params)
|
||||||
stream = inference_params.pop("stream", False)
|
stream = inference_params.pop("stream", False)
|
||||||
if provider == "anthropic":
|
if provider == "anthropic":
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from enum import Enum
|
||||||
import requests, traceback
|
import requests, traceback
|
||||||
import json
|
import json
|
||||||
from jinja2 import Template, exceptions, Environment, meta
|
from jinja2 import Template, exceptions, Environment, meta
|
||||||
|
@ -201,6 +202,31 @@ def hf_chat_template(model: str, messages: list):
|
||||||
except:
|
except:
|
||||||
raise Exception("Error rendering template")
|
raise Exception("Error rendering template")
|
||||||
|
|
||||||
|
# Anthropic template
|
||||||
|
def anthropic_pt(messages: list): # format - https://docs.anthropic.com/claude/reference/complete_post
|
||||||
|
class AnthropicConstants(Enum):
|
||||||
|
HUMAN_PROMPT = "\n\nHuman: "
|
||||||
|
AI_PROMPT = "\n\nAssistant: "
|
||||||
|
|
||||||
|
prompt = ""
|
||||||
|
for idx, message in enumerate(messages): # needs to start with `\n\nHuman: ` and end with `\n\nAssistant: `
|
||||||
|
if message["role"] == "user":
|
||||||
|
prompt += (
|
||||||
|
f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
|
||||||
|
)
|
||||||
|
elif message["role"] == "system":
|
||||||
|
prompt += (
|
||||||
|
f"{AnthropicConstants.HUMAN_PROMPT.value}<admin>{message['content']}</admin>"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt += (
|
||||||
|
f"{AnthropicConstants.AI_PROMPT.value}{message['content']}"
|
||||||
|
)
|
||||||
|
if idx == 0 and message["role"] == "assistant": # ensure the prompt always starts with `\n\nHuman: `
|
||||||
|
prompt = f"{AnthropicConstants.HUMAN_PROMPT.value}" + prompt
|
||||||
|
prompt += f"{AnthropicConstants.AI_PROMPT.value}"
|
||||||
|
return prompt
|
||||||
|
|
||||||
# Function call template
|
# Function call template
|
||||||
def function_call_prompt(messages: list, functions: list):
|
def function_call_prompt(messages: list, functions: list):
|
||||||
function_prompt = "The following functions are available to you:"
|
function_prompt = "The following functions are available to you:"
|
||||||
|
@ -249,6 +275,8 @@ def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str
|
||||||
|
|
||||||
if custom_llm_provider == "ollama":
|
if custom_llm_provider == "ollama":
|
||||||
return ollama_pt(messages=messages)
|
return ollama_pt(messages=messages)
|
||||||
|
elif custom_llm_provider == "anthropic":
|
||||||
|
return anthropic_pt(messages=messages)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if "meta-llama/llama-2" in model:
|
if "meta-llama/llama-2" in model:
|
||||||
|
|
|
@ -586,7 +586,7 @@ def completion(
|
||||||
return response
|
return response
|
||||||
response = model_response
|
response = model_response
|
||||||
|
|
||||||
elif model in litellm.anthropic_models:
|
elif custom_llm_provider=="anthropic":
|
||||||
anthropic_key = (
|
anthropic_key = (
|
||||||
api_key or litellm.anthropic_key or os.environ.get("ANTHROPIC_API_KEY") or litellm.api_key
|
api_key or litellm.anthropic_key or os.environ.get("ANTHROPIC_API_KEY") or litellm.api_key
|
||||||
)
|
)
|
||||||
|
@ -600,6 +600,7 @@ def completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
@ -1036,6 +1037,7 @@ def completion(
|
||||||
model_response = bedrock.completion(
|
model_response = bedrock.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue