forked from phoenix/litellm-mirror
(fix) OpenAI's optional messages[].name does not work with Mistral API (#6701)
* use helper for _transform_messages mistral * add test_message_with_name to base LLMChat test * fix linting
This commit is contained in:
parent
c3bc9e6b12
commit
9d20c19e0c
4 changed files with 100 additions and 38 deletions
|
@ -10,6 +10,7 @@ import types
|
||||||
from typing import List, Literal, Optional, Tuple, Union
|
from typing import List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
|
||||||
|
|
||||||
class MistralConfig:
|
class MistralConfig:
|
||||||
|
@ -148,3 +149,59 @@ class MistralConfig:
|
||||||
or get_secret_str("MISTRAL_API_KEY")
|
or get_secret_str("MISTRAL_API_KEY")
|
||||||
)
|
)
|
||||||
return api_base, dynamic_api_key
|
return api_base, dynamic_api_key
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _transform_messages(cls, messages: List[AllMessageValues]):
|
||||||
|
"""
|
||||||
|
- handles scenario where content is list and not string
|
||||||
|
- content list is just text, and no images
|
||||||
|
- if image passed in, then just return as is (user-intended)
|
||||||
|
- if `name` is passed, then drop it for mistral API: https://github.com/BerriAI/litellm/issues/6696
|
||||||
|
|
||||||
|
Motivation: mistral api doesn't support content as a list
|
||||||
|
"""
|
||||||
|
new_messages = []
|
||||||
|
for m in messages:
|
||||||
|
special_keys = ["role", "content", "tool_calls", "function_call"]
|
||||||
|
extra_args = {}
|
||||||
|
if isinstance(m, dict):
|
||||||
|
for k, v in m.items():
|
||||||
|
if k not in special_keys:
|
||||||
|
extra_args[k] = v
|
||||||
|
texts = ""
|
||||||
|
_content = m.get("content")
|
||||||
|
if _content is not None and isinstance(_content, list):
|
||||||
|
for c in _content:
|
||||||
|
_text: Optional[str] = c.get("text")
|
||||||
|
if c["type"] == "image_url":
|
||||||
|
return messages
|
||||||
|
elif c["type"] == "text" and isinstance(_text, str):
|
||||||
|
texts += _text
|
||||||
|
elif _content is not None and isinstance(_content, str):
|
||||||
|
texts = _content
|
||||||
|
|
||||||
|
new_m = {"role": m["role"], "content": texts, **extra_args}
|
||||||
|
|
||||||
|
if m.get("tool_calls"):
|
||||||
|
new_m["tool_calls"] = m.get("tool_calls")
|
||||||
|
|
||||||
|
new_m = cls._handle_name_in_message(new_m)
|
||||||
|
|
||||||
|
new_messages.append(new_m)
|
||||||
|
return new_messages
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _handle_name_in_message(cls, message: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Mistral API only supports `name` in tool messages
|
||||||
|
|
||||||
|
If role == tool, then we keep `name`
|
||||||
|
Otherwise, we drop `name`
|
||||||
|
"""
|
||||||
|
if message.get("name") is not None:
|
||||||
|
if message["role"] == "tool":
|
||||||
|
message["name"] = message.get("name")
|
||||||
|
else:
|
||||||
|
message.pop("name", None)
|
||||||
|
|
||||||
|
return message
|
||||||
|
|
|
@ -259,43 +259,6 @@ def mistral_instruct_pt(messages):
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
def mistral_api_pt(messages):
|
|
||||||
"""
|
|
||||||
- handles scenario where content is list and not string
|
|
||||||
- content list is just text, and no images
|
|
||||||
- if image passed in, then just return as is (user-intended)
|
|
||||||
|
|
||||||
Motivation: mistral api doesn't support content as a list
|
|
||||||
"""
|
|
||||||
new_messages = []
|
|
||||||
for m in messages:
|
|
||||||
special_keys = ["role", "content", "tool_calls", "function_call"]
|
|
||||||
extra_args = {}
|
|
||||||
if isinstance(m, dict):
|
|
||||||
for k, v in m.items():
|
|
||||||
if k not in special_keys:
|
|
||||||
extra_args[k] = v
|
|
||||||
texts = ""
|
|
||||||
if m.get("content", None) is not None and isinstance(m["content"], list):
|
|
||||||
for c in m["content"]:
|
|
||||||
if c["type"] == "image_url":
|
|
||||||
return messages
|
|
||||||
elif c["type"] == "text" and isinstance(c["text"], str):
|
|
||||||
texts += c["text"]
|
|
||||||
elif m.get("content", None) is not None and isinstance(m["content"], str):
|
|
||||||
texts = m["content"]
|
|
||||||
|
|
||||||
new_m = {"role": m["role"], "content": texts, **extra_args}
|
|
||||||
|
|
||||||
if new_m["role"] == "tool" and m.get("name"):
|
|
||||||
new_m["name"] = m["name"]
|
|
||||||
if m.get("tool_calls"):
|
|
||||||
new_m["tool_calls"] = m["tool_calls"]
|
|
||||||
|
|
||||||
new_messages.append(new_m)
|
|
||||||
return new_messages
|
|
||||||
|
|
||||||
|
|
||||||
# Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110
|
# Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110
|
||||||
def falcon_instruct_pt(messages):
|
def falcon_instruct_pt(messages):
|
||||||
prompt = ""
|
prompt = ""
|
||||||
|
@ -2853,7 +2816,7 @@ def prompt_factory(
|
||||||
else:
|
else:
|
||||||
return gemini_text_image_pt(messages=messages)
|
return gemini_text_image_pt(messages=messages)
|
||||||
elif custom_llm_provider == "mistral":
|
elif custom_llm_provider == "mistral":
|
||||||
return mistral_api_pt(messages=messages)
|
return litellm.MistralConfig._transform_messages(messages=messages)
|
||||||
elif custom_llm_provider == "bedrock":
|
elif custom_llm_provider == "bedrock":
|
||||||
if "amazon.titan-text" in model:
|
if "amazon.titan-text" in model:
|
||||||
return amazon_titan_pt(messages=messages)
|
return amazon_titan_pt(messages=messages)
|
||||||
|
|
|
@ -45,6 +45,14 @@ class BaseLLMChatTest(ABC):
|
||||||
)
|
)
|
||||||
assert response is not None
|
assert response is not None
|
||||||
|
|
||||||
|
def test_message_with_name(self):
|
||||||
|
base_completion_call_args = self.get_base_completion_call_args()
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "Hello", "name": "test_name"},
|
||||||
|
]
|
||||||
|
response = litellm.completion(**base_completion_call_args, messages=messages)
|
||||||
|
assert response is not None
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def pdf_messages(self):
|
def pdf_messages(self):
|
||||||
import base64
|
import base64
|
||||||
|
|
34
tests/llm_translation/test_mistral_api.py
Normal file
34
tests/llm_translation/test_mistral_api.py
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
import litellm.types
|
||||||
|
import litellm.types.utils
|
||||||
|
from litellm.llms.anthropic.chat import ModelResponseIterator
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
from typing import Optional
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
from litellm.llms.anthropic.common_utils import process_anthropic_headers
|
||||||
|
from httpx import Headers
|
||||||
|
from base_llm_unit_tests import BaseLLMChatTest
|
||||||
|
|
||||||
|
|
||||||
|
class TestMistralCompletion(BaseLLMChatTest):
|
||||||
|
def get_base_completion_call_args(self) -> dict:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
return {"model": "mistral/mistral-small-latest"}
|
Loading…
Add table
Add a link
Reference in a new issue