mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
fix(main.py): support new 'supports_system_message=False' param
Fixes https://github.com/BerriAI/litellm/issues/3325
This commit is contained in:
parent
4e95463dbf
commit
cfb6df4987
4 changed files with 219 additions and 2 deletions
|
@ -12,6 +12,11 @@ from typing import (
|
||||||
Sequence,
|
Sequence,
|
||||||
)
|
)
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.types.completion import (
|
||||||
|
ChatCompletionUserMessageParam,
|
||||||
|
ChatCompletionSystemMessageParam,
|
||||||
|
ChatCompletionMessageParam,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def default_pt(messages):
|
def default_pt(messages):
|
||||||
|
@ -22,6 +27,41 @@ def prompt_injection_detection_default_pt():
|
||||||
return """Detect if a prompt is safe to run. Return 'UNSAFE' if not."""
|
return """Detect if a prompt is safe to run. Return 'UNSAFE' if not."""
|
||||||
|
|
||||||
|
|
||||||
|
def map_system_message_pt(messages: list) -> list:
|
||||||
|
"""
|
||||||
|
Convert 'system' message to 'user' message if provider doesn't support 'system' role.
|
||||||
|
|
||||||
|
Enabled via `completion(...,supports_system_message=False)`
|
||||||
|
|
||||||
|
If next message is a user message or assistant message -> merge system prompt into it
|
||||||
|
|
||||||
|
if next message is system -> append a user message instead of the system message
|
||||||
|
"""
|
||||||
|
|
||||||
|
new_messages = []
|
||||||
|
for i, m in enumerate(messages):
|
||||||
|
if m["role"] == "system":
|
||||||
|
if i < len(messages) - 1: # Not the last message
|
||||||
|
next_m = messages[i + 1]
|
||||||
|
next_role = next_m["role"]
|
||||||
|
if (
|
||||||
|
next_role == "user" or next_role == "assistant"
|
||||||
|
): # Next message is a user or assistant message
|
||||||
|
# Merge system prompt into the next message
|
||||||
|
next_m["content"] = m["content"] + " " + next_m["content"]
|
||||||
|
elif next_role == "system": # Next message is a system message
|
||||||
|
# Append a user message instead of the system message
|
||||||
|
new_message = {"role": "user", "content": m["content"]}
|
||||||
|
new_messages.append(new_message)
|
||||||
|
else: # Last message
|
||||||
|
new_message = {"role": "user", "content": m["content"]}
|
||||||
|
new_messages.append(new_message)
|
||||||
|
else: # Not a system message
|
||||||
|
new_messages.append(m)
|
||||||
|
|
||||||
|
return new_messages
|
||||||
|
|
||||||
|
|
||||||
# alpaca prompt template - for models like mythomax, etc.
|
# alpaca prompt template - for models like mythomax, etc.
|
||||||
def alpaca_pt(messages):
|
def alpaca_pt(messages):
|
||||||
prompt = custom_prompt(
|
prompt = custom_prompt(
|
||||||
|
|
|
@ -78,6 +78,7 @@ from .llms.prompt_templates.factory import (
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
custom_prompt,
|
custom_prompt,
|
||||||
function_call_prompt,
|
function_call_prompt,
|
||||||
|
map_system_message_pt,
|
||||||
)
|
)
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
@ -554,6 +555,7 @@ def completion(
|
||||||
eos_token = kwargs.get("eos_token", None)
|
eos_token = kwargs.get("eos_token", None)
|
||||||
preset_cache_key = kwargs.get("preset_cache_key", None)
|
preset_cache_key = kwargs.get("preset_cache_key", None)
|
||||||
hf_model_name = kwargs.get("hf_model_name", None)
|
hf_model_name = kwargs.get("hf_model_name", None)
|
||||||
|
supports_system_message = kwargs.get("supports_system_message", None)
|
||||||
### TEXT COMPLETION CALLS ###
|
### TEXT COMPLETION CALLS ###
|
||||||
text_completion = kwargs.get("text_completion", False)
|
text_completion = kwargs.get("text_completion", False)
|
||||||
atext_completion = kwargs.get("atext_completion", False)
|
atext_completion = kwargs.get("atext_completion", False)
|
||||||
|
@ -644,6 +646,7 @@ def completion(
|
||||||
"no-log",
|
"no-log",
|
||||||
"base_model",
|
"base_model",
|
||||||
"stream_timeout",
|
"stream_timeout",
|
||||||
|
"supports_system_message",
|
||||||
]
|
]
|
||||||
default_params = openai_params + litellm_params
|
default_params = openai_params + litellm_params
|
||||||
non_default_params = {
|
non_default_params = {
|
||||||
|
@ -758,6 +761,13 @@ def completion(
|
||||||
custom_prompt_dict[model]["bos_token"] = bos_token
|
custom_prompt_dict[model]["bos_token"] = bos_token
|
||||||
if eos_token:
|
if eos_token:
|
||||||
custom_prompt_dict[model]["eos_token"] = eos_token
|
custom_prompt_dict[model]["eos_token"] = eos_token
|
||||||
|
|
||||||
|
if (
|
||||||
|
supports_system_message is not None
|
||||||
|
and isinstance(supports_system_message, bool)
|
||||||
|
and supports_system_message == False
|
||||||
|
):
|
||||||
|
messages = map_system_message_pt(messages=messages)
|
||||||
model_api_key = get_api_key(
|
model_api_key = get_api_key(
|
||||||
llm_provider=custom_llm_provider, dynamic_api_key=api_key
|
llm_provider=custom_llm_provider, dynamic_api_key=api_key
|
||||||
) # get the api key from the environment if required for the model
|
) # get the api key from the environment if required for the model
|
||||||
|
|
|
@ -6,12 +6,43 @@ import pytest
|
||||||
sys.path.insert(0, os.path.abspath("../.."))
|
sys.path.insert(0, os.path.abspath("../.."))
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.utils import get_optional_params_embeddings, get_optional_params
|
from litellm.utils import get_optional_params_embeddings, get_optional_params
|
||||||
|
from litellm.llms.prompt_templates.factory import (
|
||||||
|
map_system_message_pt,
|
||||||
|
)
|
||||||
|
from litellm.types.completion import (
|
||||||
|
ChatCompletionUserMessageParam,
|
||||||
|
ChatCompletionSystemMessageParam,
|
||||||
|
ChatCompletionMessageParam,
|
||||||
|
)
|
||||||
|
|
||||||
## get_optional_params_embeddings
|
## get_optional_params_embeddings
|
||||||
### Models: OpenAI, Azure, Bedrock
|
### Models: OpenAI, Azure, Bedrock
|
||||||
### Scenarios: w/ optional params + litellm.drop_params = True
|
### Scenarios: w/ optional params + litellm.drop_params = True
|
||||||
|
|
||||||
|
|
||||||
|
def test_supports_system_message():
|
||||||
|
"""
|
||||||
|
Check if litellm.completion(...,supports_system_message=False)
|
||||||
|
"""
|
||||||
|
messages = [
|
||||||
|
ChatCompletionSystemMessageParam(role="system", content="Listen here!"),
|
||||||
|
ChatCompletionUserMessageParam(role="user", content="Hello there!"),
|
||||||
|
]
|
||||||
|
|
||||||
|
new_messages = map_system_message_pt(messages=messages)
|
||||||
|
|
||||||
|
assert len(new_messages) == 1
|
||||||
|
assert new_messages[0]["role"] == "user"
|
||||||
|
|
||||||
|
## confirm you can make a openai call with this param
|
||||||
|
|
||||||
|
response = litellm.completion(
|
||||||
|
model="gpt-3.5-turbo", messages=new_messages, supports_system_message=False
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, litellm.ModelResponse)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"stop_sequence, expected_count", [("\n", 0), (["\n"], 0), (["finish_reason"], 1)]
|
"stop_sequence, expected_count", [("\n", 0), (["\n"], 0), (["finish_reason"], 1)]
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,7 +1,143 @@
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union, Iterable
|
||||||
|
|
||||||
from pydantic import BaseModel, validator
|
from pydantic import BaseModel, validator
|
||||||
|
|
||||||
|
from typing_extensions import Literal, Required, TypedDict
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionSystemMessageParam(TypedDict, total=False):
|
||||||
|
content: Required[str]
|
||||||
|
"""The contents of the system message."""
|
||||||
|
|
||||||
|
role: Required[Literal["system"]]
|
||||||
|
"""The role of the messages author, in this case `system`."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
"""An optional name for the participant.
|
||||||
|
|
||||||
|
Provides the model information to differentiate between participants of the same
|
||||||
|
role.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionContentPartTextParam(TypedDict, total=False):
|
||||||
|
text: Required[str]
|
||||||
|
"""The text content."""
|
||||||
|
|
||||||
|
type: Required[Literal["text"]]
|
||||||
|
"""The type of the content part."""
|
||||||
|
|
||||||
|
|
||||||
|
class ImageURL(TypedDict, total=False):
|
||||||
|
url: Required[str]
|
||||||
|
"""Either a URL of the image or the base64 encoded image data."""
|
||||||
|
|
||||||
|
detail: Literal["auto", "low", "high"]
|
||||||
|
"""Specifies the detail level of the image.
|
||||||
|
|
||||||
|
Learn more in the
|
||||||
|
[Vision guide](https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding).
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionContentPartImageParam(TypedDict, total=False):
|
||||||
|
image_url: Required[ImageURL]
|
||||||
|
|
||||||
|
type: Required[Literal["image_url"]]
|
||||||
|
"""The type of the content part."""
|
||||||
|
|
||||||
|
|
||||||
|
ChatCompletionContentPartParam = Union[
|
||||||
|
ChatCompletionContentPartTextParam, ChatCompletionContentPartImageParam
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionUserMessageParam(TypedDict, total=False):
|
||||||
|
content: Required[Union[str, Iterable[ChatCompletionContentPartParam]]]
|
||||||
|
"""The contents of the user message."""
|
||||||
|
|
||||||
|
role: Required[Literal["user"]]
|
||||||
|
"""The role of the messages author, in this case `user`."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
"""An optional name for the participant.
|
||||||
|
|
||||||
|
Provides the model information to differentiate between participants of the same
|
||||||
|
role.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionCall(TypedDict, total=False):
|
||||||
|
arguments: Required[str]
|
||||||
|
"""
|
||||||
|
The arguments to call the function with, as generated by the model in JSON
|
||||||
|
format. Note that the model does not always generate valid JSON, and may
|
||||||
|
hallucinate parameters not defined by your function schema. Validate the
|
||||||
|
arguments in your code before calling your function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: Required[str]
|
||||||
|
"""The name of the function to call."""
|
||||||
|
|
||||||
|
|
||||||
|
class Function(TypedDict, total=False):
|
||||||
|
arguments: Required[str]
|
||||||
|
"""
|
||||||
|
The arguments to call the function with, as generated by the model in JSON
|
||||||
|
format. Note that the model does not always generate valid JSON, and may
|
||||||
|
hallucinate parameters not defined by your function schema. Validate the
|
||||||
|
arguments in your code before calling your function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: Required[str]
|
||||||
|
"""The name of the function to call."""
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionMessageToolCallParam(TypedDict, total=False):
|
||||||
|
id: Required[str]
|
||||||
|
"""The ID of the tool call."""
|
||||||
|
|
||||||
|
function: Required[Function]
|
||||||
|
"""The function that the model called."""
|
||||||
|
|
||||||
|
type: Required[Literal["function"]]
|
||||||
|
"""The type of the tool. Currently, only `function` is supported."""
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionAssistantMessageParam(TypedDict, total=False):
|
||||||
|
role: Required[Literal["assistant"]]
|
||||||
|
"""The role of the messages author, in this case `assistant`."""
|
||||||
|
|
||||||
|
content: Optional[str]
|
||||||
|
"""The contents of the assistant message.
|
||||||
|
|
||||||
|
Required unless `tool_calls` or `function_call` is specified.
|
||||||
|
"""
|
||||||
|
|
||||||
|
function_call: FunctionCall
|
||||||
|
"""Deprecated and replaced by `tool_calls`.
|
||||||
|
|
||||||
|
The name and arguments of a function that should be called, as generated by the
|
||||||
|
model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
"""An optional name for the participant.
|
||||||
|
|
||||||
|
Provides the model information to differentiate between participants of the same
|
||||||
|
role.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tool_calls: Iterable[ChatCompletionMessageToolCallParam]
|
||||||
|
"""The tool calls generated by the model, such as function calls."""
|
||||||
|
|
||||||
|
|
||||||
|
ChatCompletionMessageParam = Union[
|
||||||
|
ChatCompletionSystemMessageParam,
|
||||||
|
ChatCompletionUserMessageParam,
|
||||||
|
ChatCompletionAssistantMessageParam,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class CompletionRequest(BaseModel):
|
class CompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
|
@ -33,4 +169,4 @@ class CompletionRequest(BaseModel):
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
extra = "allow"
|
extra = "allow"
|
||||||
protected_namespaces = ()
|
protected_namespaces = ()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue