diff --git a/docs/my-website/docs/providers/vertex.md b/docs/my-website/docs/providers/vertex.md
index 582636630a..868c8602c1 100644
--- a/docs/my-website/docs/providers/vertex.md
+++ b/docs/my-website/docs/providers/vertex.md
@@ -1531,28 +1531,103 @@ All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a02
| text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` |
| text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` |
-### Advanced Use `task_type` and `title` (Vertex Specific Params)
+### Supported OpenAI (Unified) Params
-👉 `task_type` and `title` are vertex specific params
+| [param](../embedding/supported_embedding.md#input-params-for-litellmembedding) | type | [vertex equivalent](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api) |
+|-------|-------------|--------------------|
+| `input` | **string or List[string]** | `instances` |
+| `dimensions` | **int** | `output_dimensionality` |
+| `input_type` | **Literal["RETRIEVAL_QUERY","RETRIEVAL_DOCUMENT", "SEMANTIC_SIMILARITY", "CLASSIFICATION", "CLUSTERING", "QUESTION_ANSWERING", "FACT_VERIFICATION"]** | `task_type` |
-LiteLLM Supported Vertex Specific Params
+#### Usage with OpenAI (Unified) Params
+
+
+
+
```python
-auto_truncate: Optional[bool] = None
-task_type: Optional[Literal["RETRIEVAL_QUERY","RETRIEVAL_DOCUMENT", "SEMANTIC_SIMILARITY", "CLASSIFICATION", "CLUSTERING", "QUESTION_ANSWERING", "FACT_VERIFICATION"]] = None
-title: Optional[str] = None # The title of the document to be embedded. (only valid with task_type=RETRIEVAL_DOCUMENT).
+response = litellm.embedding(
+ model="vertex_ai/text-embedding-004",
+ input=["good morning from litellm", "gm"]
+ input_type = "RETRIEVAL_DOCUMENT",
+ dimensions=1,
+)
```
+
+
+
+
+```python
+import openai
+
+client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
+
+response = client.embeddings.create(
+ model="text-embedding-004",
+ input = ["good morning from litellm", "gm"],
+ dimensions=1,
+ extra_body = {
+ "input_type": "RETRIEVAL_QUERY",
+ }
+)
+
+print(response)
+```
+
+
+
+
+### Supported Vertex Specific Params
+
+| param | type |
+|-------|-------------|
+| `auto_truncate` | **bool** |
+| `task_type` | **Literal["RETRIEVAL_QUERY","RETRIEVAL_DOCUMENT", "SEMANTIC_SIMILARITY", "CLASSIFICATION", "CLUSTERING", "QUESTION_ANSWERING", "FACT_VERIFICATION"]** |
+| `title` | **str** |
+
+#### Usage with Vertex Specific Params (Use `task_type` and `title`)
+
+You can pass any vertex specific params to the embedding model. Just pass them to the embedding function like this:
+
+[Relevant Vertex AI doc with all embedding params](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#request_body)
+
+
+
-**Example Usage with LiteLLM**
```python
response = litellm.embedding(
model="vertex_ai/text-embedding-004",
input=["good morning from litellm", "gm"]
task_type = "RETRIEVAL_DOCUMENT",
+ title = "test",
dimensions=1,
auto_truncate=True,
)
```
+
+
+
+
+```python
+import openai
+
+client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
+
+response = client.embeddings.create(
+ model="text-embedding-004",
+ input = ["good morning from litellm", "gm"],
+ dimensions=1,
+ extra_body = {
+ "task_type": "RETRIEVAL_QUERY",
+ "auto_truncate": True,
+ "title": "test",
+ }
+)
+
+print(response)
+```
+
+
## **Multi-Modal Embeddings**
diff --git a/litellm/__init__.py b/litellm/__init__.py
index a4bca6a198..1c3b8434f1 100644
--- a/litellm/__init__.py
+++ b/litellm/__init__.py
@@ -861,7 +861,7 @@ from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gem
GoogleAIStudioGeminiConfig,
VertexAIConfig,
)
-from .llms.vertex_ai_and_google_ai_studio.vertex_ai_non_gemini import (
+from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler import (
VertexAITextEmbeddingConfig,
)
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import (
diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/embeddings/batch_embed_content_handler.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini_embeddings/batch_embed_content_handler.py
similarity index 100%
rename from litellm/llms/vertex_ai_and_google_ai_studio/embeddings/batch_embed_content_handler.py
rename to litellm/llms/vertex_ai_and_google_ai_studio/gemini_embeddings/batch_embed_content_handler.py
diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/embeddings/batch_embed_content_transformation.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini_embeddings/batch_embed_content_transformation.py
similarity index 100%
rename from litellm/llms/vertex_ai_and_google_ai_studio/embeddings/batch_embed_content_transformation.py
rename to litellm/llms/vertex_ai_and_google_ai_studio/gemini_embeddings/batch_embed_content_transformation.py
diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_non_gemini.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_non_gemini.py
index e1f429d1d8..2a250864a6 100644
--- a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_non_gemini.py
+++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_non_gemini.py
@@ -1096,269 +1096,3 @@ async def async_streaming(
)
return streamwrapper
-
-
-class VertexAITextEmbeddingConfig(BaseModel):
- """
- Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#TextEmbeddingInput
-
- Args:
- auto_truncate: Optional(bool) If True, will truncate input text to fit within the model's max input length.
- task_type: Optional(str) The type of task to be performed. The default is "RETRIEVAL_QUERY".
- title: Optional(str) The title of the document to be embedded. (only valid with task_type=RETRIEVAL_DOCUMENT).
- """
-
- auto_truncate: Optional[bool] = None
- task_type: Optional[
- Literal[
- "RETRIEVAL_QUERY",
- "RETRIEVAL_DOCUMENT",
- "SEMANTIC_SIMILARITY",
- "CLASSIFICATION",
- "CLUSTERING",
- "QUESTION_ANSWERING",
- "FACT_VERIFICATION",
- ]
- ] = None
- title: Optional[str] = None
-
- def __init__(
- self,
- auto_truncate: Optional[bool] = None,
- task_type: Optional[
- Literal[
- "RETRIEVAL_QUERY",
- "RETRIEVAL_DOCUMENT",
- "SEMANTIC_SIMILARITY",
- "CLASSIFICATION",
- "CLUSTERING",
- "QUESTION_ANSWERING",
- "FACT_VERIFICATION",
- ]
- ] = None,
- title: Optional[str] = None,
- ) -> None:
- locals_ = locals()
- for key, value in locals_.items():
- if key != "self" and value is not None:
- setattr(self.__class__, key, value)
-
- @classmethod
- def get_config(cls):
- return {
- k: v
- for k, v in cls.__dict__.items()
- if not k.startswith("__")
- and not isinstance(
- v,
- (
- types.FunctionType,
- types.BuiltinFunctionType,
- classmethod,
- staticmethod,
- ),
- )
- and v is not None
- }
-
- def get_supported_openai_params(self):
- return [
- "dimensions",
- ]
-
- def map_openai_params(self, non_default_params: dict, optional_params: dict):
- for param, value in non_default_params.items():
- if param == "dimensions":
- optional_params["output_dimensionality"] = value
- return optional_params
-
- def get_mapped_special_auth_params(self) -> dict:
- """
- Common auth params across bedrock/vertex_ai/azure/watsonx
- """
- return {"project": "vertex_project", "region_name": "vertex_location"}
-
- def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
- mapped_params = self.get_mapped_special_auth_params()
-
- for param, value in non_default_params.items():
- if param in mapped_params:
- optional_params[mapped_params[param]] = value
- return optional_params
-
-
-def embedding(
- model: str,
- input: Union[list, str],
- print_verbose,
- model_response: litellm.EmbeddingResponse,
- optional_params: dict,
- api_key: Optional[str] = None,
- logging_obj=None,
- encoding=None,
- vertex_project=None,
- vertex_location=None,
- vertex_credentials=None,
- aembedding=False,
-):
- # logic for parsing in - calling - parsing out model embedding calls
- try:
- import vertexai
- except:
- raise VertexAIError(
- status_code=400,
- message="vertexai import failed please run `pip install google-cloud-aiplatform`",
- )
-
- import google.auth # type: ignore
- from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
-
- ## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
- try:
- print_verbose(
- f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
- )
- if vertex_credentials is not None and isinstance(vertex_credentials, str):
- import google.oauth2.service_account
-
- json_obj = json.loads(vertex_credentials)
-
- creds = google.oauth2.service_account.Credentials.from_service_account_info(
- json_obj,
- scopes=["https://www.googleapis.com/auth/cloud-platform"],
- )
- else:
- creds, _ = google.auth.default(quota_project_id=vertex_project)
- print_verbose(
- f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"
- )
- vertexai.init(
- project=vertex_project, location=vertex_location, credentials=creds
- )
- except Exception as e:
- raise VertexAIError(status_code=401, message=str(e))
-
- if isinstance(input, str):
- input = [input]
-
- if optional_params is not None and isinstance(optional_params, dict):
- if optional_params.get("task_type") or optional_params.get("title"):
- # if user passed task_type or title, cast to TextEmbeddingInput
- _task_type = optional_params.pop("task_type", None)
- _title = optional_params.pop("title", None)
- input = [
- TextEmbeddingInput(text=x, task_type=_task_type, title=_title)
- for x in input
- ]
-
- try:
- llm_model = TextEmbeddingModel.from_pretrained(model)
- except Exception as e:
- raise VertexAIError(status_code=422, message=str(e))
-
- if aembedding == True:
- return async_embedding(
- model=model,
- client=llm_model,
- input=input,
- logging_obj=logging_obj,
- model_response=model_response,
- optional_params=optional_params,
- encoding=encoding,
- )
-
- _input_dict = {"texts": input, **optional_params}
- request_str = f"""embeddings = llm_model.get_embeddings({_input_dict})"""
- ## LOGGING PRE-CALL
- logging_obj.pre_call(
- input=input,
- api_key=None,
- additional_args={
- "complete_input_dict": optional_params,
- "request_str": request_str,
- },
- )
-
- try:
- embeddings = llm_model.get_embeddings(**_input_dict)
- except Exception as e:
- raise VertexAIError(status_code=500, message=str(e))
-
- ## LOGGING POST-CALL
- logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
- ## Populate OpenAI compliant dictionary
- embedding_response = []
- input_tokens: int = 0
- for idx, embedding in enumerate(embeddings):
- embedding_response.append(
- {
- "object": "embedding",
- "index": idx,
- "embedding": embedding.values,
- }
- )
- input_tokens += embedding.statistics.token_count # type: ignore
- model_response.object = "list"
- model_response.data = embedding_response
- model_response.model = model
-
- usage = Usage(
- prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
- )
- setattr(model_response, "usage", usage)
-
- return model_response
-
-
-async def async_embedding(
- model: str,
- input: Union[list, str],
- model_response: litellm.EmbeddingResponse,
- logging_obj=None,
- optional_params=None,
- encoding=None,
- client=None,
-):
- """
- Async embedding implementation
- """
- _input_dict = {"texts": input, **optional_params}
- request_str = f"""embeddings = llm_model.get_embeddings({_input_dict})"""
- ## LOGGING PRE-CALL
- logging_obj.pre_call(
- input=input,
- api_key=None,
- additional_args={
- "complete_input_dict": optional_params,
- "request_str": request_str,
- },
- )
-
- try:
- embeddings = await client.get_embeddings_async(**_input_dict)
- except Exception as e:
- raise VertexAIError(status_code=500, message=str(e))
-
- ## LOGGING POST-CALL
- logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
- ## Populate OpenAI compliant dictionary
- embedding_response = []
- input_tokens: int = 0
- for idx, embedding in enumerate(embeddings):
- embedding_response.append(
- {
- "object": "embedding",
- "index": idx,
- "embedding": embedding.values,
- }
- )
- input_tokens += embedding.statistics.token_count
-
- model_response.object = "list"
- model_response.data = embedding_response
- model_response.model = model
- usage = Usage(
- prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
- )
- setattr(model_response, "usage", usage)
- return model_response
diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/embedding_handler.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/embedding_handler.py
new file mode 100644
index 0000000000..4cd5513c4f
--- /dev/null
+++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/embedding_handler.py
@@ -0,0 +1,283 @@
+import json
+import os
+import types
+from typing import Literal, Optional, Union
+
+from pydantic import BaseModel
+
+import litellm
+from litellm._logging import verbose_logger
+from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_non_gemini import (
+ VertexAIError,
+)
+from litellm.types.llms.vertex_ai import *
+from litellm.utils import Usage
+
+
+class VertexAITextEmbeddingConfig(BaseModel):
+ """
+ Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#TextEmbeddingInput
+
+ Args:
+ auto_truncate: Optional(bool) If True, will truncate input text to fit within the model's max input length.
+ task_type: Optional(str) The type of task to be performed. The default is "RETRIEVAL_QUERY".
+ title: Optional(str) The title of the document to be embedded. (only valid with task_type=RETRIEVAL_DOCUMENT).
+ """
+
+ auto_truncate: Optional[bool] = None
+ task_type: Optional[
+ Literal[
+ "RETRIEVAL_QUERY",
+ "RETRIEVAL_DOCUMENT",
+ "SEMANTIC_SIMILARITY",
+ "CLASSIFICATION",
+ "CLUSTERING",
+ "QUESTION_ANSWERING",
+ "FACT_VERIFICATION",
+ ]
+ ] = None
+ title: Optional[str] = None
+
+ def __init__(
+ self,
+ auto_truncate: Optional[bool] = None,
+ task_type: Optional[
+ Literal[
+ "RETRIEVAL_QUERY",
+ "RETRIEVAL_DOCUMENT",
+ "SEMANTIC_SIMILARITY",
+ "CLASSIFICATION",
+ "CLUSTERING",
+ "QUESTION_ANSWERING",
+ "FACT_VERIFICATION",
+ ]
+ ] = None,
+ title: Optional[str] = None,
+ ) -> None:
+ locals_ = locals()
+ for key, value in locals_.items():
+ if key != "self" and value is not None:
+ setattr(self.__class__, key, value)
+
+ @classmethod
+ def get_config(cls):
+ return {
+ k: v
+ for k, v in cls.__dict__.items()
+ if not k.startswith("__")
+ and not isinstance(
+ v,
+ (
+ types.FunctionType,
+ types.BuiltinFunctionType,
+ classmethod,
+ staticmethod,
+ ),
+ )
+ and v is not None
+ }
+
+ def get_supported_openai_params(self):
+ return ["dimensions"]
+
+ def map_openai_params(
+ self, non_default_params: dict, optional_params: dict, kwargs: dict
+ ):
+ for param, value in non_default_params.items():
+ if param == "dimensions":
+ optional_params["output_dimensionality"] = value
+
+ if "input_type" in kwargs:
+ optional_params["task_type"] = kwargs.pop("input_type")
+ return optional_params, kwargs
+
+ def get_mapped_special_auth_params(self) -> dict:
+ """
+ Common auth params across bedrock/vertex_ai/azure/watsonx
+ """
+ return {"project": "vertex_project", "region_name": "vertex_location"}
+
+ def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
+ mapped_params = self.get_mapped_special_auth_params()
+
+ for param, value in non_default_params.items():
+ if param in mapped_params:
+ optional_params[mapped_params[param]] = value
+ return optional_params
+
+
+def embedding(
+ model: str,
+ input: Union[list, str],
+ print_verbose,
+ model_response: litellm.EmbeddingResponse,
+ optional_params: dict,
+ api_key: Optional[str] = None,
+ logging_obj=None,
+ encoding=None,
+ vertex_project=None,
+ vertex_location=None,
+ vertex_credentials=None,
+ aembedding=False,
+):
+ # logic for parsing in - calling - parsing out model embedding calls
+ try:
+ import vertexai
+ except:
+ raise VertexAIError(
+ status_code=400,
+ message="vertexai import failed please run `pip install google-cloud-aiplatform`",
+ )
+
+ import google.auth # type: ignore
+ from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
+
+ ## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
+ try:
+ print_verbose(
+ f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
+ )
+ if vertex_credentials is not None and isinstance(vertex_credentials, str):
+ import google.oauth2.service_account
+
+ json_obj = json.loads(vertex_credentials)
+
+ creds = google.oauth2.service_account.Credentials.from_service_account_info(
+ json_obj,
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
+ )
+ else:
+ creds, _ = google.auth.default(quota_project_id=vertex_project)
+ print_verbose(
+ f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"
+ )
+ vertexai.init(
+ project=vertex_project, location=vertex_location, credentials=creds
+ )
+ except Exception as e:
+ raise VertexAIError(status_code=401, message=str(e))
+
+ if isinstance(input, str):
+ input = [input]
+
+ if optional_params is not None and isinstance(optional_params, dict):
+ if optional_params.get("task_type") or optional_params.get("title"):
+ # if user passed task_type or title, cast to TextEmbeddingInput
+ _task_type = optional_params.pop("task_type", None)
+ _title = optional_params.pop("title", None)
+ input = [
+ TextEmbeddingInput(text=x, task_type=_task_type, title=_title)
+ for x in input
+ ]
+
+ try:
+ llm_model = TextEmbeddingModel.from_pretrained(model)
+ except Exception as e:
+ raise VertexAIError(status_code=422, message=str(e))
+
+ if aembedding == True:
+ return async_embedding(
+ model=model,
+ client=llm_model,
+ input=input,
+ logging_obj=logging_obj,
+ model_response=model_response,
+ optional_params=optional_params,
+ encoding=encoding,
+ )
+
+ _input_dict = {"texts": input, **optional_params}
+ request_str = f"""embeddings = llm_model.get_embeddings({_input_dict})"""
+ ## LOGGING PRE-CALL
+ logging_obj.pre_call(
+ input=input,
+ api_key=None,
+ additional_args={
+ "complete_input_dict": optional_params,
+ "request_str": request_str,
+ },
+ )
+
+ try:
+ embeddings = llm_model.get_embeddings(**_input_dict)
+ except Exception as e:
+ raise VertexAIError(status_code=500, message=str(e))
+
+ ## LOGGING POST-CALL
+ logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
+ ## Populate OpenAI compliant dictionary
+ embedding_response = []
+ input_tokens: int = 0
+ for idx, embedding in enumerate(embeddings):
+ embedding_response.append(
+ {
+ "object": "embedding",
+ "index": idx,
+ "embedding": embedding.values,
+ }
+ )
+ input_tokens += embedding.statistics.token_count # type: ignore
+ model_response.object = "list"
+ model_response.data = embedding_response
+ model_response.model = model
+
+ usage = Usage(
+ prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
+ )
+ setattr(model_response, "usage", usage)
+
+ return model_response
+
+
+async def async_embedding(
+ model: str,
+ input: Union[list, str],
+ model_response: litellm.EmbeddingResponse,
+ logging_obj=None,
+ optional_params=None,
+ encoding=None,
+ client=None,
+):
+ """
+ Async embedding implementation
+ """
+ _input_dict = {"texts": input, **optional_params}
+ request_str = f"""embeddings = llm_model.get_embeddings({_input_dict})"""
+ ## LOGGING PRE-CALL
+ logging_obj.pre_call(
+ input=input,
+ api_key=None,
+ additional_args={
+ "complete_input_dict": optional_params,
+ "request_str": request_str,
+ },
+ )
+
+ try:
+ embeddings = await client.get_embeddings_async(**_input_dict)
+ except Exception as e:
+ raise VertexAIError(status_code=500, message=str(e))
+
+ ## LOGGING POST-CALL
+ logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
+ ## Populate OpenAI compliant dictionary
+ embedding_response = []
+ input_tokens: int = 0
+ for idx, embedding in enumerate(embeddings):
+ embedding_response.append(
+ {
+ "object": "embedding",
+ "index": idx,
+ "embedding": embedding.values,
+ }
+ )
+ input_tokens += embedding.statistics.token_count
+
+ model_response.object = "list"
+ model_response.data = embedding_response
+ model_response.model = model
+ usage = Usage(
+ prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
+ )
+ setattr(model_response, "usage", usage)
+ return model_response
diff --git a/litellm/main.py b/litellm/main.py
index d77d860584..95a1063772 100644
--- a/litellm/main.py
+++ b/litellm/main.py
@@ -125,12 +125,12 @@ from .llms.vertex_ai_and_google_ai_studio import (
vertex_ai_anthropic,
vertex_ai_non_gemini,
)
-from .llms.vertex_ai_and_google_ai_studio.embeddings.batch_embed_content_handler import (
- GoogleBatchEmbeddings,
-)
from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
+from .llms.vertex_ai_and_google_ai_studio.gemini_embeddings.batch_embed_content_handler import (
+ GoogleBatchEmbeddings,
+)
from .llms.vertex_ai_and_google_ai_studio.multimodal_embeddings.embedding_handler import (
VertexMultimodalEmbedding,
)
@@ -140,6 +140,9 @@ from .llms.vertex_ai_and_google_ai_studio.text_to_speech.text_to_speech_handler
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import (
VertexAIPartnerModels,
)
+from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings import (
+ embedding_handler as vertex_ai_embedding_handler,
+)
from .llms.watsonx import IBMWatsonXAI
from .types.llms.openai import HttpxBinaryResponseContent
from .types.utils import (
@@ -3608,7 +3611,7 @@ def embedding(
custom_llm_provider="vertex_ai",
)
else:
- response = vertex_ai_non_gemini.embedding(
+ response = vertex_ai_embedding_handler.embedding(
model=model,
input=input,
encoding=encoding,
diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py
index f2bfd092fc..4b659944a0 100644
--- a/litellm/tests/test_amazing_vertex_completion.py
+++ b/litellm/tests/test_amazing_vertex_completion.py
@@ -2014,6 +2014,25 @@ def test_vertexai_embedding_embedding_latest():
pytest.fail(f"Error occurred: {e}")
+@pytest.mark.flaky(retries=3, delay=1)
+def test_vertexai_embedding_embedding_latest_input_type():
+ try:
+ load_vertex_ai_credentials()
+ litellm.set_verbose = True
+
+ response = embedding(
+ model="vertex_ai/text-embedding-004",
+ input=["hi"],
+ input_type="RETRIEVAL_QUERY",
+ )
+ assert response.usage.prompt_tokens > 0
+ print(f"response:", response)
+ except litellm.RateLimitError as e:
+ pass
+ except Exception as e:
+ pytest.fail(f"Error occurred: {e}")
+
+
@pytest.mark.asyncio
@pytest.mark.flaky(retries=3, delay=1)
async def test_vertexai_aembedding():
diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py
index baf528d8bc..8fdf722f06 100644
--- a/litellm/tests/test_completion.py
+++ b/litellm/tests/test_completion.py
@@ -23,7 +23,7 @@ from litellm import RateLimitError, Timeout, completion, completion_cost, embedd
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
-# litellm.num_retries=3
+# litellm.num_retries = 3
litellm.cache = None
litellm.success_callback = []
user_message = "Write a short poem about the sky"
diff --git a/litellm/utils.py b/litellm/utils.py
index c5739fcc45..ec4ac79c0f 100644
--- a/litellm/utils.py
+++ b/litellm/utils.py
@@ -2621,8 +2621,11 @@ def get_optional_params_embeddings(
request_type="embeddings",
)
_check_valid_arg(supported_params=supported_params)
- optional_params = litellm.VertexAITextEmbeddingConfig().map_openai_params(
- non_default_params=non_default_params, optional_params={}
+ (
+ optional_params,
+ kwargs,
+ ) = litellm.VertexAITextEmbeddingConfig().map_openai_params(
+ non_default_params=non_default_params, optional_params={}, kwargs=kwargs
)
final_params = {**optional_params, **kwargs}
return final_params