diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 3f151d1a9..1ca7e1710 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -222,6 +222,7 @@ class OpenAIChatCompletion(BaseLLM): custom_prompt_dict: dict = {}, client=None, organization: Optional[str] = None, + custom_llm_provider: Optional[str] = None, ): super().completion() exception_mapping_worked = False @@ -236,6 +237,14 @@ class OpenAIChatCompletion(BaseLLM): status_code=422, message=f"Timeout needs to be a float" ) + if custom_llm_provider == "mistral": + # check if message content passed in as list, and not string + messages = prompt_factory( + model=model, + messages=messages, + custom_llm_provider=custom_llm_provider, + ) + for _ in range( 2 ): # if call fails due to alternating messages, retry with reformatted message @@ -325,12 +334,13 @@ class OpenAIChatCompletion(BaseLLM): model_response_object=model_response, ) except Exception as e: - if "Conversation roles must alternate user/assistant" in str( - e - ) or "user and assistant roles should be alternating" in str(e): + if ( + "Conversation roles must alternate user/assistant" in str(e) + or "user and assistant roles should be alternating" in str(e) + ) and messages is not None: # reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility new_messages = [] - for i in range(len(messages) - 1): + for i in range(len(messages) - 1): # type: ignore new_messages.append(messages[i]) if messages[i]["role"] == messages[i + 1]["role"]: if messages[i]["role"] == "user": @@ -341,7 +351,9 @@ class OpenAIChatCompletion(BaseLLM): new_messages.append({"role": "user", "content": ""}) new_messages.append(messages[-1]) messages = new_messages - elif "Last message must have role `user`" in str(e): + elif ( + "Last message must have role `user`" in str(e) + ) and messages is not None: new_messages = messages new_messages.append({"role": "user", "content": ""}) messages = new_messages diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index 1aebcf35d..6321860cc 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -116,6 +116,28 @@ def mistral_instruct_pt(messages): return prompt +def mistral_api_pt(messages): + """ + - handles scenario where content is list and not string + - content list is just text, and no images + - if image passed in, then just return as is (user-intended) + + Motivation: mistral api doesn't support content as a list + """ + new_messages = [] + for m in messages: + texts = "" + if isinstance(m["content"], list): + for c in m["content"]: + if c["type"] == "image_url": + return messages + elif c["type"] == "text" and isinstance(c["text"], str): + texts += c["text"] + new_m = {"role": m["role"], "content": texts} + new_messages.append(new_m) + return new_messages + + # Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110 def falcon_instruct_pt(messages): prompt = "" @@ -612,6 +634,8 @@ def prompt_factory( return _gemini_vision_convert_messages(messages=messages) else: return gemini_text_image_pt(messages=messages) + elif custom_llm_provider == "mistral": + return mistral_api_pt(messages=messages) try: if "meta-llama/llama-2" in model and "chat" in model: return llama_2_chat_pt(messages=messages) diff --git a/litellm/main.py b/litellm/main.py index 384dadc32..1a3f4cc3e 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -846,6 +846,7 @@ def completion( custom_prompt_dict=custom_prompt_dict, client=client, # pass AsyncOpenAI, OpenAI client organization=organization, + custom_llm_provider=custom_llm_provider, ) except Exception as e: ## LOGGING - log the original exception returned diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 50fd1e3da..ebe85aa70 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -110,6 +110,29 @@ def test_completion_mistral_api(): # test_completion_mistral_api() +def test_completion_mistral_api_modified_input(): + try: + litellm.set_verbose = True + response = completion( + model="mistral/mistral-tiny", + max_tokens=5, + messages=[ + { + "role": "user", + "content": [{"type": "text", "text": "Hey, how's it going?"}], + } + ], + ) + # Add any assertions here to check the response + print(response) + + cost = litellm.completion_cost(completion_response=response) + print("cost to make mistral completion=", cost) + assert cost > 0.0 + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + def test_completion_claude2_1(): try: print("claude2.1 test request")