diff --git a/litellm/__init__.py b/litellm/__init__.py index 6dddc8a22..3ac5d0581 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -787,6 +787,7 @@ from .llms.openai import ( OpenAIConfig, OpenAITextCompletionConfig, MistralConfig, + MistralEmbeddingConfig, DeepInfraConfig, ) from .llms.azure import ( diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index dec86d35d..fa7745af8 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -164,6 +164,49 @@ class MistralConfig: return optional_params +class MistralEmbeddingConfig: + """ + Reference: https://docs.mistral.ai/api/#operation/createEmbedding + """ + + def __init__( + self, + ) -> None: + locals_ = locals().copy() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params(self): + return [ + "encoding_format", + ] + + def map_openai_params(self, non_default_params: dict, optional_params: dict): + for param, value in non_default_params.items(): + if param == "encoding_format": + optional_params["encoding_format"] = value + return optional_params + + class DeepInfraConfig: """ Reference: https://deepinfra.com/docs/advanced/openai_api diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 089d469af..d8dc798f0 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -79,10 +79,6 @@ async def add_litellm_data_to_request( data["cache"][k] = v verbose_proxy_logger.debug("receiving data: %s", data) - # users can pass in 'user' param to /chat/completions. Don't override it - if data.get("user", None) is None and user_api_key_dict.user_id is not None: - # if users are using user_api_key_auth, set `user` in `data` - data["user"] = user_api_key_dict.user_id if "metadata" not in data: data["metadata"] = {} diff --git a/litellm/utils.py b/litellm/utils.py index 85f5a405a..036e74767 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4943,7 +4943,18 @@ def get_optional_params_embeddings( message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.", ) return {**non_default_params, **kwargs} - + if custom_llm_provider == "mistral": + supported_params = get_supported_openai_params( + model=model, + custom_llm_provider="mistral", + request_type="embeddings", + ) + _check_valid_arg(supported_params=supported_params) + optional_params = litellm.MistralEmbeddingConfig().map_openai_params( + non_default_params=non_default_params, optional_params={} + ) + final_params = {**optional_params, **kwargs} + return final_params if ( custom_llm_provider != "openai" and custom_llm_provider != "azure" @@ -6352,7 +6363,10 @@ def get_supported_openai_params( "max_retries", ] elif custom_llm_provider == "mistral": - return litellm.MistralConfig().get_supported_openai_params() + if request_type == "chat_completion": + return litellm.MistralConfig().get_supported_openai_params() + elif request_type == "embeddings": + return litellm.MistralEmbeddingConfig().get_supported_openai_params() elif custom_llm_provider == "replicate": return [ "stream",