(performance improvement - vertex embeddings) ~111.11% faster (#6000)

* use vertex llm as base class for embeddings

* use correct vertex class in main.py

* set_headers in vertex llm base

* add types for vertex embedding requests

* add embedding handler for vertex

* use async mode for vertex embedding tests

* use vertexAI textEmbeddingConfig

* fix linting

* add sync and async mode testing for vertex ai embeddings
This commit is contained in:
Ishaan Jaff 2024-10-01 14:16:21 -07:00 committed by GitHub
parent 18a28ef977
commit eef9bad9a6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 497 additions and 300 deletions

View file

@ -918,9 +918,13 @@ from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gem
GoogleAIStudioGeminiConfig, GoogleAIStudioGeminiConfig,
VertexAIConfig, VertexAIConfig,
) )
from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler import (
from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.transformation import (
VertexAITextEmbeddingConfig, VertexAITextEmbeddingConfig,
) )
vertexAITextEmbeddingConfig = VertexAITextEmbeddingConfig()
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import ( from .llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import (
VertexAIAnthropicConfig, VertexAIAnthropicConfig,
) )

View file

@ -3,311 +3,234 @@ import os
import types import types
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
import httpx
from pydantic import BaseModel from pydantic import BaseModel
import litellm import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObject
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_non_gemini import ( from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_non_gemini import (
VertexAIError, VertexAIError,
) )
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_llm_base import VertexBase
from litellm.types.llms.vertex_ai import * from litellm.types.llms.vertex_ai import *
from litellm.utils import Usage from litellm.utils import Usage
from .transformation import VertexAITextEmbeddingConfig
from .types import *
class VertexAITextEmbeddingConfig(BaseModel):
"""
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#TextEmbeddingInput
Args:
auto_truncate: Optional(bool) If True, will truncate input text to fit within the model's max input length.
task_type: Optional(str) The type of task to be performed. The default is "RETRIEVAL_QUERY".
title: Optional(str) The title of the document to be embedded. (only valid with task_type=RETRIEVAL_DOCUMENT).
"""
auto_truncate: Optional[bool] = None
task_type: Optional[
Literal[
"RETRIEVAL_QUERY",
"RETRIEVAL_DOCUMENT",
"SEMANTIC_SIMILARITY",
"CLASSIFICATION",
"CLUSTERING",
"QUESTION_ANSWERING",
"FACT_VERIFICATION",
]
] = None
title: Optional[str] = None
def __init__(
self,
auto_truncate: Optional[bool] = None,
task_type: Optional[
Literal[
"RETRIEVAL_QUERY",
"RETRIEVAL_DOCUMENT",
"SEMANTIC_SIMILARITY",
"CLASSIFICATION",
"CLUSTERING",
"QUESTION_ANSWERING",
"FACT_VERIFICATION",
]
] = None,
title: Optional[str] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return ["dimensions"]
def map_openai_params(
self, non_default_params: dict, optional_params: dict, kwargs: dict
):
for param, value in non_default_params.items():
if param == "dimensions":
optional_params["output_dimensionality"] = value
if "input_type" in kwargs:
optional_params["task_type"] = kwargs.pop("input_type")
return optional_params, kwargs
def get_mapped_special_auth_params(self) -> dict:
"""
Common auth params across bedrock/vertex_ai/azure/watsonx
"""
return {"project": "vertex_project", "region_name": "vertex_location"}
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
mapped_params = self.get_mapped_special_auth_params()
for param, value in non_default_params.items():
if param in mapped_params:
optional_params[mapped_params[param]] = value
return optional_params
class VertexEmbedding(VertexBase):
def __init__(self) -> None:
super().__init__()
def embedding( def embedding(
self,
model: str, model: str,
input: Union[list, str], input: Union[list, str],
print_verbose, print_verbose,
model_response: litellm.EmbeddingResponse, model_response: litellm.EmbeddingResponse,
optional_params: dict, optional_params: dict,
logging_obj: LiteLLMLoggingObject,
custom_llm_provider: Literal[
"vertex_ai", "vertex_ai_beta", "gemini"
], # if it's vertex_ai or gemini (google ai studio)
timeout: Optional[Union[float, httpx.Timeout]],
api_key: Optional[str] = None, api_key: Optional[str] = None,
logging_obj=None,
encoding=None, encoding=None,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
aembedding=False, aembedding=False,
api_base: Optional[str] = None,
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
vertex_project: Optional[str] = None,
vertex_location: Optional[str] = None,
vertex_credentials: Optional[str] = None,
gemini_api_key: Optional[str] = None,
extra_headers: Optional[dict] = None,
): ):
# logic for parsing in - calling - parsing out model embedding calls
try:
import vertexai
except:
raise VertexAIError(
status_code=400,
message="vertexai import failed please run `pip install google-cloud-aiplatform`",
)
import google.auth # type: ignore
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
try:
print_verbose(
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
)
if vertex_credentials is not None and isinstance(vertex_credentials, str):
import google.oauth2.service_account
json_obj = json.loads(vertex_credentials)
creds = google.oauth2.service_account.Credentials.from_service_account_info(
json_obj,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
else:
creds, _ = google.auth.default(quota_project_id=vertex_project)
print_verbose(
f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"
)
vertexai.init(
project=vertex_project, location=vertex_location, credentials=creds
)
except Exception as e:
raise VertexAIError(status_code=401, message=str(e))
if isinstance(input, str):
input = [input]
if optional_params is not None and isinstance(optional_params, dict):
if optional_params.get("task_type") or optional_params.get("title"):
# if user passed task_type or title, cast to TextEmbeddingInput
_task_type = optional_params.pop("task_type", None)
_title = optional_params.pop("title", None)
input = [
TextEmbeddingInput(text=x, task_type=_task_type, title=_title)
for x in input
]
try:
llm_model = TextEmbeddingModel.from_pretrained(model)
except Exception as e:
raise VertexAIError(status_code=422, message=str(e))
if aembedding == True: if aembedding == True:
return async_embedding( return self.async_embedding(
model=model, model=model,
client=llm_model,
input=input, input=input,
logging_obj=logging_obj, logging_obj=logging_obj,
model_response=model_response, model_response=model_response,
optional_params=optional_params, optional_params=optional_params,
encoding=encoding, encoding=encoding,
custom_llm_provider=custom_llm_provider,
timeout=timeout,
api_base=api_base,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
gemini_api_key=gemini_api_key,
extra_headers=extra_headers,
) )
_input_dict = {"texts": input, **optional_params} should_use_v1beta1_features = self.is_using_v1beta1_features(
request_str = f"""embeddings = llm_model.get_embeddings({_input_dict})""" optional_params=optional_params
## LOGGING PRE-CALL )
_auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider=custom_llm_provider,
)
auth_header, api_base = self._get_token_and_url(
model=model,
gemini_api_key=gemini_api_key,
auth_header=_auth_header,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=False,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
should_use_v1beta1_features=should_use_v1beta1_features,
mode="embedding",
)
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
vertex_request: VertexEmbeddingRequest = (
litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request(
input=input, optional_params=optional_params
)
)
_client_params = {}
if timeout:
_client_params["timeout"] = timeout
if client is None or not isinstance(client, HTTPHandler):
client = _get_httpx_client(params=_client_params)
else:
client = client # type: ignore
## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=input, input=vertex_request,
api_key=None, api_key="",
additional_args={ additional_args={
"complete_input_dict": optional_params, "complete_input_dict": vertex_request,
"request_str": request_str, "api_base": api_base,
"headers": headers,
}, },
) )
try: try:
embeddings = llm_model.get_embeddings(**_input_dict) response = client.post(api_base, headers=headers, json=vertex_request) # type: ignore
except Exception as e: response.raise_for_status()
raise VertexAIError(status_code=500, message=str(e)) except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise VertexAIError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise VertexAIError(status_code=408, message="Timeout error occurred.")
_json_response = response.json()
## LOGGING POST-CALL ## LOGGING POST-CALL
logging_obj.post_call(input=input, api_key=None, original_response=embeddings) logging_obj.post_call(
## Populate OpenAI compliant dictionary input=input, api_key=None, original_response=_json_response
embedding_response = []
input_tokens: int = 0
for idx, embedding in enumerate(embeddings):
embedding_response.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding.values,
}
) )
input_tokens += embedding.statistics.token_count # type: ignore
model_response.object = "list"
model_response.data = embedding_response
model_response.model = model
usage = Usage( model_response = (
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens litellm.vertexAITextEmbeddingConfig.transform_vertex_response_to_openai(
response=_json_response, model=model, model_response=model_response
)
) )
setattr(model_response, "usage", usage)
return model_response return model_response
async def async_embedding( async def async_embedding(
self,
model: str, model: str,
input: Union[list, str], input: Union[list, str],
model_response: litellm.EmbeddingResponse, model_response: litellm.EmbeddingResponse,
logging_obj=None, logging_obj: LiteLLMLoggingObject,
optional_params=None, optional_params: dict,
custom_llm_provider: Literal[
"vertex_ai", "vertex_ai_beta", "gemini"
], # if it's vertex_ai or gemini (google ai studio)
timeout: Optional[Union[float, httpx.Timeout]],
api_base: Optional[str] = None,
client: Optional[AsyncHTTPHandler] = None,
vertex_project: Optional[str] = None,
vertex_location: Optional[str] = None,
vertex_credentials: Optional[str] = None,
gemini_api_key: Optional[str] = None,
extra_headers: Optional[dict] = None,
encoding=None, encoding=None,
client=None, ) -> litellm.EmbeddingResponse:
):
""" """
Async embedding implementation Async embedding implementation
""" """
_input_dict = {"texts": input, **optional_params} should_use_v1beta1_features = self.is_using_v1beta1_features(
request_str = f"""embeddings = llm_model.get_embeddings({_input_dict})""" optional_params=optional_params
## LOGGING PRE-CALL )
_auth_header, vertex_project = await self._ensure_access_token_async(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider=custom_llm_provider,
)
auth_header, api_base = self._get_token_and_url(
model=model,
gemini_api_key=gemini_api_key,
auth_header=_auth_header,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=False,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
should_use_v1beta1_features=should_use_v1beta1_features,
mode="embedding",
)
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
vertex_request: VertexEmbeddingRequest = (
litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request(
input=input, optional_params=optional_params
)
)
_async_client_params = {}
if timeout:
_async_client_params["timeout"] = timeout
if client is None or not isinstance(client, AsyncHTTPHandler):
client = get_async_httpx_client(
params=_async_client_params, llm_provider=litellm.LlmProviders.VERTEX_AI
)
else:
client = client # type: ignore
## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=input, input=vertex_request,
api_key=None, api_key="",
additional_args={ additional_args={
"complete_input_dict": optional_params, "complete_input_dict": vertex_request,
"request_str": request_str, "api_base": api_base,
"headers": headers,
}, },
) )
try: try:
embeddings = await client.get_embeddings_async(**_input_dict) response = await client.post(api_base, headers=headers, json=vertex_request) # type: ignore
except Exception as e: response.raise_for_status()
raise VertexAIError(status_code=500, message=str(e)) except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise VertexAIError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise VertexAIError(status_code=408, message="Timeout error occurred.")
_json_response = response.json()
## LOGGING POST-CALL ## LOGGING POST-CALL
logging_obj.post_call(input=input, api_key=None, original_response=embeddings) logging_obj.post_call(
## Populate OpenAI compliant dictionary input=input, api_key=None, original_response=_json_response
embedding_response = []
input_tokens: int = 0
for idx, embedding in enumerate(embeddings):
embedding_response.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding.values,
}
) )
input_tokens += embedding.statistics.token_count
model_response.object = "list" model_response = (
model_response.data = embedding_response litellm.vertexAITextEmbeddingConfig.transform_vertex_response_to_openai(
model_response.model = model response=_json_response, model=model, model_response=model_response
usage = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
) )
setattr(model_response, "usage", usage) )
return model_response
async def transform_vertex_response_to_openai(
response: dict, model: str, model_response: litellm.EmbeddingResponse
) -> litellm.EmbeddingResponse:
_predictions = response["predictions"]
embedding_response = []
input_tokens: int = 0
for idx, element in enumerate(_predictions):
embedding = element["embeddings"]
embedding_response.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding["values"],
}
)
input_tokens += embedding["statistics"]["token_count"]
model_response.object = "list"
model_response.data = embedding_response
model_response.model = model
usage = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
)
setattr(model_response, "usage", usage)
return model_response return model_response

View file

@ -0,0 +1,183 @@
import types
from typing import List, Literal, Optional, Union
from pydantic import BaseModel
import litellm
from litellm.utils import Usage
from .types import *
class VertexAITextEmbeddingConfig(BaseModel):
"""
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#TextEmbeddingInput
Args:
auto_truncate: Optional(bool) If True, will truncate input text to fit within the model's max input length.
task_type: Optional(str) The type of task to be performed. The default is "RETRIEVAL_QUERY".
title: Optional(str) The title of the document to be embedded. (only valid with task_type=RETRIEVAL_DOCUMENT).
"""
auto_truncate: Optional[bool] = None
task_type: Optional[
Literal[
"RETRIEVAL_QUERY",
"RETRIEVAL_DOCUMENT",
"SEMANTIC_SIMILARITY",
"CLASSIFICATION",
"CLUSTERING",
"QUESTION_ANSWERING",
"FACT_VERIFICATION",
]
] = None
title: Optional[str] = None
def __init__(
self,
auto_truncate: Optional[bool] = None,
task_type: Optional[
Literal[
"RETRIEVAL_QUERY",
"RETRIEVAL_DOCUMENT",
"SEMANTIC_SIMILARITY",
"CLASSIFICATION",
"CLUSTERING",
"QUESTION_ANSWERING",
"FACT_VERIFICATION",
]
] = None,
title: Optional[str] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return ["dimensions"]
def map_openai_params(
self, non_default_params: dict, optional_params: dict, kwargs: dict
):
for param, value in non_default_params.items():
if param == "dimensions":
optional_params["output_dimensionality"] = value
if "input_type" in kwargs:
optional_params["task_type"] = kwargs.pop("input_type")
return optional_params, kwargs
def get_mapped_special_auth_params(self) -> dict:
"""
Common auth params across bedrock/vertex_ai/azure/watsonx
"""
return {"project": "vertex_project", "region_name": "vertex_location"}
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
mapped_params = self.get_mapped_special_auth_params()
for param, value in non_default_params.items():
if param in mapped_params:
optional_params[mapped_params[param]] = value
return optional_params
def transform_openai_request_to_vertex_embedding_request(
self, input: Union[list, str], optional_params: dict
) -> VertexEmbeddingRequest:
"""
Transforms an openai request to a vertex embedding request.
"""
vertex_request: VertexEmbeddingRequest = VertexEmbeddingRequest()
vertex_text_embedding_input_list: List[TextEmbeddingInput] = []
task_type: Optional[TaskType] = optional_params.get("task_type")
title = optional_params.get("title")
if isinstance(input, str):
input = [input] # Convert single string to list for uniform processing
for text in input:
embedding_input = self.create_embedding_input(
content=text, task_type=task_type, title=title
)
vertex_text_embedding_input_list.append(embedding_input)
vertex_request["instances"] = vertex_text_embedding_input_list
vertex_request["parameters"] = EmbeddingParameters(**optional_params)
return vertex_request
def create_embedding_input(
self,
content: str,
task_type: Optional[TaskType] = None,
title: Optional[str] = None,
) -> TextEmbeddingInput:
"""
Creates a TextEmbeddingInput object.
Vertex requires a List of TextEmbeddingInput objects. This helper function creates a single TextEmbeddingInput object.
Args:
content (str): The content to be embedded.
task_type (Optional[TaskType]): The type of task to be performed".
title (Optional[str]): The title of the document to be embedded
Returns:
TextEmbeddingInput: A TextEmbeddingInput object.
"""
text_embedding_input = TextEmbeddingInput(content=content)
if task_type is not None:
text_embedding_input["task_type"] = task_type
if title is not None:
text_embedding_input["title"] = title
return text_embedding_input
def transform_vertex_response_to_openai(
self, response: dict, model: str, model_response: litellm.EmbeddingResponse
) -> litellm.EmbeddingResponse:
"""
Transforms a vertex embedding response to an openai response.
"""
_predictions = response["predictions"]
embedding_response = []
input_tokens: int = 0
for idx, element in enumerate(_predictions):
embedding = element["embeddings"]
embedding_response.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding["values"],
}
)
input_tokens += embedding["statistics"]["token_count"]
model_response.object = "list"
model_response.data = embedding_response
model_response.model = model
usage = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
)
setattr(model_response, "usage", usage)
return model_response

View file

@ -0,0 +1,49 @@
"""
Types for Vertex Embeddings Requests
"""
from enum import Enum
from typing import List, Literal, Optional, TypedDict, Union
class TaskType(str, Enum):
RETRIEVAL_QUERY = "RETRIEVAL_QUERY"
RETRIEVAL_DOCUMENT = "RETRIEVAL_DOCUMENT"
SEMANTIC_SIMILARITY = "SEMANTIC_SIMILARITY"
CLASSIFICATION = "CLASSIFICATION"
CLUSTERING = "CLUSTERING"
QUESTION_ANSWERING = "QUESTION_ANSWERING"
FACT_VERIFICATION = "FACT_VERIFICATION"
CODE_RETRIEVAL_QUERY = "CODE_RETRIEVAL_QUERY"
class TextEmbeddingInput(TypedDict, total=False):
content: str
task_type: Optional[TaskType]
title: Optional[str]
class EmbeddingParameters(TypedDict, total=False):
auto_truncate: Optional[bool]
output_dimensionality: Optional[int]
class VertexEmbeddingRequest(TypedDict, total=False):
instances: List[TextEmbeddingInput]
parameters: Optional[EmbeddingParameters]
# Example usage:
# example_request: VertexEmbeddingRequest = {
# "instances": [
# {
# "content": "I would like embeddings for this text!",
# "task_type": "RETRIEVAL_DOCUMENT",
# "title": "document title"
# }
# ],
# "parameters": {
# "auto_truncate": True,
# "output_dimensionality": None
# }
# }

View file

@ -303,3 +303,16 @@ class VertexBase(BaseLLM):
raise RuntimeError("Could not resolve API token from the environment") raise RuntimeError("Could not resolve API token from the environment")
return self._credentials.token, project_id or self.project_id return self._credentials.token, project_id or self.project_id
def set_headers(
self, auth_header: Optional[str], extra_headers: Optional[dict]
) -> dict:
headers = {
"Content-Type": "application/json",
}
if auth_header is not None:
headers["Authorization"] = f"Bearer {auth_header}"
if extra_headers is not None:
headers.update(extra_headers)
return headers

View file

@ -134,8 +134,8 @@ from .llms.vertex_ai_and_google_ai_studio.text_to_speech.text_to_speech_handler
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import ( from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import (
VertexAIPartnerModels, VertexAIPartnerModels,
) )
from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings import ( from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler import (
embedding_handler as vertex_ai_embedding_handler, VertexEmbedding,
) )
from .llms.watsonx import IBMWatsonXAI from .llms.watsonx import IBMWatsonXAI
from .types.llms.openai import HttpxBinaryResponseContent from .types.llms.openai import HttpxBinaryResponseContent
@ -185,6 +185,7 @@ bedrock_chat_completion = BedrockLLM()
bedrock_converse_chat_completion = BedrockConverseLLM() bedrock_converse_chat_completion = BedrockConverseLLM()
bedrock_embedding = BedrockEmbedding() bedrock_embedding = BedrockEmbedding()
vertex_chat_completion = VertexLLM() vertex_chat_completion = VertexLLM()
vertex_embedding = VertexEmbedding()
vertex_multimodal_embedding = VertexMultimodalEmbedding() vertex_multimodal_embedding = VertexMultimodalEmbedding()
vertex_image_generation = VertexImageGeneration() vertex_image_generation = VertexImageGeneration()
google_batch_embeddings = GoogleBatchEmbeddings() google_batch_embeddings = GoogleBatchEmbeddings()
@ -3711,21 +3712,21 @@ def embedding(
optional_params.pop("vertex_project", None) optional_params.pop("vertex_project", None)
or optional_params.pop("vertex_ai_project", None) or optional_params.pop("vertex_ai_project", None)
or litellm.vertex_project or litellm.vertex_project
or get_secret("VERTEXAI_PROJECT") or get_secret_str("VERTEXAI_PROJECT")
or get_secret("VERTEX_PROJECT") or get_secret_str("VERTEX_PROJECT")
) )
vertex_ai_location = ( vertex_ai_location = (
optional_params.pop("vertex_location", None) optional_params.pop("vertex_location", None)
or optional_params.pop("vertex_ai_location", None) or optional_params.pop("vertex_ai_location", None)
or litellm.vertex_location or litellm.vertex_location
or get_secret("VERTEXAI_LOCATION") or get_secret_str("VERTEXAI_LOCATION")
or get_secret("VERTEX_LOCATION") or get_secret_str("VERTEX_LOCATION")
) )
vertex_credentials = ( vertex_credentials = (
optional_params.pop("vertex_credentials", None) optional_params.pop("vertex_credentials", None)
or optional_params.pop("vertex_ai_credentials", None) or optional_params.pop("vertex_ai_credentials", None)
or get_secret("VERTEXAI_CREDENTIALS") or get_secret_str("VERTEXAI_CREDENTIALS")
or get_secret("VERTEX_CREDENTIALS") or get_secret_str("VERTEX_CREDENTIALS")
) )
if ( if (
@ -3750,7 +3751,7 @@ def embedding(
custom_llm_provider="vertex_ai", custom_llm_provider="vertex_ai",
) )
else: else:
response = vertex_ai_embedding_handler.embedding( response = vertex_embedding.embedding(
model=model, model=model,
input=input, input=input,
encoding=encoding, encoding=encoding,
@ -3760,6 +3761,8 @@ def embedding(
vertex_project=vertex_ai_project, vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location, vertex_location=vertex_ai_location,
vertex_credentials=vertex_credentials, vertex_credentials=vertex_credentials,
custom_llm_provider="vertex_ai",
timeout=timeout,
aembedding=aembedding, aembedding=aembedding,
print_verbose=print_verbose, print_verbose=print_verbose,
) )

View file

@ -129,9 +129,6 @@ class PassThroughEndpointLogging:
from litellm.llms.vertex_ai_and_google_ai_studio.image_generation.image_generation_handler import ( from litellm.llms.vertex_ai_and_google_ai_studio.image_generation.image_generation_handler import (
VertexImageGeneration, VertexImageGeneration,
) )
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler import (
transform_vertex_response_to_openai,
)
from litellm.types.utils import PassthroughCallTypes from litellm.types.utils import PassthroughCallTypes
vertex_image_generation_class = VertexImageGeneration() vertex_image_generation_class = VertexImageGeneration()
@ -157,7 +154,7 @@ class PassThroughEndpointLogging:
PassthroughCallTypes.passthrough_image_generation.value PassthroughCallTypes.passthrough_image_generation.value
) )
else: else:
litellm_prediction_response = await transform_vertex_response_to_openai( litellm_prediction_response = litellm.vertexAITextEmbeddingConfig.transform_vertex_response_to_openai(
response=_json_response, response=_json_response,
model=model, model=model,
model_response=litellm.EmbeddingResponse(), model_response=litellm.EmbeddingResponse(),

View file

@ -1861,15 +1861,40 @@ async def test_gemini_pro_async_function_calling():
@pytest.mark.flaky(retries=3, delay=1) @pytest.mark.flaky(retries=3, delay=1)
def test_vertexai_embedding(): @pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_vertexai_embedding(sync_mode):
try: try:
load_vertex_ai_credentials() load_vertex_ai_credentials()
# litellm.set_verbose = True litellm.set_verbose = True
response = embedding(
model="textembedding-gecko@001", input_text = ["good morning from litellm", "this is another item"]
input=["good morning from litellm", "this is another item"],
if sync_mode:
response = litellm.embedding(
model="textembedding-gecko@001", input=input_text
) )
print(f"response:", response) else:
response = await litellm.aembedding(
model="textembedding-gecko@001", input=input_text
)
print(f"response: {response}")
# Assert that the response is not None
assert response is not None
# Assert that the response contains embeddings
assert hasattr(response, "data")
assert len(response.data) == len(input_text)
# Assert that each embedding is a non-empty list of floats
for embedding in response.data:
assert "embedding" in embedding
assert isinstance(embedding["embedding"], list)
assert len(embedding["embedding"]) > 0
assert all(isinstance(x, float) for x in embedding["embedding"])
except litellm.RateLimitError as e: except litellm.RateLimitError as e:
pass pass
except Exception as e: except Exception as e: