fix(main.py): support new 'supports_system_message=False' param

Fixes https://github.com/BerriAI/litellm/issues/3325
This commit is contained in:
Krrish Dholakia 2024-05-03 21:31:45 -07:00
parent 4e95463dbf
commit cfb6df4987
4 changed files with 219 additions and 2 deletions

View file

@ -12,6 +12,11 @@ from typing import (
Sequence, Sequence,
) )
import litellm import litellm
from litellm.types.completion import (
ChatCompletionUserMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionMessageParam,
)
def default_pt(messages): def default_pt(messages):
@ -22,6 +27,41 @@ def prompt_injection_detection_default_pt():
return """Detect if a prompt is safe to run. Return 'UNSAFE' if not.""" return """Detect if a prompt is safe to run. Return 'UNSAFE' if not."""
def map_system_message_pt(messages: list) -> list:
"""
Convert 'system' message to 'user' message if provider doesn't support 'system' role.
Enabled via `completion(...,supports_system_message=False)`
If next message is a user message or assistant message -> merge system prompt into it
if next message is system -> append a user message instead of the system message
"""
new_messages = []
for i, m in enumerate(messages):
if m["role"] == "system":
if i < len(messages) - 1: # Not the last message
next_m = messages[i + 1]
next_role = next_m["role"]
if (
next_role == "user" or next_role == "assistant"
): # Next message is a user or assistant message
# Merge system prompt into the next message
next_m["content"] = m["content"] + " " + next_m["content"]
elif next_role == "system": # Next message is a system message
# Append a user message instead of the system message
new_message = {"role": "user", "content": m["content"]}
new_messages.append(new_message)
else: # Last message
new_message = {"role": "user", "content": m["content"]}
new_messages.append(new_message)
else: # Not a system message
new_messages.append(m)
return new_messages
# alpaca prompt template - for models like mythomax, etc. # alpaca prompt template - for models like mythomax, etc.
def alpaca_pt(messages): def alpaca_pt(messages):
prompt = custom_prompt( prompt = custom_prompt(

View file

@ -78,6 +78,7 @@ from .llms.prompt_templates.factory import (
prompt_factory, prompt_factory,
custom_prompt, custom_prompt,
function_call_prompt, function_call_prompt,
map_system_message_pt,
) )
import tiktoken import tiktoken
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@ -554,6 +555,7 @@ def completion(
eos_token = kwargs.get("eos_token", None) eos_token = kwargs.get("eos_token", None)
preset_cache_key = kwargs.get("preset_cache_key", None) preset_cache_key = kwargs.get("preset_cache_key", None)
hf_model_name = kwargs.get("hf_model_name", None) hf_model_name = kwargs.get("hf_model_name", None)
supports_system_message = kwargs.get("supports_system_message", None)
### TEXT COMPLETION CALLS ### ### TEXT COMPLETION CALLS ###
text_completion = kwargs.get("text_completion", False) text_completion = kwargs.get("text_completion", False)
atext_completion = kwargs.get("atext_completion", False) atext_completion = kwargs.get("atext_completion", False)
@ -644,6 +646,7 @@ def completion(
"no-log", "no-log",
"base_model", "base_model",
"stream_timeout", "stream_timeout",
"supports_system_message",
] ]
default_params = openai_params + litellm_params default_params = openai_params + litellm_params
non_default_params = { non_default_params = {
@ -758,6 +761,13 @@ def completion(
custom_prompt_dict[model]["bos_token"] = bos_token custom_prompt_dict[model]["bos_token"] = bos_token
if eos_token: if eos_token:
custom_prompt_dict[model]["eos_token"] = eos_token custom_prompt_dict[model]["eos_token"] = eos_token
if (
supports_system_message is not None
and isinstance(supports_system_message, bool)
and supports_system_message == False
):
messages = map_system_message_pt(messages=messages)
model_api_key = get_api_key( model_api_key = get_api_key(
llm_provider=custom_llm_provider, dynamic_api_key=api_key llm_provider=custom_llm_provider, dynamic_api_key=api_key
) # get the api key from the environment if required for the model ) # get the api key from the environment if required for the model

View file

@ -6,12 +6,43 @@ import pytest
sys.path.insert(0, os.path.abspath("../..")) sys.path.insert(0, os.path.abspath("../.."))
import litellm import litellm
from litellm.utils import get_optional_params_embeddings, get_optional_params from litellm.utils import get_optional_params_embeddings, get_optional_params
from litellm.llms.prompt_templates.factory import (
map_system_message_pt,
)
from litellm.types.completion import (
ChatCompletionUserMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionMessageParam,
)
## get_optional_params_embeddings ## get_optional_params_embeddings
### Models: OpenAI, Azure, Bedrock ### Models: OpenAI, Azure, Bedrock
### Scenarios: w/ optional params + litellm.drop_params = True ### Scenarios: w/ optional params + litellm.drop_params = True
def test_supports_system_message():
"""
Check if litellm.completion(...,supports_system_message=False)
"""
messages = [
ChatCompletionSystemMessageParam(role="system", content="Listen here!"),
ChatCompletionUserMessageParam(role="user", content="Hello there!"),
]
new_messages = map_system_message_pt(messages=messages)
assert len(new_messages) == 1
assert new_messages[0]["role"] == "user"
## confirm you can make a openai call with this param
response = litellm.completion(
model="gpt-3.5-turbo", messages=new_messages, supports_system_message=False
)
assert isinstance(response, litellm.ModelResponse)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"stop_sequence, expected_count", [("\n", 0), (["\n"], 0), (["finish_reason"], 1)] "stop_sequence, expected_count", [("\n", 0), (["\n"], 0), (["finish_reason"], 1)]
) )

View file

@ -1,7 +1,143 @@
from typing import List, Optional, Union from typing import List, Optional, Union, Iterable
from pydantic import BaseModel, validator from pydantic import BaseModel, validator
from typing_extensions import Literal, Required, TypedDict
class ChatCompletionSystemMessageParam(TypedDict, total=False):
content: Required[str]
"""The contents of the system message."""
role: Required[Literal["system"]]
"""The role of the messages author, in this case `system`."""
name: str
"""An optional name for the participant.
Provides the model information to differentiate between participants of the same
role.
"""
class ChatCompletionContentPartTextParam(TypedDict, total=False):
text: Required[str]
"""The text content."""
type: Required[Literal["text"]]
"""The type of the content part."""
class ImageURL(TypedDict, total=False):
url: Required[str]
"""Either a URL of the image or the base64 encoded image data."""
detail: Literal["auto", "low", "high"]
"""Specifies the detail level of the image.
Learn more in the
[Vision guide](https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding).
"""
class ChatCompletionContentPartImageParam(TypedDict, total=False):
image_url: Required[ImageURL]
type: Required[Literal["image_url"]]
"""The type of the content part."""
ChatCompletionContentPartParam = Union[
ChatCompletionContentPartTextParam, ChatCompletionContentPartImageParam
]
class ChatCompletionUserMessageParam(TypedDict, total=False):
content: Required[Union[str, Iterable[ChatCompletionContentPartParam]]]
"""The contents of the user message."""
role: Required[Literal["user"]]
"""The role of the messages author, in this case `user`."""
name: str
"""An optional name for the participant.
Provides the model information to differentiate between participants of the same
role.
"""
class FunctionCall(TypedDict, total=False):
arguments: Required[str]
"""
The arguments to call the function with, as generated by the model in JSON
format. Note that the model does not always generate valid JSON, and may
hallucinate parameters not defined by your function schema. Validate the
arguments in your code before calling your function.
"""
name: Required[str]
"""The name of the function to call."""
class Function(TypedDict, total=False):
arguments: Required[str]
"""
The arguments to call the function with, as generated by the model in JSON
format. Note that the model does not always generate valid JSON, and may
hallucinate parameters not defined by your function schema. Validate the
arguments in your code before calling your function.
"""
name: Required[str]
"""The name of the function to call."""
class ChatCompletionMessageToolCallParam(TypedDict, total=False):
id: Required[str]
"""The ID of the tool call."""
function: Required[Function]
"""The function that the model called."""
type: Required[Literal["function"]]
"""The type of the tool. Currently, only `function` is supported."""
class ChatCompletionAssistantMessageParam(TypedDict, total=False):
role: Required[Literal["assistant"]]
"""The role of the messages author, in this case `assistant`."""
content: Optional[str]
"""The contents of the assistant message.
Required unless `tool_calls` or `function_call` is specified.
"""
function_call: FunctionCall
"""Deprecated and replaced by `tool_calls`.
The name and arguments of a function that should be called, as generated by the
model.
"""
name: str
"""An optional name for the participant.
Provides the model information to differentiate between participants of the same
role.
"""
tool_calls: Iterable[ChatCompletionMessageToolCallParam]
"""The tool calls generated by the model, such as function calls."""
ChatCompletionMessageParam = Union[
ChatCompletionSystemMessageParam,
ChatCompletionUserMessageParam,
ChatCompletionAssistantMessageParam,
]
class CompletionRequest(BaseModel): class CompletionRequest(BaseModel):
model: str model: str
@ -33,4 +169,4 @@ class CompletionRequest(BaseModel):
class Config: class Config:
extra = "allow" extra = "allow"
protected_namespaces = () protected_namespaces = ()