fix(factory.py): mistral message input fix

This commit is contained in:
Krrish Dholakia 2024-02-08 20:54:26 -08:00
parent ff93609453
commit c9e5c796ad
4 changed files with 65 additions and 5 deletions

View file

@ -222,6 +222,7 @@ class OpenAIChatCompletion(BaseLLM):
custom_prompt_dict: dict = {},
client=None,
organization: Optional[str] = None,
custom_llm_provider: Optional[str] = None,
):
super().completion()
exception_mapping_worked = False
@ -236,6 +237,14 @@ class OpenAIChatCompletion(BaseLLM):
status_code=422, message=f"Timeout needs to be a float"
)
if custom_llm_provider == "mistral":
# check if message content passed in as list, and not string
messages = prompt_factory(
model=model,
messages=messages,
custom_llm_provider=custom_llm_provider,
)
for _ in range(
2
): # if call fails due to alternating messages, retry with reformatted message
@ -325,12 +334,13 @@ class OpenAIChatCompletion(BaseLLM):
model_response_object=model_response,
)
except Exception as e:
if "Conversation roles must alternate user/assistant" in str(
e
) or "user and assistant roles should be alternating" in str(e):
if (
"Conversation roles must alternate user/assistant" in str(e)
or "user and assistant roles should be alternating" in str(e)
) and messages is not None:
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
new_messages = []
for i in range(len(messages) - 1):
for i in range(len(messages) - 1): # type: ignore
new_messages.append(messages[i])
if messages[i]["role"] == messages[i + 1]["role"]:
if messages[i]["role"] == "user":
@ -341,7 +351,9 @@ class OpenAIChatCompletion(BaseLLM):
new_messages.append({"role": "user", "content": ""})
new_messages.append(messages[-1])
messages = new_messages
elif "Last message must have role `user`" in str(e):
elif (
"Last message must have role `user`" in str(e)
) and messages is not None:
new_messages = messages
new_messages.append({"role": "user", "content": ""})
messages = new_messages