forked from phoenix/litellm-mirror
(fix) OpenAI's optional messages[].name does not work with Mistral API (#6701)
* use helper for _transform_messages mistral * add test_message_with_name to base LLMChat test * fix linting
This commit is contained in:
parent
c3bc9e6b12
commit
9d20c19e0c
4 changed files with 100 additions and 38 deletions
|
@ -10,6 +10,7 @@ import types
|
|||
from typing import List, Literal, Optional, Tuple, Union
|
||||
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
|
||||
|
||||
class MistralConfig:
|
||||
|
@ -148,3 +149,59 @@ class MistralConfig:
|
|||
or get_secret_str("MISTRAL_API_KEY")
|
||||
)
|
||||
return api_base, dynamic_api_key
|
||||
|
||||
@classmethod
|
||||
def _transform_messages(cls, messages: List[AllMessageValues]):
|
||||
"""
|
||||
- handles scenario where content is list and not string
|
||||
- content list is just text, and no images
|
||||
- if image passed in, then just return as is (user-intended)
|
||||
- if `name` is passed, then drop it for mistral API: https://github.com/BerriAI/litellm/issues/6696
|
||||
|
||||
Motivation: mistral api doesn't support content as a list
|
||||
"""
|
||||
new_messages = []
|
||||
for m in messages:
|
||||
special_keys = ["role", "content", "tool_calls", "function_call"]
|
||||
extra_args = {}
|
||||
if isinstance(m, dict):
|
||||
for k, v in m.items():
|
||||
if k not in special_keys:
|
||||
extra_args[k] = v
|
||||
texts = ""
|
||||
_content = m.get("content")
|
||||
if _content is not None and isinstance(_content, list):
|
||||
for c in _content:
|
||||
_text: Optional[str] = c.get("text")
|
||||
if c["type"] == "image_url":
|
||||
return messages
|
||||
elif c["type"] == "text" and isinstance(_text, str):
|
||||
texts += _text
|
||||
elif _content is not None and isinstance(_content, str):
|
||||
texts = _content
|
||||
|
||||
new_m = {"role": m["role"], "content": texts, **extra_args}
|
||||
|
||||
if m.get("tool_calls"):
|
||||
new_m["tool_calls"] = m.get("tool_calls")
|
||||
|
||||
new_m = cls._handle_name_in_message(new_m)
|
||||
|
||||
new_messages.append(new_m)
|
||||
return new_messages
|
||||
|
||||
@classmethod
|
||||
def _handle_name_in_message(cls, message: dict) -> dict:
|
||||
"""
|
||||
Mistral API only supports `name` in tool messages
|
||||
|
||||
If role == tool, then we keep `name`
|
||||
Otherwise, we drop `name`
|
||||
"""
|
||||
if message.get("name") is not None:
|
||||
if message["role"] == "tool":
|
||||
message["name"] = message.get("name")
|
||||
else:
|
||||
message.pop("name", None)
|
||||
|
||||
return message
|
||||
|
|
|
@ -259,43 +259,6 @@ def mistral_instruct_pt(messages):
|
|||
return prompt
|
||||
|
||||
|
||||
def mistral_api_pt(messages):
|
||||
"""
|
||||
- handles scenario where content is list and not string
|
||||
- content list is just text, and no images
|
||||
- if image passed in, then just return as is (user-intended)
|
||||
|
||||
Motivation: mistral api doesn't support content as a list
|
||||
"""
|
||||
new_messages = []
|
||||
for m in messages:
|
||||
special_keys = ["role", "content", "tool_calls", "function_call"]
|
||||
extra_args = {}
|
||||
if isinstance(m, dict):
|
||||
for k, v in m.items():
|
||||
if k not in special_keys:
|
||||
extra_args[k] = v
|
||||
texts = ""
|
||||
if m.get("content", None) is not None and isinstance(m["content"], list):
|
||||
for c in m["content"]:
|
||||
if c["type"] == "image_url":
|
||||
return messages
|
||||
elif c["type"] == "text" and isinstance(c["text"], str):
|
||||
texts += c["text"]
|
||||
elif m.get("content", None) is not None and isinstance(m["content"], str):
|
||||
texts = m["content"]
|
||||
|
||||
new_m = {"role": m["role"], "content": texts, **extra_args}
|
||||
|
||||
if new_m["role"] == "tool" and m.get("name"):
|
||||
new_m["name"] = m["name"]
|
||||
if m.get("tool_calls"):
|
||||
new_m["tool_calls"] = m["tool_calls"]
|
||||
|
||||
new_messages.append(new_m)
|
||||
return new_messages
|
||||
|
||||
|
||||
# Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110
|
||||
def falcon_instruct_pt(messages):
|
||||
prompt = ""
|
||||
|
@ -2853,7 +2816,7 @@ def prompt_factory(
|
|||
else:
|
||||
return gemini_text_image_pt(messages=messages)
|
||||
elif custom_llm_provider == "mistral":
|
||||
return mistral_api_pt(messages=messages)
|
||||
return litellm.MistralConfig._transform_messages(messages=messages)
|
||||
elif custom_llm_provider == "bedrock":
|
||||
if "amazon.titan-text" in model:
|
||||
return amazon_titan_pt(messages=messages)
|
||||
|
|
|
@ -45,6 +45,14 @@ class BaseLLMChatTest(ABC):
|
|||
)
|
||||
assert response is not None
|
||||
|
||||
def test_message_with_name(self):
|
||||
base_completion_call_args = self.get_base_completion_call_args()
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello", "name": "test_name"},
|
||||
]
|
||||
response = litellm.completion(**base_completion_call_args, messages=messages)
|
||||
assert response is not None
|
||||
|
||||
@pytest.fixture
|
||||
def pdf_messages(self):
|
||||
import base64
|
||||
|
|
34
tests/llm_translation/test_mistral_api.py
Normal file
34
tests/llm_translation/test_mistral_api.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import litellm.types
|
||||
import litellm.types.utils
|
||||
from litellm.llms.anthropic.chat import ModelResponseIterator
|
||||
|
||||
load_dotenv()
|
||||
import io
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
|
||||
from litellm.llms.anthropic.common_utils import process_anthropic_headers
|
||||
from httpx import Headers
|
||||
from base_llm_unit_tests import BaseLLMChatTest
|
||||
|
||||
|
||||
class TestMistralCompletion(BaseLLMChatTest):
|
||||
def get_base_completion_call_args(self) -> dict:
|
||||
litellm.set_verbose = True
|
||||
return {"model": "mistral/mistral-small-latest"}
|
Loading…
Add table
Add a link
Reference in a new issue