diff --git a/litellm/__init__.py b/litellm/__init__.py index c7f290754..880e2ed04 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -918,9 +918,13 @@ from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gem GoogleAIStudioGeminiConfig, VertexAIConfig, ) -from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler import ( + +from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.transformation import ( VertexAITextEmbeddingConfig, ) + +vertexAITextEmbeddingConfig = VertexAITextEmbeddingConfig() + from .llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import ( VertexAIAnthropicConfig, ) diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/embedding_handler.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/embedding_handler.py index 5638c58cd..4e6fb2b59 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/embedding_handler.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/embedding_handler.py @@ -3,311 +3,234 @@ import os import types from typing import Literal, Optional, Union +import httpx from pydantic import BaseModel import litellm from litellm._logging import verbose_logger +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObject +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, + _get_httpx_client, + get_async_httpx_client, +) from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_non_gemini import ( VertexAIError, ) +from litellm.llms.vertex_ai_and_google_ai_studio.vertex_llm_base import VertexBase from litellm.types.llms.vertex_ai import * from litellm.utils import Usage +from .transformation import VertexAITextEmbeddingConfig +from .types import * -class VertexAITextEmbeddingConfig(BaseModel): - """ - Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#TextEmbeddingInput - Args: - auto_truncate: Optional(bool) If True, will truncate input text to fit within the model's max input length. - task_type: Optional(str) The type of task to be performed. The default is "RETRIEVAL_QUERY". - title: Optional(str) The title of the document to be embedded. (only valid with task_type=RETRIEVAL_DOCUMENT). - """ +class VertexEmbedding(VertexBase): + def __init__(self) -> None: + super().__init__() - auto_truncate: Optional[bool] = None - task_type: Optional[ - Literal[ - "RETRIEVAL_QUERY", - "RETRIEVAL_DOCUMENT", - "SEMANTIC_SIMILARITY", - "CLASSIFICATION", - "CLUSTERING", - "QUESTION_ANSWERING", - "FACT_VERIFICATION", - ] - ] = None - title: Optional[str] = None - - def __init__( + def embedding( self, - auto_truncate: Optional[bool] = None, - task_type: Optional[ - Literal[ - "RETRIEVAL_QUERY", - "RETRIEVAL_DOCUMENT", - "SEMANTIC_SIMILARITY", - "CLASSIFICATION", - "CLUSTERING", - "QUESTION_ANSWERING", - "FACT_VERIFICATION", - ] - ] = None, - title: Optional[str] = None, - ) -> None: - locals_ = locals() - for key, value in locals_.items(): - if key != "self" and value is not None: - setattr(self.__class__, key, value) - - @classmethod - def get_config(cls): - return { - k: v - for k, v in cls.__dict__.items() - if not k.startswith("__") - and not isinstance( - v, - ( - types.FunctionType, - types.BuiltinFunctionType, - classmethod, - staticmethod, - ), - ) - and v is not None - } - - def get_supported_openai_params(self): - return ["dimensions"] - - def map_openai_params( - self, non_default_params: dict, optional_params: dict, kwargs: dict + model: str, + input: Union[list, str], + print_verbose, + model_response: litellm.EmbeddingResponse, + optional_params: dict, + logging_obj: LiteLLMLoggingObject, + custom_llm_provider: Literal[ + "vertex_ai", "vertex_ai_beta", "gemini" + ], # if it's vertex_ai or gemini (google ai studio) + timeout: Optional[Union[float, httpx.Timeout]], + api_key: Optional[str] = None, + encoding=None, + aembedding=False, + api_base: Optional[str] = None, + client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None, + vertex_project: Optional[str] = None, + vertex_location: Optional[str] = None, + vertex_credentials: Optional[str] = None, + gemini_api_key: Optional[str] = None, + extra_headers: Optional[dict] = None, ): - for param, value in non_default_params.items(): - if param == "dimensions": - optional_params["output_dimensionality"] = value + if aembedding == True: + return self.async_embedding( + model=model, + input=input, + logging_obj=logging_obj, + model_response=model_response, + optional_params=optional_params, + encoding=encoding, + custom_llm_provider=custom_llm_provider, + timeout=timeout, + api_base=api_base, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_credentials=vertex_credentials, + gemini_api_key=gemini_api_key, + extra_headers=extra_headers, + ) - if "input_type" in kwargs: - optional_params["task_type"] = kwargs.pop("input_type") - return optional_params, kwargs - - def get_mapped_special_auth_params(self) -> dict: - """ - Common auth params across bedrock/vertex_ai/azure/watsonx - """ - return {"project": "vertex_project", "region_name": "vertex_location"} - - def map_special_auth_params(self, non_default_params: dict, optional_params: dict): - mapped_params = self.get_mapped_special_auth_params() - - for param, value in non_default_params.items(): - if param in mapped_params: - optional_params[mapped_params[param]] = value - return optional_params - - -def embedding( - model: str, - input: Union[list, str], - print_verbose, - model_response: litellm.EmbeddingResponse, - optional_params: dict, - api_key: Optional[str] = None, - logging_obj=None, - encoding=None, - vertex_project=None, - vertex_location=None, - vertex_credentials=None, - aembedding=False, -): - # logic for parsing in - calling - parsing out model embedding calls - try: - import vertexai - except: - raise VertexAIError( - status_code=400, - message="vertexai import failed please run `pip install google-cloud-aiplatform`", + should_use_v1beta1_features = self.is_using_v1beta1_features( + optional_params=optional_params ) - import google.auth # type: ignore - from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel - - ## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744 - try: - print_verbose( - f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}" + _auth_header, vertex_project = self._ensure_access_token( + credentials=vertex_credentials, + project_id=vertex_project, + custom_llm_provider=custom_llm_provider, + ) + auth_header, api_base = self._get_token_and_url( + model=model, + gemini_api_key=gemini_api_key, + auth_header=_auth_header, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_credentials=vertex_credentials, + stream=False, + custom_llm_provider=custom_llm_provider, + api_base=api_base, + should_use_v1beta1_features=should_use_v1beta1_features, + mode="embedding", + ) + headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers) + vertex_request: VertexEmbeddingRequest = ( + litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request( + input=input, optional_params=optional_params + ) ) - if vertex_credentials is not None and isinstance(vertex_credentials, str): - import google.oauth2.service_account - json_obj = json.loads(vertex_credentials) + _client_params = {} + if timeout: + _client_params["timeout"] = timeout + if client is None or not isinstance(client, HTTPHandler): + client = _get_httpx_client(params=_client_params) + else: + client = client # type: ignore + ## LOGGING + logging_obj.pre_call( + input=vertex_request, + api_key="", + additional_args={ + "complete_input_dict": vertex_request, + "api_base": api_base, + "headers": headers, + }, + ) - creds = google.oauth2.service_account.Credentials.from_service_account_info( - json_obj, - scopes=["https://www.googleapis.com/auth/cloud-platform"], + try: + response = client.post(api_base, headers=headers, json=vertex_request) # type: ignore + response.raise_for_status() + except httpx.HTTPStatusError as err: + error_code = err.response.status_code + raise VertexAIError(status_code=error_code, message=err.response.text) + except httpx.TimeoutException: + raise VertexAIError(status_code=408, message="Timeout error occurred.") + + _json_response = response.json() + ## LOGGING POST-CALL + logging_obj.post_call( + input=input, api_key=None, original_response=_json_response + ) + + model_response = ( + litellm.vertexAITextEmbeddingConfig.transform_vertex_response_to_openai( + response=_json_response, model=model, model_response=model_response + ) + ) + + return model_response + + async def async_embedding( + self, + model: str, + input: Union[list, str], + model_response: litellm.EmbeddingResponse, + logging_obj: LiteLLMLoggingObject, + optional_params: dict, + custom_llm_provider: Literal[ + "vertex_ai", "vertex_ai_beta", "gemini" + ], # if it's vertex_ai or gemini (google ai studio) + timeout: Optional[Union[float, httpx.Timeout]], + api_base: Optional[str] = None, + client: Optional[AsyncHTTPHandler] = None, + vertex_project: Optional[str] = None, + vertex_location: Optional[str] = None, + vertex_credentials: Optional[str] = None, + gemini_api_key: Optional[str] = None, + extra_headers: Optional[dict] = None, + encoding=None, + ) -> litellm.EmbeddingResponse: + """ + Async embedding implementation + """ + should_use_v1beta1_features = self.is_using_v1beta1_features( + optional_params=optional_params + ) + _auth_header, vertex_project = await self._ensure_access_token_async( + credentials=vertex_credentials, + project_id=vertex_project, + custom_llm_provider=custom_llm_provider, + ) + auth_header, api_base = self._get_token_and_url( + model=model, + gemini_api_key=gemini_api_key, + auth_header=_auth_header, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_credentials=vertex_credentials, + stream=False, + custom_llm_provider=custom_llm_provider, + api_base=api_base, + should_use_v1beta1_features=should_use_v1beta1_features, + mode="embedding", + ) + headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers) + vertex_request: VertexEmbeddingRequest = ( + litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request( + input=input, optional_params=optional_params + ) + ) + + _async_client_params = {} + if timeout: + _async_client_params["timeout"] = timeout + if client is None or not isinstance(client, AsyncHTTPHandler): + client = get_async_httpx_client( + params=_async_client_params, llm_provider=litellm.LlmProviders.VERTEX_AI ) else: - creds, _ = google.auth.default(quota_project_id=vertex_project) - print_verbose( - f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}" - ) - vertexai.init( - project=vertex_project, location=vertex_location, credentials=creds - ) - except Exception as e: - raise VertexAIError(status_code=401, message=str(e)) - - if isinstance(input, str): - input = [input] - - if optional_params is not None and isinstance(optional_params, dict): - if optional_params.get("task_type") or optional_params.get("title"): - # if user passed task_type or title, cast to TextEmbeddingInput - _task_type = optional_params.pop("task_type", None) - _title = optional_params.pop("title", None) - input = [ - TextEmbeddingInput(text=x, task_type=_task_type, title=_title) - for x in input - ] - - try: - llm_model = TextEmbeddingModel.from_pretrained(model) - except Exception as e: - raise VertexAIError(status_code=422, message=str(e)) - - if aembedding == True: - return async_embedding( - model=model, - client=llm_model, - input=input, - logging_obj=logging_obj, - model_response=model_response, - optional_params=optional_params, - encoding=encoding, + client = client # type: ignore + ## LOGGING + logging_obj.pre_call( + input=vertex_request, + api_key="", + additional_args={ + "complete_input_dict": vertex_request, + "api_base": api_base, + "headers": headers, + }, ) - _input_dict = {"texts": input, **optional_params} - request_str = f"""embeddings = llm_model.get_embeddings({_input_dict})""" - ## LOGGING PRE-CALL - logging_obj.pre_call( - input=input, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) + try: + response = await client.post(api_base, headers=headers, json=vertex_request) # type: ignore + response.raise_for_status() + except httpx.HTTPStatusError as err: + error_code = err.response.status_code + raise VertexAIError(status_code=error_code, message=err.response.text) + except httpx.TimeoutException: + raise VertexAIError(status_code=408, message="Timeout error occurred.") - try: - embeddings = llm_model.get_embeddings(**_input_dict) - except Exception as e: - raise VertexAIError(status_code=500, message=str(e)) - - ## LOGGING POST-CALL - logging_obj.post_call(input=input, api_key=None, original_response=embeddings) - ## Populate OpenAI compliant dictionary - embedding_response = [] - input_tokens: int = 0 - for idx, embedding in enumerate(embeddings): - embedding_response.append( - { - "object": "embedding", - "index": idx, - "embedding": embedding.values, - } + _json_response = response.json() + ## LOGGING POST-CALL + logging_obj.post_call( + input=input, api_key=None, original_response=_json_response ) - input_tokens += embedding.statistics.token_count # type: ignore - model_response.object = "list" - model_response.data = embedding_response - model_response.model = model - usage = Usage( - prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens - ) - setattr(model_response, "usage", usage) - - return model_response - - -async def async_embedding( - model: str, - input: Union[list, str], - model_response: litellm.EmbeddingResponse, - logging_obj=None, - optional_params=None, - encoding=None, - client=None, -): - """ - Async embedding implementation - """ - _input_dict = {"texts": input, **optional_params} - request_str = f"""embeddings = llm_model.get_embeddings({_input_dict})""" - ## LOGGING PRE-CALL - logging_obj.pre_call( - input=input, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - - try: - embeddings = await client.get_embeddings_async(**_input_dict) - except Exception as e: - raise VertexAIError(status_code=500, message=str(e)) - - ## LOGGING POST-CALL - logging_obj.post_call(input=input, api_key=None, original_response=embeddings) - ## Populate OpenAI compliant dictionary - embedding_response = [] - input_tokens: int = 0 - for idx, embedding in enumerate(embeddings): - embedding_response.append( - { - "object": "embedding", - "index": idx, - "embedding": embedding.values, - } + model_response = ( + litellm.vertexAITextEmbeddingConfig.transform_vertex_response_to_openai( + response=_json_response, model=model, model_response=model_response + ) ) - input_tokens += embedding.statistics.token_count - model_response.object = "list" - model_response.data = embedding_response - model_response.model = model - usage = Usage( - prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens - ) - setattr(model_response, "usage", usage) - return model_response - - -async def transform_vertex_response_to_openai( - response: dict, model: str, model_response: litellm.EmbeddingResponse -) -> litellm.EmbeddingResponse: - - _predictions = response["predictions"] - - embedding_response = [] - input_tokens: int = 0 - for idx, element in enumerate(_predictions): - - embedding = element["embeddings"] - embedding_response.append( - { - "object": "embedding", - "index": idx, - "embedding": embedding["values"], - } - ) - input_tokens += embedding["statistics"]["token_count"] - - model_response.object = "list" - model_response.data = embedding_response - model_response.model = model - usage = Usage( - prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens - ) - setattr(model_response, "usage", usage) - return model_response + return model_response diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/transformation.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/transformation.py new file mode 100644 index 000000000..1ca405392 --- /dev/null +++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/transformation.py @@ -0,0 +1,183 @@ +import types +from typing import List, Literal, Optional, Union + +from pydantic import BaseModel + +import litellm +from litellm.utils import Usage + +from .types import * + + +class VertexAITextEmbeddingConfig(BaseModel): + """ + Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#TextEmbeddingInput + + Args: + auto_truncate: Optional(bool) If True, will truncate input text to fit within the model's max input length. + task_type: Optional(str) The type of task to be performed. The default is "RETRIEVAL_QUERY". + title: Optional(str) The title of the document to be embedded. (only valid with task_type=RETRIEVAL_DOCUMENT). + """ + + auto_truncate: Optional[bool] = None + task_type: Optional[ + Literal[ + "RETRIEVAL_QUERY", + "RETRIEVAL_DOCUMENT", + "SEMANTIC_SIMILARITY", + "CLASSIFICATION", + "CLUSTERING", + "QUESTION_ANSWERING", + "FACT_VERIFICATION", + ] + ] = None + title: Optional[str] = None + + def __init__( + self, + auto_truncate: Optional[bool] = None, + task_type: Optional[ + Literal[ + "RETRIEVAL_QUERY", + "RETRIEVAL_DOCUMENT", + "SEMANTIC_SIMILARITY", + "CLASSIFICATION", + "CLUSTERING", + "QUESTION_ANSWERING", + "FACT_VERIFICATION", + ] + ] = None, + title: Optional[str] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params(self): + return ["dimensions"] + + def map_openai_params( + self, non_default_params: dict, optional_params: dict, kwargs: dict + ): + for param, value in non_default_params.items(): + if param == "dimensions": + optional_params["output_dimensionality"] = value + + if "input_type" in kwargs: + optional_params["task_type"] = kwargs.pop("input_type") + return optional_params, kwargs + + def get_mapped_special_auth_params(self) -> dict: + """ + Common auth params across bedrock/vertex_ai/azure/watsonx + """ + return {"project": "vertex_project", "region_name": "vertex_location"} + + def map_special_auth_params(self, non_default_params: dict, optional_params: dict): + mapped_params = self.get_mapped_special_auth_params() + + for param, value in non_default_params.items(): + if param in mapped_params: + optional_params[mapped_params[param]] = value + return optional_params + + def transform_openai_request_to_vertex_embedding_request( + self, input: Union[list, str], optional_params: dict + ) -> VertexEmbeddingRequest: + """ + Transforms an openai request to a vertex embedding request. + """ + vertex_request: VertexEmbeddingRequest = VertexEmbeddingRequest() + vertex_text_embedding_input_list: List[TextEmbeddingInput] = [] + task_type: Optional[TaskType] = optional_params.get("task_type") + title = optional_params.get("title") + + if isinstance(input, str): + input = [input] # Convert single string to list for uniform processing + + for text in input: + embedding_input = self.create_embedding_input( + content=text, task_type=task_type, title=title + ) + vertex_text_embedding_input_list.append(embedding_input) + + vertex_request["instances"] = vertex_text_embedding_input_list + vertex_request["parameters"] = EmbeddingParameters(**optional_params) + + return vertex_request + + def create_embedding_input( + self, + content: str, + task_type: Optional[TaskType] = None, + title: Optional[str] = None, + ) -> TextEmbeddingInput: + """ + Creates a TextEmbeddingInput object. + + Vertex requires a List of TextEmbeddingInput objects. This helper function creates a single TextEmbeddingInput object. + + Args: + content (str): The content to be embedded. + task_type (Optional[TaskType]): The type of task to be performed". + title (Optional[str]): The title of the document to be embedded + + Returns: + TextEmbeddingInput: A TextEmbeddingInput object. + """ + text_embedding_input = TextEmbeddingInput(content=content) + if task_type is not None: + text_embedding_input["task_type"] = task_type + if title is not None: + text_embedding_input["title"] = title + return text_embedding_input + + def transform_vertex_response_to_openai( + self, response: dict, model: str, model_response: litellm.EmbeddingResponse + ) -> litellm.EmbeddingResponse: + """ + Transforms a vertex embedding response to an openai response. + """ + _predictions = response["predictions"] + + embedding_response = [] + input_tokens: int = 0 + for idx, element in enumerate(_predictions): + + embedding = element["embeddings"] + embedding_response.append( + { + "object": "embedding", + "index": idx, + "embedding": embedding["values"], + } + ) + input_tokens += embedding["statistics"]["token_count"] + + model_response.object = "list" + model_response.data = embedding_response + model_response.model = model + usage = Usage( + prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens + ) + setattr(model_response, "usage", usage) + return model_response diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/types.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/types.py new file mode 100644 index 000000000..311809c82 --- /dev/null +++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/types.py @@ -0,0 +1,49 @@ +""" +Types for Vertex Embeddings Requests +""" + +from enum import Enum +from typing import List, Literal, Optional, TypedDict, Union + + +class TaskType(str, Enum): + RETRIEVAL_QUERY = "RETRIEVAL_QUERY" + RETRIEVAL_DOCUMENT = "RETRIEVAL_DOCUMENT" + SEMANTIC_SIMILARITY = "SEMANTIC_SIMILARITY" + CLASSIFICATION = "CLASSIFICATION" + CLUSTERING = "CLUSTERING" + QUESTION_ANSWERING = "QUESTION_ANSWERING" + FACT_VERIFICATION = "FACT_VERIFICATION" + CODE_RETRIEVAL_QUERY = "CODE_RETRIEVAL_QUERY" + + +class TextEmbeddingInput(TypedDict, total=False): + content: str + task_type: Optional[TaskType] + title: Optional[str] + + +class EmbeddingParameters(TypedDict, total=False): + auto_truncate: Optional[bool] + output_dimensionality: Optional[int] + + +class VertexEmbeddingRequest(TypedDict, total=False): + instances: List[TextEmbeddingInput] + parameters: Optional[EmbeddingParameters] + + +# Example usage: +# example_request: VertexEmbeddingRequest = { +# "instances": [ +# { +# "content": "I would like embeddings for this text!", +# "task_type": "RETRIEVAL_DOCUMENT", +# "title": "document title" +# } +# ], +# "parameters": { +# "auto_truncate": True, +# "output_dimensionality": None +# } +# } diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_llm_base.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_llm_base.py index 740bdca5c..cf130bb14 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_llm_base.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_llm_base.py @@ -303,3 +303,16 @@ class VertexBase(BaseLLM): raise RuntimeError("Could not resolve API token from the environment") return self._credentials.token, project_id or self.project_id + + def set_headers( + self, auth_header: Optional[str], extra_headers: Optional[dict] + ) -> dict: + headers = { + "Content-Type": "application/json", + } + if auth_header is not None: + headers["Authorization"] = f"Bearer {auth_header}" + if extra_headers is not None: + headers.update(extra_headers) + + return headers diff --git a/litellm/main.py b/litellm/main.py index c681c3b6e..9d9ca5991 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -134,8 +134,8 @@ from .llms.vertex_ai_and_google_ai_studio.text_to_speech.text_to_speech_handler from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import ( VertexAIPartnerModels, ) -from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings import ( - embedding_handler as vertex_ai_embedding_handler, +from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler import ( + VertexEmbedding, ) from .llms.watsonx import IBMWatsonXAI from .types.llms.openai import HttpxBinaryResponseContent @@ -185,6 +185,7 @@ bedrock_chat_completion = BedrockLLM() bedrock_converse_chat_completion = BedrockConverseLLM() bedrock_embedding = BedrockEmbedding() vertex_chat_completion = VertexLLM() +vertex_embedding = VertexEmbedding() vertex_multimodal_embedding = VertexMultimodalEmbedding() vertex_image_generation = VertexImageGeneration() google_batch_embeddings = GoogleBatchEmbeddings() @@ -2980,7 +2981,7 @@ def batch_completion( deployment_id=None, request_timeout: Optional[int] = None, timeout: Optional[int] = 600, - max_workers:Optional[int]= 100, + max_workers: Optional[int] = 100, # Optional liteLLM function params **kwargs, ): @@ -3711,21 +3712,21 @@ def embedding( optional_params.pop("vertex_project", None) or optional_params.pop("vertex_ai_project", None) or litellm.vertex_project - or get_secret("VERTEXAI_PROJECT") - or get_secret("VERTEX_PROJECT") + or get_secret_str("VERTEXAI_PROJECT") + or get_secret_str("VERTEX_PROJECT") ) vertex_ai_location = ( optional_params.pop("vertex_location", None) or optional_params.pop("vertex_ai_location", None) or litellm.vertex_location - or get_secret("VERTEXAI_LOCATION") - or get_secret("VERTEX_LOCATION") + or get_secret_str("VERTEXAI_LOCATION") + or get_secret_str("VERTEX_LOCATION") ) vertex_credentials = ( optional_params.pop("vertex_credentials", None) or optional_params.pop("vertex_ai_credentials", None) - or get_secret("VERTEXAI_CREDENTIALS") - or get_secret("VERTEX_CREDENTIALS") + or get_secret_str("VERTEXAI_CREDENTIALS") + or get_secret_str("VERTEX_CREDENTIALS") ) if ( @@ -3750,7 +3751,7 @@ def embedding( custom_llm_provider="vertex_ai", ) else: - response = vertex_ai_embedding_handler.embedding( + response = vertex_embedding.embedding( model=model, input=input, encoding=encoding, @@ -3760,6 +3761,8 @@ def embedding( vertex_project=vertex_ai_project, vertex_location=vertex_ai_location, vertex_credentials=vertex_credentials, + custom_llm_provider="vertex_ai", + timeout=timeout, aembedding=aembedding, print_verbose=print_verbose, ) diff --git a/litellm/proxy/pass_through_endpoints/success_handler.py b/litellm/proxy/pass_through_endpoints/success_handler.py index fe46ae58c..45ba10f1c 100644 --- a/litellm/proxy/pass_through_endpoints/success_handler.py +++ b/litellm/proxy/pass_through_endpoints/success_handler.py @@ -129,9 +129,6 @@ class PassThroughEndpointLogging: from litellm.llms.vertex_ai_and_google_ai_studio.image_generation.image_generation_handler import ( VertexImageGeneration, ) - from litellm.llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler import ( - transform_vertex_response_to_openai, - ) from litellm.types.utils import PassthroughCallTypes vertex_image_generation_class = VertexImageGeneration() @@ -157,7 +154,7 @@ class PassThroughEndpointLogging: PassthroughCallTypes.passthrough_image_generation.value ) else: - litellm_prediction_response = await transform_vertex_response_to_openai( + litellm_prediction_response = litellm.vertexAITextEmbeddingConfig.transform_vertex_response_to_openai( response=_json_response, model=model, model_response=litellm.EmbeddingResponse(), diff --git a/tests/local_testing/test_amazing_vertex_completion.py b/tests/local_testing/test_amazing_vertex_completion.py index 0fed67d5f..604c2eed8 100644 --- a/tests/local_testing/test_amazing_vertex_completion.py +++ b/tests/local_testing/test_amazing_vertex_completion.py @@ -1861,15 +1861,40 @@ async def test_gemini_pro_async_function_calling(): @pytest.mark.flaky(retries=3, delay=1) -def test_vertexai_embedding(): +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_vertexai_embedding(sync_mode): try: load_vertex_ai_credentials() - # litellm.set_verbose = True - response = embedding( - model="textembedding-gecko@001", - input=["good morning from litellm", "this is another item"], - ) - print(f"response:", response) + litellm.set_verbose = True + + input_text = ["good morning from litellm", "this is another item"] + + if sync_mode: + response = litellm.embedding( + model="textembedding-gecko@001", input=input_text + ) + else: + response = await litellm.aembedding( + model="textembedding-gecko@001", input=input_text + ) + + print(f"response: {response}") + + # Assert that the response is not None + assert response is not None + + # Assert that the response contains embeddings + assert hasattr(response, "data") + assert len(response.data) == len(input_text) + + # Assert that each embedding is a non-empty list of floats + for embedding in response.data: + assert "embedding" in embedding + assert isinstance(embedding["embedding"], list) + assert len(embedding["embedding"]) > 0 + assert all(isinstance(x, float) for x in embedding["embedding"]) + except litellm.RateLimitError as e: pass except Exception as e: