fix(main.py): support new 'supports_system_message=False' param

Fixes https://github.com/BerriAI/litellm/issues/3325
This commit is contained in:
Krrish Dholakia 2024-05-03 21:31:45 -07:00
parent 4e95463dbf
commit cfb6df4987
4 changed files with 219 additions and 2 deletions

View file

@ -78,6 +78,7 @@ from .llms.prompt_templates.factory import (
prompt_factory,
custom_prompt,
function_call_prompt,
map_system_message_pt,
)
import tiktoken
from concurrent.futures import ThreadPoolExecutor
@ -554,6 +555,7 @@ def completion(
eos_token = kwargs.get("eos_token", None)
preset_cache_key = kwargs.get("preset_cache_key", None)
hf_model_name = kwargs.get("hf_model_name", None)
supports_system_message = kwargs.get("supports_system_message", None)
### TEXT COMPLETION CALLS ###
text_completion = kwargs.get("text_completion", False)
atext_completion = kwargs.get("atext_completion", False)
@ -644,6 +646,7 @@ def completion(
"no-log",
"base_model",
"stream_timeout",
"supports_system_message",
]
default_params = openai_params + litellm_params
non_default_params = {
@ -758,6 +761,13 @@ def completion(
custom_prompt_dict[model]["bos_token"] = bos_token
if eos_token:
custom_prompt_dict[model]["eos_token"] = eos_token
if (
supports_system_message is not None
and isinstance(supports_system_message, bool)
and supports_system_message == False
):
messages = map_system_message_pt(messages=messages)
model_api_key = get_api_key(
llm_provider=custom_llm_provider, dynamic_api_key=api_key
) # get the api key from the environment if required for the model