Merge pull request #1902 from BerriAI/litellm_mistral_message_list_fix

fix(factory.py): mistral message input fix
This commit is contained in:
Krish Dholakia 2024-02-08 23:01:39 -08:00 committed by GitHub
commit 51c07e294a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 41 additions and 5 deletions

View file

@ -222,6 +222,7 @@ class OpenAIChatCompletion(BaseLLM):
custom_prompt_dict: dict = {},
client=None,
organization: Optional[str] = None,
custom_llm_provider: Optional[str] = None,
):
super().completion()
exception_mapping_worked = False
@ -236,6 +237,14 @@ class OpenAIChatCompletion(BaseLLM):
status_code=422, message=f"Timeout needs to be a float"
)
if custom_llm_provider == "mistral":
# check if message content passed in as list, and not string
messages = prompt_factory(
model=model,
messages=messages,
custom_llm_provider=custom_llm_provider,
)
for _ in range(
2
): # if call fails due to alternating messages, retry with reformatted message
@ -325,12 +334,13 @@ class OpenAIChatCompletion(BaseLLM):
model_response_object=model_response,
)
except Exception as e:
if "Conversation roles must alternate user/assistant" in str(
e
) or "user and assistant roles should be alternating" in str(e):
if (
"Conversation roles must alternate user/assistant" in str(e)
or "user and assistant roles should be alternating" in str(e)
) and messages is not None:
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
new_messages = []
for i in range(len(messages) - 1):
for i in range(len(messages) - 1): # type: ignore
new_messages.append(messages[i])
if messages[i]["role"] == messages[i + 1]["role"]:
if messages[i]["role"] == "user":
@ -341,7 +351,9 @@ class OpenAIChatCompletion(BaseLLM):
new_messages.append({"role": "user", "content": ""})
new_messages.append(messages[-1])
messages = new_messages
elif "Last message must have role `user`" in str(e):
elif (
"Last message must have role `user`" in str(e)
) and messages is not None:
new_messages = messages
new_messages.append({"role": "user", "content": ""})
messages = new_messages

View file

@ -116,6 +116,28 @@ def mistral_instruct_pt(messages):
return prompt
def mistral_api_pt(messages):
"""
- handles scenario where content is list and not string
- content list is just text, and no images
- if image passed in, then just return as is (user-intended)
Motivation: mistral api doesn't support content as a list
"""
new_messages = []
for m in messages:
texts = ""
if isinstance(m["content"], list):
for c in m["content"]:
if c["type"] == "image_url":
return messages
elif c["type"] == "text" and isinstance(c["text"], str):
texts += c["text"]
new_m = {"role": m["role"], "content": texts}
new_messages.append(new_m)
return new_messages
# Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110
def falcon_instruct_pt(messages):
prompt = ""
@ -612,6 +634,8 @@ def prompt_factory(
return _gemini_vision_convert_messages(messages=messages)
else:
return gemini_text_image_pt(messages=messages)
elif custom_llm_provider == "mistral":
return mistral_api_pt(messages=messages)
try:
if "meta-llama/llama-2" in model and "chat" in model:
return llama_2_chat_pt(messages=messages)