fix(anthropic.py): support anthropic system prompt

This commit is contained in:
Krrish Dholakia 2024-03-04 10:11:29 -08:00
parent bd37b38ece
commit 1c40282627
2 changed files with 55 additions and 35 deletions

View file

@ -41,6 +41,7 @@ class AnthropicConfig:
top_p: Optional[int] = None
top_k: Optional[int] = None
metadata: Optional[dict] = None
system: Optional[str] = None
def __init__(
self,
@ -50,6 +51,7 @@ class AnthropicConfig:
top_p: Optional[int] = None,
top_k: Optional[int] = None,
metadata: Optional[dict] = None,
system: Optional[str] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
@ -118,38 +120,19 @@ def completion(
messages=messages,
)
else:
prompt = prompt_factory(
# Separate system prompt from rest of message
system_prompt_idx: Optional[int] = None
for idx, message in enumerate(messages):
if message["role"] == "system":
optional_params["system"] = message["content"]
system_prompt_idx = idx
break
if system_prompt_idx is not None:
messages.pop(system_prompt_idx)
# Format rest of message according to anthropic guidelines
messages = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic"
)
"""
format messages for anthropic
1. Anthropic supports roles like "user" and "assistant", (here litellm translates system-> assistant)
2. The first message always needs to be of role "user"
3. Each message must alternate between "user" and "assistant" (this is not addressed as now by litellm)
4. final assistant content cannot end with trailing whitespace (anthropic raises an error otherwise)
"""
# 1. Anthropic only supports roles like "user" and "assistant"
for idx, message in enumerate(messages):
if message["role"] == "system":
message["role"] = "assistant"
# if this is the final assistant message, remove trailing whitespace
# TODO: only do this if it's the final assistant message
if message["role"] == "assistant":
message["content"] = message["content"].strip()
# 2. The first message always needs to be of role "user"
if len(messages) > 0:
if messages[0]["role"] != "user":
# find the index of the first user message
for i, message in enumerate(messages):
if message["role"] == "user":
break
# remove the user message at existing position and add it to the front
messages.pop(i)
# move the first user message to the front
messages = [message] + messages
## Load Config
config = litellm.AnthropicConfig.get_config()
@ -167,7 +150,7 @@ def completion(
## LOGGING
logging_obj.pre_call(
input=prompt,
input=messages,
api_key=api_key,
additional_args={
"complete_input_dict": data,

View file

@ -424,6 +424,46 @@ def anthropic_pt(
return prompt
def anthropic_messages_pt(messages: list):
"""
format messages for anthropic
1. Anthropic supports roles like "user" and "assistant", (here litellm translates system-> assistant)
2. The first message always needs to be of role "user"
3. Each message must alternate between "user" and "assistant" (this is not addressed as now by litellm)
4. final assistant content cannot end with trailing whitespace (anthropic raises an error otherwise)
5. System messages are a separate param to the Messages API (used for tool calling)
"""
## Ensure final assistant message has no trailing whitespace
last_assistant_message_idx: Optional[int] = None
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
new_messages = []
for i in range(len(messages) - 1): # type: ignore
if i == 0 and messages[i]["role"] == "assistant":
new_messages.append({"role": "user", "content": ""})
new_messages.append(messages[i])
if messages[i]["role"] == messages[i + 1]["role"]:
if messages[i]["role"] == "user":
new_messages.append({"role": "assistant", "content": ""})
else:
new_messages.append({"role": "user", "content": ""})
if messages[i]["role"] == "assistant":
last_assistant_message_idx = i
new_messages.append(messages[-1])
if last_assistant_message_idx is not None:
new_messages[last_assistant_message_idx]["content"] = new_messages[
last_assistant_message_idx
][
"content"
].strip() # no trailing whitespace for final assistant message
return new_messages
def amazon_titan_pt(
messages: list,
): # format - https://github.com/BerriAI/litellm/issues/1896
@ -650,10 +690,7 @@ def prompt_factory(
if custom_llm_provider == "ollama":
return ollama_pt(model=model, messages=messages)
elif custom_llm_provider == "anthropic":
if any(_ in model for _ in ["claude-2.1", "claude-v2:1"]):
return claude_2_1_pt(messages=messages)
else:
return anthropic_pt(messages=messages)
return anthropic_messages_pt(messages=messages)
elif custom_llm_provider == "together_ai":
prompt_format, chat_template = get_model_info(token=api_key, model=model)
return format_prompt_togetherai(