Litellm merge pr (#7161)

* build: merge branch

* test: fix openai naming

* fix(main.py): fix openai renaming

* style: ignore function length for config factory

* fix(sagemaker/): fix routing logic

* fix: fix imports

* fix: fix override
This commit is contained in:
Krish Dholakia 2024-12-10 22:49:26 -08:00 committed by GitHub
parent d5aae81c6d
commit 350cfc36f7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
88 changed files with 3617 additions and 4421 deletions

View file

@ -9,11 +9,16 @@ Docs - https://docs.mistral.ai/api/
import types
from typing import List, Literal, Optional, Tuple, Union
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
from litellm.llms.prompt_templates.common_utils import (
handle_messages_with_content_list_to_str_conversion,
strip_none_values_from_message,
)
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
class MistralConfig:
class MistralConfig(OpenAIGPTConfig):
"""
Reference: https://docs.mistral.ai/api/
@ -67,23 +72,9 @@ class MistralConfig:
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
return super().get_config()
def get_supported_openai_params(self):
def get_supported_openai_params(self, model: str) -> List[str]:
return [
"stream",
"temperature",
@ -104,7 +95,13 @@ class MistralConfig:
else: # openai 'tool_choice' object param not supported by Mistral API
return "any"
def map_openai_params(self, non_default_params: dict, optional_params: dict):
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_tokens"] = value
@ -150,8 +147,9 @@ class MistralConfig:
)
return api_base, dynamic_api_key
@classmethod
def _transform_messages(cls, messages: List[AllMessageValues]):
def _transform_messages(
self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
"""
- handles scenario where content is list and not string
- content list is just text, and no images
@ -160,48 +158,36 @@ class MistralConfig:
Motivation: mistral api doesn't support content as a list
"""
new_messages = []
## 1. If 'image_url' in content, then return as is
for m in messages:
special_keys = ["role", "content", "tool_calls", "function_call"]
extra_args = {}
if isinstance(m, dict):
for k, v in m.items():
if k not in special_keys:
extra_args[k] = v
texts = ""
_content = m.get("content")
if _content is not None and isinstance(_content, list):
for c in _content:
_text: Optional[str] = c.get("text")
if c["type"] == "image_url":
_content_block = m.get("content")
if _content_block and isinstance(_content_block, list):
for c in _content_block:
if c.get("type") == "image_url":
return messages
elif c["type"] == "text" and isinstance(_text, str):
texts += _text
elif _content is not None and isinstance(_content, str):
texts = _content
new_m = {"role": m["role"], "content": texts, **extra_args}
## 2. If content is list, then convert to string
messages = handle_messages_with_content_list_to_str_conversion(messages)
if m.get("tool_calls"):
new_m["tool_calls"] = m.get("tool_calls")
## 3. Handle name in message
new_messages: List[AllMessageValues] = []
for m in messages:
m = MistralConfig._handle_name_in_message(m)
m = strip_none_values_from_message(m) # prevents 'extra_forbidden' error
new_messages.append(m)
new_m = cls._handle_name_in_message(new_m)
new_messages.append(new_m)
return new_messages
@classmethod
def _handle_name_in_message(cls, message: dict) -> dict:
def _handle_name_in_message(cls, message: AllMessageValues) -> AllMessageValues:
"""
Mistral API only supports `name` in tool messages
If role == tool, then we keep `name`
Otherwise, we drop `name`
"""
if message.get("name") is not None:
if message["role"] == "tool":
message["name"] = message.get("name")
else:
message.pop("name", None)
_name = message.get("name") # type: ignore
if _name is not None and message["role"] != "tool":
message.pop("name", None) # type: ignore
return message