From 9d20c19e0c6cb3c346cb875d086e035b7a867903 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 11 Nov 2024 18:03:41 -0800 Subject: [PATCH] (fix) OpenAI's optional messages[].name does not work with Mistral API (#6701) * use helper for _transform_messages mistral * add test_message_with_name to base LLMChat test * fix linting --- .../mistral/mistral_chat_transformation.py | 57 +++++++++++++++++++ litellm/llms/prompt_templates/factory.py | 39 +------------ tests/llm_translation/base_llm_unit_tests.py | 8 +++ tests/llm_translation/test_mistral_api.py | 34 +++++++++++ 4 files changed, 100 insertions(+), 38 deletions(-) create mode 100644 tests/llm_translation/test_mistral_api.py diff --git a/litellm/llms/mistral/mistral_chat_transformation.py b/litellm/llms/mistral/mistral_chat_transformation.py index 5d1a54c3a..aeb1a90fd 100644 --- a/litellm/llms/mistral/mistral_chat_transformation.py +++ b/litellm/llms/mistral/mistral_chat_transformation.py @@ -10,6 +10,7 @@ import types from typing import List, Literal, Optional, Tuple, Union from litellm.secret_managers.main import get_secret_str +from litellm.types.llms.openai import AllMessageValues class MistralConfig: @@ -148,3 +149,59 @@ class MistralConfig: or get_secret_str("MISTRAL_API_KEY") ) return api_base, dynamic_api_key + + @classmethod + def _transform_messages(cls, messages: List[AllMessageValues]): + """ + - handles scenario where content is list and not string + - content list is just text, and no images + - if image passed in, then just return as is (user-intended) + - if `name` is passed, then drop it for mistral API: https://github.com/BerriAI/litellm/issues/6696 + + Motivation: mistral api doesn't support content as a list + """ + new_messages = [] + for m in messages: + special_keys = ["role", "content", "tool_calls", "function_call"] + extra_args = {} + if isinstance(m, dict): + for k, v in m.items(): + if k not in special_keys: + extra_args[k] = v + texts = "" + _content = m.get("content") + if _content is not None and isinstance(_content, list): + for c in _content: + _text: Optional[str] = c.get("text") + if c["type"] == "image_url": + return messages + elif c["type"] == "text" and isinstance(_text, str): + texts += _text + elif _content is not None and isinstance(_content, str): + texts = _content + + new_m = {"role": m["role"], "content": texts, **extra_args} + + if m.get("tool_calls"): + new_m["tool_calls"] = m.get("tool_calls") + + new_m = cls._handle_name_in_message(new_m) + + new_messages.append(new_m) + return new_messages + + @classmethod + def _handle_name_in_message(cls, message: dict) -> dict: + """ + Mistral API only supports `name` in tool messages + + If role == tool, then we keep `name` + Otherwise, we drop `name` + """ + if message.get("name") is not None: + if message["role"] == "tool": + message["name"] = message.get("name") + else: + message.pop("name", None) + + return message diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index 80ad2ca35..29028e053 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -259,43 +259,6 @@ def mistral_instruct_pt(messages): return prompt -def mistral_api_pt(messages): - """ - - handles scenario where content is list and not string - - content list is just text, and no images - - if image passed in, then just return as is (user-intended) - - Motivation: mistral api doesn't support content as a list - """ - new_messages = [] - for m in messages: - special_keys = ["role", "content", "tool_calls", "function_call"] - extra_args = {} - if isinstance(m, dict): - for k, v in m.items(): - if k not in special_keys: - extra_args[k] = v - texts = "" - if m.get("content", None) is not None and isinstance(m["content"], list): - for c in m["content"]: - if c["type"] == "image_url": - return messages - elif c["type"] == "text" and isinstance(c["text"], str): - texts += c["text"] - elif m.get("content", None) is not None and isinstance(m["content"], str): - texts = m["content"] - - new_m = {"role": m["role"], "content": texts, **extra_args} - - if new_m["role"] == "tool" and m.get("name"): - new_m["name"] = m["name"] - if m.get("tool_calls"): - new_m["tool_calls"] = m["tool_calls"] - - new_messages.append(new_m) - return new_messages - - # Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110 def falcon_instruct_pt(messages): prompt = "" @@ -2853,7 +2816,7 @@ def prompt_factory( else: return gemini_text_image_pt(messages=messages) elif custom_llm_provider == "mistral": - return mistral_api_pt(messages=messages) + return litellm.MistralConfig._transform_messages(messages=messages) elif custom_llm_provider == "bedrock": if "amazon.titan-text" in model: return amazon_titan_pt(messages=messages) diff --git a/tests/llm_translation/base_llm_unit_tests.py b/tests/llm_translation/base_llm_unit_tests.py index 96004eb4e..18ac7216f 100644 --- a/tests/llm_translation/base_llm_unit_tests.py +++ b/tests/llm_translation/base_llm_unit_tests.py @@ -45,6 +45,14 @@ class BaseLLMChatTest(ABC): ) assert response is not None + def test_message_with_name(self): + base_completion_call_args = self.get_base_completion_call_args() + messages = [ + {"role": "user", "content": "Hello", "name": "test_name"}, + ] + response = litellm.completion(**base_completion_call_args, messages=messages) + assert response is not None + @pytest.fixture def pdf_messages(self): import base64 diff --git a/tests/llm_translation/test_mistral_api.py b/tests/llm_translation/test_mistral_api.py new file mode 100644 index 000000000..b2cb36541 --- /dev/null +++ b/tests/llm_translation/test_mistral_api.py @@ -0,0 +1,34 @@ +import asyncio +import os +import sys +import traceback + +from dotenv import load_dotenv + +import litellm.types +import litellm.types.utils +from litellm.llms.anthropic.chat import ModelResponseIterator + +load_dotenv() +import io +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +from typing import Optional +from unittest.mock import MagicMock, patch + +import pytest + +import litellm + +from litellm.llms.anthropic.common_utils import process_anthropic_headers +from httpx import Headers +from base_llm_unit_tests import BaseLLMChatTest + + +class TestMistralCompletion(BaseLLMChatTest): + def get_base_completion_call_args(self) -> dict: + litellm.set_verbose = True + return {"model": "mistral/mistral-small-latest"}