forked from phoenix/litellm-mirror
fix(anthropic.py): support anthropic system prompt
This commit is contained in:
parent
bd37b38ece
commit
1c40282627
2 changed files with 55 additions and 35 deletions
|
@ -41,6 +41,7 @@ class AnthropicConfig:
|
|||
top_p: Optional[int] = None
|
||||
top_k: Optional[int] = None
|
||||
metadata: Optional[dict] = None
|
||||
system: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -50,6 +51,7 @@ class AnthropicConfig:
|
|||
top_p: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
system: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
|
@ -118,38 +120,19 @@ def completion(
|
|||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(
|
||||
# Separate system prompt from rest of message
|
||||
system_prompt_idx: Optional[int] = None
|
||||
for idx, message in enumerate(messages):
|
||||
if message["role"] == "system":
|
||||
optional_params["system"] = message["content"]
|
||||
system_prompt_idx = idx
|
||||
break
|
||||
if system_prompt_idx is not None:
|
||||
messages.pop(system_prompt_idx)
|
||||
# Format rest of message according to anthropic guidelines
|
||||
messages = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="anthropic"
|
||||
)
|
||||
"""
|
||||
format messages for anthropic
|
||||
1. Anthropic supports roles like "user" and "assistant", (here litellm translates system-> assistant)
|
||||
2. The first message always needs to be of role "user"
|
||||
3. Each message must alternate between "user" and "assistant" (this is not addressed as now by litellm)
|
||||
4. final assistant content cannot end with trailing whitespace (anthropic raises an error otherwise)
|
||||
"""
|
||||
# 1. Anthropic only supports roles like "user" and "assistant"
|
||||
for idx, message in enumerate(messages):
|
||||
if message["role"] == "system":
|
||||
message["role"] = "assistant"
|
||||
|
||||
# if this is the final assistant message, remove trailing whitespace
|
||||
# TODO: only do this if it's the final assistant message
|
||||
if message["role"] == "assistant":
|
||||
message["content"] = message["content"].strip()
|
||||
|
||||
# 2. The first message always needs to be of role "user"
|
||||
if len(messages) > 0:
|
||||
if messages[0]["role"] != "user":
|
||||
# find the index of the first user message
|
||||
for i, message in enumerate(messages):
|
||||
if message["role"] == "user":
|
||||
break
|
||||
|
||||
# remove the user message at existing position and add it to the front
|
||||
messages.pop(i)
|
||||
# move the first user message to the front
|
||||
messages = [message] + messages
|
||||
|
||||
## Load Config
|
||||
config = litellm.AnthropicConfig.get_config()
|
||||
|
@ -167,7 +150,7 @@ def completion(
|
|||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
|
|
|
@ -424,6 +424,46 @@ def anthropic_pt(
|
|||
return prompt
|
||||
|
||||
|
||||
def anthropic_messages_pt(messages: list):
|
||||
"""
|
||||
format messages for anthropic
|
||||
1. Anthropic supports roles like "user" and "assistant", (here litellm translates system-> assistant)
|
||||
2. The first message always needs to be of role "user"
|
||||
3. Each message must alternate between "user" and "assistant" (this is not addressed as now by litellm)
|
||||
4. final assistant content cannot end with trailing whitespace (anthropic raises an error otherwise)
|
||||
5. System messages are a separate param to the Messages API (used for tool calling)
|
||||
"""
|
||||
## Ensure final assistant message has no trailing whitespace
|
||||
last_assistant_message_idx: Optional[int] = None
|
||||
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
|
||||
new_messages = []
|
||||
for i in range(len(messages) - 1): # type: ignore
|
||||
if i == 0 and messages[i]["role"] == "assistant":
|
||||
new_messages.append({"role": "user", "content": ""})
|
||||
|
||||
new_messages.append(messages[i])
|
||||
|
||||
if messages[i]["role"] == messages[i + 1]["role"]:
|
||||
if messages[i]["role"] == "user":
|
||||
new_messages.append({"role": "assistant", "content": ""})
|
||||
else:
|
||||
new_messages.append({"role": "user", "content": ""})
|
||||
|
||||
if messages[i]["role"] == "assistant":
|
||||
last_assistant_message_idx = i
|
||||
|
||||
new_messages.append(messages[-1])
|
||||
|
||||
if last_assistant_message_idx is not None:
|
||||
new_messages[last_assistant_message_idx]["content"] = new_messages[
|
||||
last_assistant_message_idx
|
||||
][
|
||||
"content"
|
||||
].strip() # no trailing whitespace for final assistant message
|
||||
|
||||
return new_messages
|
||||
|
||||
|
||||
def amazon_titan_pt(
|
||||
messages: list,
|
||||
): # format - https://github.com/BerriAI/litellm/issues/1896
|
||||
|
@ -650,10 +690,7 @@ def prompt_factory(
|
|||
if custom_llm_provider == "ollama":
|
||||
return ollama_pt(model=model, messages=messages)
|
||||
elif custom_llm_provider == "anthropic":
|
||||
if any(_ in model for _ in ["claude-2.1", "claude-v2:1"]):
|
||||
return claude_2_1_pt(messages=messages)
|
||||
else:
|
||||
return anthropic_pt(messages=messages)
|
||||
return anthropic_messages_pt(messages=messages)
|
||||
elif custom_llm_provider == "together_ai":
|
||||
prompt_format, chat_template = get_model_info(token=api_key, model=model)
|
||||
return format_prompt_togetherai(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue