mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
fix(main.py): support new 'supports_system_message=False' param
Fixes https://github.com/BerriAI/litellm/issues/3325
This commit is contained in:
parent
4e95463dbf
commit
cfb6df4987
4 changed files with 219 additions and 2 deletions
|
@ -12,6 +12,11 @@ from typing import (
|
|||
Sequence,
|
||||
)
|
||||
import litellm
|
||||
from litellm.types.completion import (
|
||||
ChatCompletionUserMessageParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionMessageParam,
|
||||
)
|
||||
|
||||
|
||||
def default_pt(messages):
|
||||
|
@ -22,6 +27,41 @@ def prompt_injection_detection_default_pt():
|
|||
return """Detect if a prompt is safe to run. Return 'UNSAFE' if not."""
|
||||
|
||||
|
||||
def map_system_message_pt(messages: list) -> list:
|
||||
"""
|
||||
Convert 'system' message to 'user' message if provider doesn't support 'system' role.
|
||||
|
||||
Enabled via `completion(...,supports_system_message=False)`
|
||||
|
||||
If next message is a user message or assistant message -> merge system prompt into it
|
||||
|
||||
if next message is system -> append a user message instead of the system message
|
||||
"""
|
||||
|
||||
new_messages = []
|
||||
for i, m in enumerate(messages):
|
||||
if m["role"] == "system":
|
||||
if i < len(messages) - 1: # Not the last message
|
||||
next_m = messages[i + 1]
|
||||
next_role = next_m["role"]
|
||||
if (
|
||||
next_role == "user" or next_role == "assistant"
|
||||
): # Next message is a user or assistant message
|
||||
# Merge system prompt into the next message
|
||||
next_m["content"] = m["content"] + " " + next_m["content"]
|
||||
elif next_role == "system": # Next message is a system message
|
||||
# Append a user message instead of the system message
|
||||
new_message = {"role": "user", "content": m["content"]}
|
||||
new_messages.append(new_message)
|
||||
else: # Last message
|
||||
new_message = {"role": "user", "content": m["content"]}
|
||||
new_messages.append(new_message)
|
||||
else: # Not a system message
|
||||
new_messages.append(m)
|
||||
|
||||
return new_messages
|
||||
|
||||
|
||||
# alpaca prompt template - for models like mythomax, etc.
|
||||
def alpaca_pt(messages):
|
||||
prompt = custom_prompt(
|
||||
|
|
|
@ -78,6 +78,7 @@ from .llms.prompt_templates.factory import (
|
|||
prompt_factory,
|
||||
custom_prompt,
|
||||
function_call_prompt,
|
||||
map_system_message_pt,
|
||||
)
|
||||
import tiktoken
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
@ -554,6 +555,7 @@ def completion(
|
|||
eos_token = kwargs.get("eos_token", None)
|
||||
preset_cache_key = kwargs.get("preset_cache_key", None)
|
||||
hf_model_name = kwargs.get("hf_model_name", None)
|
||||
supports_system_message = kwargs.get("supports_system_message", None)
|
||||
### TEXT COMPLETION CALLS ###
|
||||
text_completion = kwargs.get("text_completion", False)
|
||||
atext_completion = kwargs.get("atext_completion", False)
|
||||
|
@ -644,6 +646,7 @@ def completion(
|
|||
"no-log",
|
||||
"base_model",
|
||||
"stream_timeout",
|
||||
"supports_system_message",
|
||||
]
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {
|
||||
|
@ -758,6 +761,13 @@ def completion(
|
|||
custom_prompt_dict[model]["bos_token"] = bos_token
|
||||
if eos_token:
|
||||
custom_prompt_dict[model]["eos_token"] = eos_token
|
||||
|
||||
if (
|
||||
supports_system_message is not None
|
||||
and isinstance(supports_system_message, bool)
|
||||
and supports_system_message == False
|
||||
):
|
||||
messages = map_system_message_pt(messages=messages)
|
||||
model_api_key = get_api_key(
|
||||
llm_provider=custom_llm_provider, dynamic_api_key=api_key
|
||||
) # get the api key from the environment if required for the model
|
||||
|
|
|
@ -6,12 +6,43 @@ import pytest
|
|||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
import litellm
|
||||
from litellm.utils import get_optional_params_embeddings, get_optional_params
|
||||
from litellm.llms.prompt_templates.factory import (
|
||||
map_system_message_pt,
|
||||
)
|
||||
from litellm.types.completion import (
|
||||
ChatCompletionUserMessageParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionMessageParam,
|
||||
)
|
||||
|
||||
## get_optional_params_embeddings
|
||||
### Models: OpenAI, Azure, Bedrock
|
||||
### Scenarios: w/ optional params + litellm.drop_params = True
|
||||
|
||||
|
||||
def test_supports_system_message():
|
||||
"""
|
||||
Check if litellm.completion(...,supports_system_message=False)
|
||||
"""
|
||||
messages = [
|
||||
ChatCompletionSystemMessageParam(role="system", content="Listen here!"),
|
||||
ChatCompletionUserMessageParam(role="user", content="Hello there!"),
|
||||
]
|
||||
|
||||
new_messages = map_system_message_pt(messages=messages)
|
||||
|
||||
assert len(new_messages) == 1
|
||||
assert new_messages[0]["role"] == "user"
|
||||
|
||||
## confirm you can make a openai call with this param
|
||||
|
||||
response = litellm.completion(
|
||||
model="gpt-3.5-turbo", messages=new_messages, supports_system_message=False
|
||||
)
|
||||
|
||||
assert isinstance(response, litellm.ModelResponse)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"stop_sequence, expected_count", [("\n", 0), (["\n"], 0), (["finish_reason"], 1)]
|
||||
)
|
||||
|
|
|
@ -1,7 +1,143 @@
|
|||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Union, Iterable
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
from typing_extensions import Literal, Required, TypedDict
|
||||
|
||||
|
||||
class ChatCompletionSystemMessageParam(TypedDict, total=False):
|
||||
content: Required[str]
|
||||
"""The contents of the system message."""
|
||||
|
||||
role: Required[Literal["system"]]
|
||||
"""The role of the messages author, in this case `system`."""
|
||||
|
||||
name: str
|
||||
"""An optional name for the participant.
|
||||
|
||||
Provides the model information to differentiate between participants of the same
|
||||
role.
|
||||
"""
|
||||
|
||||
|
||||
class ChatCompletionContentPartTextParam(TypedDict, total=False):
|
||||
text: Required[str]
|
||||
"""The text content."""
|
||||
|
||||
type: Required[Literal["text"]]
|
||||
"""The type of the content part."""
|
||||
|
||||
|
||||
class ImageURL(TypedDict, total=False):
|
||||
url: Required[str]
|
||||
"""Either a URL of the image or the base64 encoded image data."""
|
||||
|
||||
detail: Literal["auto", "low", "high"]
|
||||
"""Specifies the detail level of the image.
|
||||
|
||||
Learn more in the
|
||||
[Vision guide](https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding).
|
||||
"""
|
||||
|
||||
|
||||
class ChatCompletionContentPartImageParam(TypedDict, total=False):
|
||||
image_url: Required[ImageURL]
|
||||
|
||||
type: Required[Literal["image_url"]]
|
||||
"""The type of the content part."""
|
||||
|
||||
|
||||
ChatCompletionContentPartParam = Union[
|
||||
ChatCompletionContentPartTextParam, ChatCompletionContentPartImageParam
|
||||
]
|
||||
|
||||
|
||||
class ChatCompletionUserMessageParam(TypedDict, total=False):
|
||||
content: Required[Union[str, Iterable[ChatCompletionContentPartParam]]]
|
||||
"""The contents of the user message."""
|
||||
|
||||
role: Required[Literal["user"]]
|
||||
"""The role of the messages author, in this case `user`."""
|
||||
|
||||
name: str
|
||||
"""An optional name for the participant.
|
||||
|
||||
Provides the model information to differentiate between participants of the same
|
||||
role.
|
||||
"""
|
||||
|
||||
|
||||
class FunctionCall(TypedDict, total=False):
|
||||
arguments: Required[str]
|
||||
"""
|
||||
The arguments to call the function with, as generated by the model in JSON
|
||||
format. Note that the model does not always generate valid JSON, and may
|
||||
hallucinate parameters not defined by your function schema. Validate the
|
||||
arguments in your code before calling your function.
|
||||
"""
|
||||
|
||||
name: Required[str]
|
||||
"""The name of the function to call."""
|
||||
|
||||
|
||||
class Function(TypedDict, total=False):
|
||||
arguments: Required[str]
|
||||
"""
|
||||
The arguments to call the function with, as generated by the model in JSON
|
||||
format. Note that the model does not always generate valid JSON, and may
|
||||
hallucinate parameters not defined by your function schema. Validate the
|
||||
arguments in your code before calling your function.
|
||||
"""
|
||||
|
||||
name: Required[str]
|
||||
"""The name of the function to call."""
|
||||
|
||||
|
||||
class ChatCompletionMessageToolCallParam(TypedDict, total=False):
|
||||
id: Required[str]
|
||||
"""The ID of the tool call."""
|
||||
|
||||
function: Required[Function]
|
||||
"""The function that the model called."""
|
||||
|
||||
type: Required[Literal["function"]]
|
||||
"""The type of the tool. Currently, only `function` is supported."""
|
||||
|
||||
|
||||
class ChatCompletionAssistantMessageParam(TypedDict, total=False):
|
||||
role: Required[Literal["assistant"]]
|
||||
"""The role of the messages author, in this case `assistant`."""
|
||||
|
||||
content: Optional[str]
|
||||
"""The contents of the assistant message.
|
||||
|
||||
Required unless `tool_calls` or `function_call` is specified.
|
||||
"""
|
||||
|
||||
function_call: FunctionCall
|
||||
"""Deprecated and replaced by `tool_calls`.
|
||||
|
||||
The name and arguments of a function that should be called, as generated by the
|
||||
model.
|
||||
"""
|
||||
|
||||
name: str
|
||||
"""An optional name for the participant.
|
||||
|
||||
Provides the model information to differentiate between participants of the same
|
||||
role.
|
||||
"""
|
||||
|
||||
tool_calls: Iterable[ChatCompletionMessageToolCallParam]
|
||||
"""The tool calls generated by the model, such as function calls."""
|
||||
|
||||
|
||||
ChatCompletionMessageParam = Union[
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
ChatCompletionAssistantMessageParam,
|
||||
]
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
model: str
|
||||
|
@ -33,4 +169,4 @@ class CompletionRequest(BaseModel):
|
|||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
protected_namespaces = ()
|
||||
protected_namespaces = ()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue