(fix) OpenAI's optional messages[].name does not work with Mistral API (#6701)

* use helper for _transform_messages mistral

* add test_message_with_name to base LLMChat test

* fix linting
This commit is contained in:
Ishaan Jaff 2024-11-11 18:03:41 -08:00 committed by GitHub
parent c3bc9e6b12
commit 9d20c19e0c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 100 additions and 38 deletions

View file

@ -10,6 +10,7 @@ import types
from typing import List, Literal, Optional, Tuple, Union
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
class MistralConfig:
@ -148,3 +149,59 @@ class MistralConfig:
or get_secret_str("MISTRAL_API_KEY")
)
return api_base, dynamic_api_key
@classmethod
def _transform_messages(cls, messages: List[AllMessageValues]):
"""
- handles scenario where content is list and not string
- content list is just text, and no images
- if image passed in, then just return as is (user-intended)
- if `name` is passed, then drop it for mistral API: https://github.com/BerriAI/litellm/issues/6696
Motivation: mistral api doesn't support content as a list
"""
new_messages = []
for m in messages:
special_keys = ["role", "content", "tool_calls", "function_call"]
extra_args = {}
if isinstance(m, dict):
for k, v in m.items():
if k not in special_keys:
extra_args[k] = v
texts = ""
_content = m.get("content")
if _content is not None and isinstance(_content, list):
for c in _content:
_text: Optional[str] = c.get("text")
if c["type"] == "image_url":
return messages
elif c["type"] == "text" and isinstance(_text, str):
texts += _text
elif _content is not None and isinstance(_content, str):
texts = _content
new_m = {"role": m["role"], "content": texts, **extra_args}
if m.get("tool_calls"):
new_m["tool_calls"] = m.get("tool_calls")
new_m = cls._handle_name_in_message(new_m)
new_messages.append(new_m)
return new_messages
@classmethod
def _handle_name_in_message(cls, message: dict) -> dict:
"""
Mistral API only supports `name` in tool messages
If role == tool, then we keep `name`
Otherwise, we drop `name`
"""
if message.get("name") is not None:
if message["role"] == "tool":
message["name"] = message.get("name")
else:
message.pop("name", None)
return message

View file

@ -259,43 +259,6 @@ def mistral_instruct_pt(messages):
return prompt
def mistral_api_pt(messages):
"""
- handles scenario where content is list and not string
- content list is just text, and no images
- if image passed in, then just return as is (user-intended)
Motivation: mistral api doesn't support content as a list
"""
new_messages = []
for m in messages:
special_keys = ["role", "content", "tool_calls", "function_call"]
extra_args = {}
if isinstance(m, dict):
for k, v in m.items():
if k not in special_keys:
extra_args[k] = v
texts = ""
if m.get("content", None) is not None and isinstance(m["content"], list):
for c in m["content"]:
if c["type"] == "image_url":
return messages
elif c["type"] == "text" and isinstance(c["text"], str):
texts += c["text"]
elif m.get("content", None) is not None and isinstance(m["content"], str):
texts = m["content"]
new_m = {"role": m["role"], "content": texts, **extra_args}
if new_m["role"] == "tool" and m.get("name"):
new_m["name"] = m["name"]
if m.get("tool_calls"):
new_m["tool_calls"] = m["tool_calls"]
new_messages.append(new_m)
return new_messages
# Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110
def falcon_instruct_pt(messages):
prompt = ""
@ -2853,7 +2816,7 @@ def prompt_factory(
else:
return gemini_text_image_pt(messages=messages)
elif custom_llm_provider == "mistral":
return mistral_api_pt(messages=messages)
return litellm.MistralConfig._transform_messages(messages=messages)
elif custom_llm_provider == "bedrock":
if "amazon.titan-text" in model:
return amazon_titan_pt(messages=messages)

View file

@ -45,6 +45,14 @@ class BaseLLMChatTest(ABC):
)
assert response is not None
def test_message_with_name(self):
base_completion_call_args = self.get_base_completion_call_args()
messages = [
{"role": "user", "content": "Hello", "name": "test_name"},
]
response = litellm.completion(**base_completion_call_args, messages=messages)
assert response is not None
@pytest.fixture
def pdf_messages(self):
import base64

View file

@ -0,0 +1,34 @@
import asyncio
import os
import sys
import traceback
from dotenv import load_dotenv
import litellm.types
import litellm.types.utils
from litellm.llms.anthropic.chat import ModelResponseIterator
load_dotenv()
import io
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from typing import Optional
from unittest.mock import MagicMock, patch
import pytest
import litellm
from litellm.llms.anthropic.common_utils import process_anthropic_headers
from httpx import Headers
from base_llm_unit_tests import BaseLLMChatTest
class TestMistralCompletion(BaseLLMChatTest):
def get_base_completion_call_args(self) -> dict:
litellm.set_verbose = True
return {"model": "mistral/mistral-small-latest"}