forked from phoenix/litellm-mirror
fix(anthropic.py): support anthropic system prompt
This commit is contained in:
parent
bd37b38ece
commit
1c40282627
2 changed files with 55 additions and 35 deletions
|
@ -41,6 +41,7 @@ class AnthropicConfig:
|
||||||
top_p: Optional[int] = None
|
top_p: Optional[int] = None
|
||||||
top_k: Optional[int] = None
|
top_k: Optional[int] = None
|
||||||
metadata: Optional[dict] = None
|
metadata: Optional[dict] = None
|
||||||
|
system: Optional[str] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -50,6 +51,7 @@ class AnthropicConfig:
|
||||||
top_p: Optional[int] = None,
|
top_p: Optional[int] = None,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
metadata: Optional[dict] = None,
|
metadata: Optional[dict] = None,
|
||||||
|
system: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
locals_ = locals()
|
locals_ = locals()
|
||||||
for key, value in locals_.items():
|
for key, value in locals_.items():
|
||||||
|
@ -118,38 +120,19 @@ def completion(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prompt = prompt_factory(
|
# Separate system prompt from rest of message
|
||||||
|
system_prompt_idx: Optional[int] = None
|
||||||
|
for idx, message in enumerate(messages):
|
||||||
|
if message["role"] == "system":
|
||||||
|
optional_params["system"] = message["content"]
|
||||||
|
system_prompt_idx = idx
|
||||||
|
break
|
||||||
|
if system_prompt_idx is not None:
|
||||||
|
messages.pop(system_prompt_idx)
|
||||||
|
# Format rest of message according to anthropic guidelines
|
||||||
|
messages = prompt_factory(
|
||||||
model=model, messages=messages, custom_llm_provider="anthropic"
|
model=model, messages=messages, custom_llm_provider="anthropic"
|
||||||
)
|
)
|
||||||
"""
|
|
||||||
format messages for anthropic
|
|
||||||
1. Anthropic supports roles like "user" and "assistant", (here litellm translates system-> assistant)
|
|
||||||
2. The first message always needs to be of role "user"
|
|
||||||
3. Each message must alternate between "user" and "assistant" (this is not addressed as now by litellm)
|
|
||||||
4. final assistant content cannot end with trailing whitespace (anthropic raises an error otherwise)
|
|
||||||
"""
|
|
||||||
# 1. Anthropic only supports roles like "user" and "assistant"
|
|
||||||
for idx, message in enumerate(messages):
|
|
||||||
if message["role"] == "system":
|
|
||||||
message["role"] = "assistant"
|
|
||||||
|
|
||||||
# if this is the final assistant message, remove trailing whitespace
|
|
||||||
# TODO: only do this if it's the final assistant message
|
|
||||||
if message["role"] == "assistant":
|
|
||||||
message["content"] = message["content"].strip()
|
|
||||||
|
|
||||||
# 2. The first message always needs to be of role "user"
|
|
||||||
if len(messages) > 0:
|
|
||||||
if messages[0]["role"] != "user":
|
|
||||||
# find the index of the first user message
|
|
||||||
for i, message in enumerate(messages):
|
|
||||||
if message["role"] == "user":
|
|
||||||
break
|
|
||||||
|
|
||||||
# remove the user message at existing position and add it to the front
|
|
||||||
messages.pop(i)
|
|
||||||
# move the first user message to the front
|
|
||||||
messages = [message] + messages
|
|
||||||
|
|
||||||
## Load Config
|
## Load Config
|
||||||
config = litellm.AnthropicConfig.get_config()
|
config = litellm.AnthropicConfig.get_config()
|
||||||
|
@ -167,7 +150,7 @@ def completion(
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=prompt,
|
input=messages,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
additional_args={
|
additional_args={
|
||||||
"complete_input_dict": data,
|
"complete_input_dict": data,
|
||||||
|
|
|
@ -424,6 +424,46 @@ def anthropic_pt(
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def anthropic_messages_pt(messages: list):
|
||||||
|
"""
|
||||||
|
format messages for anthropic
|
||||||
|
1. Anthropic supports roles like "user" and "assistant", (here litellm translates system-> assistant)
|
||||||
|
2. The first message always needs to be of role "user"
|
||||||
|
3. Each message must alternate between "user" and "assistant" (this is not addressed as now by litellm)
|
||||||
|
4. final assistant content cannot end with trailing whitespace (anthropic raises an error otherwise)
|
||||||
|
5. System messages are a separate param to the Messages API (used for tool calling)
|
||||||
|
"""
|
||||||
|
## Ensure final assistant message has no trailing whitespace
|
||||||
|
last_assistant_message_idx: Optional[int] = None
|
||||||
|
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
|
||||||
|
new_messages = []
|
||||||
|
for i in range(len(messages) - 1): # type: ignore
|
||||||
|
if i == 0 and messages[i]["role"] == "assistant":
|
||||||
|
new_messages.append({"role": "user", "content": ""})
|
||||||
|
|
||||||
|
new_messages.append(messages[i])
|
||||||
|
|
||||||
|
if messages[i]["role"] == messages[i + 1]["role"]:
|
||||||
|
if messages[i]["role"] == "user":
|
||||||
|
new_messages.append({"role": "assistant", "content": ""})
|
||||||
|
else:
|
||||||
|
new_messages.append({"role": "user", "content": ""})
|
||||||
|
|
||||||
|
if messages[i]["role"] == "assistant":
|
||||||
|
last_assistant_message_idx = i
|
||||||
|
|
||||||
|
new_messages.append(messages[-1])
|
||||||
|
|
||||||
|
if last_assistant_message_idx is not None:
|
||||||
|
new_messages[last_assistant_message_idx]["content"] = new_messages[
|
||||||
|
last_assistant_message_idx
|
||||||
|
][
|
||||||
|
"content"
|
||||||
|
].strip() # no trailing whitespace for final assistant message
|
||||||
|
|
||||||
|
return new_messages
|
||||||
|
|
||||||
|
|
||||||
def amazon_titan_pt(
|
def amazon_titan_pt(
|
||||||
messages: list,
|
messages: list,
|
||||||
): # format - https://github.com/BerriAI/litellm/issues/1896
|
): # format - https://github.com/BerriAI/litellm/issues/1896
|
||||||
|
@ -650,10 +690,7 @@ def prompt_factory(
|
||||||
if custom_llm_provider == "ollama":
|
if custom_llm_provider == "ollama":
|
||||||
return ollama_pt(model=model, messages=messages)
|
return ollama_pt(model=model, messages=messages)
|
||||||
elif custom_llm_provider == "anthropic":
|
elif custom_llm_provider == "anthropic":
|
||||||
if any(_ in model for _ in ["claude-2.1", "claude-v2:1"]):
|
return anthropic_messages_pt(messages=messages)
|
||||||
return claude_2_1_pt(messages=messages)
|
|
||||||
else:
|
|
||||||
return anthropic_pt(messages=messages)
|
|
||||||
elif custom_llm_provider == "together_ai":
|
elif custom_llm_provider == "together_ai":
|
||||||
prompt_format, chat_template = get_model_info(token=api_key, model=model)
|
prompt_format, chat_template = get_model_info(token=api_key, model=model)
|
||||||
return format_prompt_togetherai(
|
return format_prompt_togetherai(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue