diff --git a/.circleci/config.yml b/.circleci/config.yml index 5f4628d26..736bb8e8a 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -202,6 +202,7 @@ jobs: -e REDIS_PORT=$REDIS_PORT \ -e AZURE_FRANCE_API_KEY=$AZURE_FRANCE_API_KEY \ -e AZURE_EUROPE_API_KEY=$AZURE_EUROPE_API_KEY \ + -e MISTRAL_API_KEY=$MISTRAL_API_KEY \ -e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \ -e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \ -e AWS_REGION_NAME=$AWS_REGION_NAME \ diff --git a/litellm/__init__.py b/litellm/__init__.py index 6dddc8a22..3ac5d0581 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -787,6 +787,7 @@ from .llms.openai import ( OpenAIConfig, OpenAITextCompletionConfig, MistralConfig, + MistralEmbeddingConfig, DeepInfraConfig, ) from .llms.azure import ( diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index dec86d35d..fa7745af8 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -164,6 +164,49 @@ class MistralConfig: return optional_params +class MistralEmbeddingConfig: + """ + Reference: https://docs.mistral.ai/api/#operation/createEmbedding + """ + + def __init__( + self, + ) -> None: + locals_ = locals().copy() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params(self): + return [ + "encoding_format", + ] + + def map_openai_params(self, non_default_params: dict, optional_params: dict): + for param, value in non_default_params.items(): + if param == "encoding_format": + optional_params["encoding_format"] = value + return optional_params + + class DeepInfraConfig: """ Reference: https://deepinfra.com/docs/advanced/openai_api diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 089d469af..d8dc798f0 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -79,10 +79,6 @@ async def add_litellm_data_to_request( data["cache"][k] = v verbose_proxy_logger.debug("receiving data: %s", data) - # users can pass in 'user' param to /chat/completions. Don't override it - if data.get("user", None) is None and user_api_key_dict.user_id is not None: - # if users are using user_api_key_auth, set `user` in `data` - data["user"] = user_api_key_dict.user_id if "metadata" not in data: data["metadata"] = {} diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 3c6b2f201..a21378f31 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -14,10 +14,9 @@ model_list: litellm_params: model: openai/* api_key: os.environ/OPENAI_API_KEY - - model_name: my-triton-model + - model_name: mistral-embed litellm_params: - model: triton/any" - api_base: https://exampleopenaiendpoint-production.up.railway.app/triton/embeddings + model: mistral/mistral-embed general_settings: master_key: sk-1234 diff --git a/litellm/utils.py b/litellm/utils.py index 85f5a405a..036e74767 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4943,7 +4943,18 @@ def get_optional_params_embeddings( message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.", ) return {**non_default_params, **kwargs} - + if custom_llm_provider == "mistral": + supported_params = get_supported_openai_params( + model=model, + custom_llm_provider="mistral", + request_type="embeddings", + ) + _check_valid_arg(supported_params=supported_params) + optional_params = litellm.MistralEmbeddingConfig().map_openai_params( + non_default_params=non_default_params, optional_params={} + ) + final_params = {**optional_params, **kwargs} + return final_params if ( custom_llm_provider != "openai" and custom_llm_provider != "azure" @@ -6352,7 +6363,10 @@ def get_supported_openai_params( "max_retries", ] elif custom_llm_provider == "mistral": - return litellm.MistralConfig().get_supported_openai_params() + if request_type == "chat_completion": + return litellm.MistralConfig().get_supported_openai_params() + elif request_type == "embeddings": + return litellm.MistralEmbeddingConfig().get_supported_openai_params() elif custom_llm_provider == "replicate": return [ "stream", diff --git a/proxy_server_config.yaml b/proxy_server_config.yaml index f9f77c05a..f1853dc83 100644 --- a/proxy_server_config.yaml +++ b/proxy_server_config.yaml @@ -85,6 +85,9 @@ model_list: litellm_params: model: openai/* api_key: os.environ/OPENAI_API_KEY + - model_name: mistral-embed + litellm_params: + model: mistral/mistral-embed - model_name: gpt-instruct # [PROD TEST] - tests if `/health` automatically infers this to be a text completion model litellm_params: model: text-completion-openai/gpt-3.5-turbo-instruct diff --git a/tests/test_openai_endpoints.py b/tests/test_openai_endpoints.py index 83d387ffb..e2f600b76 100644 --- a/tests/test_openai_endpoints.py +++ b/tests/test_openai_endpoints.py @@ -22,6 +22,7 @@ async def generate_key( "text-embedding-ada-002", "dall-e-2", "fake-openai-endpoint-2", + "mistral-embed", ], ): url = "http://0.0.0.0:4000/key/generate" @@ -197,14 +198,14 @@ async def completion(session, key): return response -async def embeddings(session, key): +async def embeddings(session, key, model="text-embedding-ada-002"): url = "http://0.0.0.0:4000/embeddings" headers = { "Authorization": f"Bearer {key}", "Content-Type": "application/json", } data = { - "model": "text-embedding-ada-002", + "model": model, "input": ["hello world"], } @@ -408,6 +409,9 @@ async def test_embeddings(): key_2 = key_gen["key"] await embeddings(session=session, key=key_2) + # embedding request with non OpenAI model + await embeddings(session=session, key=key, model="mistral-embed") + @pytest.mark.asyncio async def test_image_generation():