diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 3f151d1a9..1ca7e1710 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -222,6 +222,7 @@ class OpenAIChatCompletion(BaseLLM): custom_prompt_dict: dict = {}, client=None, organization: Optional[str] = None, + custom_llm_provider: Optional[str] = None, ): super().completion() exception_mapping_worked = False @@ -236,6 +237,14 @@ class OpenAIChatCompletion(BaseLLM): status_code=422, message=f"Timeout needs to be a float" ) + if custom_llm_provider == "mistral": + # check if message content passed in as list, and not string + messages = prompt_factory( + model=model, + messages=messages, + custom_llm_provider=custom_llm_provider, + ) + for _ in range( 2 ): # if call fails due to alternating messages, retry with reformatted message @@ -325,12 +334,13 @@ class OpenAIChatCompletion(BaseLLM): model_response_object=model_response, ) except Exception as e: - if "Conversation roles must alternate user/assistant" in str( - e - ) or "user and assistant roles should be alternating" in str(e): + if ( + "Conversation roles must alternate user/assistant" in str(e) + or "user and assistant roles should be alternating" in str(e) + ) and messages is not None: # reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility new_messages = [] - for i in range(len(messages) - 1): + for i in range(len(messages) - 1): # type: ignore new_messages.append(messages[i]) if messages[i]["role"] == messages[i + 1]["role"]: if messages[i]["role"] == "user": @@ -341,7 +351,9 @@ class OpenAIChatCompletion(BaseLLM): new_messages.append({"role": "user", "content": ""}) new_messages.append(messages[-1]) messages = new_messages - elif "Last message must have role `user`" in str(e): + elif ( + "Last message must have role `user`" in str(e) + ) and messages is not None: new_messages = messages new_messages.append({"role": "user", "content": ""}) messages = new_messages diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index 1aebcf35d..6321860cc 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -116,6 +116,28 @@ def mistral_instruct_pt(messages): return prompt +def mistral_api_pt(messages): + """ + - handles scenario where content is list and not string + - content list is just text, and no images + - if image passed in, then just return as is (user-intended) + + Motivation: mistral api doesn't support content as a list + """ + new_messages = [] + for m in messages: + texts = "" + if isinstance(m["content"], list): + for c in m["content"]: + if c["type"] == "image_url": + return messages + elif c["type"] == "text" and isinstance(c["text"], str): + texts += c["text"] + new_m = {"role": m["role"], "content": texts} + new_messages.append(new_m) + return new_messages + + # Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110 def falcon_instruct_pt(messages): prompt = "" @@ -612,6 +634,8 @@ def prompt_factory( return _gemini_vision_convert_messages(messages=messages) else: return gemini_text_image_pt(messages=messages) + elif custom_llm_provider == "mistral": + return mistral_api_pt(messages=messages) try: if "meta-llama/llama-2" in model and "chat" in model: return llama_2_chat_pt(messages=messages)