fix(factory.py): mistral message input fix

This commit is contained in:
Krrish Dholakia 2024-02-08 20:54:26 -08:00
parent ff93609453
commit c9e5c796ad
4 changed files with 65 additions and 5 deletions

View file

@ -222,6 +222,7 @@ class OpenAIChatCompletion(BaseLLM):
custom_prompt_dict: dict = {},
client=None,
organization: Optional[str] = None,
custom_llm_provider: Optional[str] = None,
):
super().completion()
exception_mapping_worked = False
@ -236,6 +237,14 @@ class OpenAIChatCompletion(BaseLLM):
status_code=422, message=f"Timeout needs to be a float"
)
if custom_llm_provider == "mistral":
# check if message content passed in as list, and not string
messages = prompt_factory(
model=model,
messages=messages,
custom_llm_provider=custom_llm_provider,
)
for _ in range(
2
): # if call fails due to alternating messages, retry with reformatted message
@ -325,12 +334,13 @@ class OpenAIChatCompletion(BaseLLM):
model_response_object=model_response,
)
except Exception as e:
if "Conversation roles must alternate user/assistant" in str(
e
) or "user and assistant roles should be alternating" in str(e):
if (
"Conversation roles must alternate user/assistant" in str(e)
or "user and assistant roles should be alternating" in str(e)
) and messages is not None:
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
new_messages = []
for i in range(len(messages) - 1):
for i in range(len(messages) - 1): # type: ignore
new_messages.append(messages[i])
if messages[i]["role"] == messages[i + 1]["role"]:
if messages[i]["role"] == "user":
@ -341,7 +351,9 @@ class OpenAIChatCompletion(BaseLLM):
new_messages.append({"role": "user", "content": ""})
new_messages.append(messages[-1])
messages = new_messages
elif "Last message must have role `user`" in str(e):
elif (
"Last message must have role `user`" in str(e)
) and messages is not None:
new_messages = messages
new_messages.append({"role": "user", "content": ""})
messages = new_messages

View file

@ -116,6 +116,28 @@ def mistral_instruct_pt(messages):
return prompt
def mistral_api_pt(messages):
"""
- handles scenario where content is list and not string
- content list is just text, and no images
- if image passed in, then just return as is (user-intended)
Motivation: mistral api doesn't support content as a list
"""
new_messages = []
for m in messages:
texts = ""
if isinstance(m["content"], list):
for c in m["content"]:
if c["type"] == "image_url":
return messages
elif c["type"] == "text" and isinstance(c["text"], str):
texts += c["text"]
new_m = {"role": m["role"], "content": texts}
new_messages.append(new_m)
return new_messages
# Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110
def falcon_instruct_pt(messages):
prompt = ""
@ -612,6 +634,8 @@ def prompt_factory(
return _gemini_vision_convert_messages(messages=messages)
else:
return gemini_text_image_pt(messages=messages)
elif custom_llm_provider == "mistral":
return mistral_api_pt(messages=messages)
try:
if "meta-llama/llama-2" in model and "chat" in model:
return llama_2_chat_pt(messages=messages)

View file

@ -846,6 +846,7 @@ def completion(
custom_prompt_dict=custom_prompt_dict,
client=client, # pass AsyncOpenAI, OpenAI client
organization=organization,
custom_llm_provider=custom_llm_provider,
)
except Exception as e:
## LOGGING - log the original exception returned

View file

@ -110,6 +110,29 @@ def test_completion_mistral_api():
# test_completion_mistral_api()
def test_completion_mistral_api_modified_input():
try:
litellm.set_verbose = True
response = completion(
model="mistral/mistral-tiny",
max_tokens=5,
messages=[
{
"role": "user",
"content": [{"type": "text", "text": "Hey, how's it going?"}],
}
],
)
# Add any assertions here to check the response
print(response)
cost = litellm.completion_cost(completion_response=response)
print("cost to make mistral completion=", cost)
assert cost > 0.0
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_completion_claude2_1():
try:
print("claude2.1 test request")