From a6273a29fee456124d16725c3e00ab16b2b35215 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 30 Aug 2024 09:19:48 -0700 Subject: [PATCH] add test for test_vertexai_multimodal_embedding_text_input --- .../vertex_and_google_ai_studio_gemini.py | 162 ------------- .../embedding_handler.py | 216 ++++++++++++++++++ litellm/main.py | 9 +- .../tests/test_amazing_vertex_completion.py | 55 +++++ 4 files changed, 278 insertions(+), 164 deletions(-) create mode 100644 litellm/llms/vertex_ai_and_google_ai_studio/multimodal_embeddings/embedding_handler.py diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py index ed7cd8dba..238018821 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py @@ -41,15 +41,12 @@ from litellm.types.llms.vertex_ai import ( FunctionDeclaration, GenerateContentResponseBody, GenerationConfig, - Instance, - InstanceVideo, PartType, RequestBody, SafetSettingsConfig, SystemInstructions, ToolConfig, Tools, - VertexMultimodalEmbeddingRequest, ) from litellm.types.utils import GenericStreamingChunk from litellm.utils import CustomStreamWrapper, ModelResponse, Usage @@ -811,10 +808,6 @@ class VertexLLM(BaseLLM): self._credentials: Optional[Any] = None self.project_id: Optional[str] = None self.async_handler: Optional[AsyncHTTPHandler] = None - self.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS = [ - "multimodalembedding", - "multimodalembedding@001", - ] def _process_response( self, @@ -1727,161 +1720,6 @@ class VertexLLM(BaseLLM): return model_response - def multimodal_embedding( - self, - model: str, - input: Union[list, str], - print_verbose, - model_response: litellm.EmbeddingResponse, - custom_llm_provider: Literal["gemini", "vertex_ai"], - optional_params: dict, - api_key: Optional[str] = None, - api_base: Optional[str] = None, - logging_obj=None, - encoding=None, - vertex_project=None, - vertex_location=None, - vertex_credentials=None, - aembedding=False, - timeout=300, - client=None, - ): - auth_header, url = self._get_token_and_url( - model=model, - gemini_api_key=api_key, - vertex_project=vertex_project, - vertex_location=vertex_location, - vertex_credentials=vertex_credentials, - stream=None, - custom_llm_provider=custom_llm_provider, - api_base=api_base, - should_use_v1beta1_features=False, - mode="embedding", - ) - - if client is None: - _params = {} - if timeout is not None: - if isinstance(timeout, float) or isinstance(timeout, int): - _httpx_timeout = httpx.Timeout(timeout) - _params["timeout"] = _httpx_timeout - else: - _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) - - sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore - else: - sync_handler = client # type: ignore - - optional_params = optional_params or {} - - request_data = VertexMultimodalEmbeddingRequest() - - if "instances" in optional_params: - request_data["instances"] = optional_params["instances"] - elif isinstance(input, list): - request_data["instances"] = input - else: - # construct instances - vertex_request_instance = Instance(**optional_params) - - if isinstance(input, str): - vertex_request_instance["text"] = input - - request_data["instances"] = [vertex_request_instance] - - headers = { - "Content-Type": "application/json; charset=utf-8", - "Authorization": f"Bearer {auth_header}", - } - - ## LOGGING - logging_obj.pre_call( - input=input, - api_key="", - additional_args={ - "complete_input_dict": request_data, - "api_base": url, - "headers": headers, - }, - ) - - if aembedding is True: - return self.async_multimodal_embedding( - model=model, - api_base=url, - data=request_data, - timeout=timeout, - headers=headers, - client=client, - model_response=model_response, - ) - - response = sync_handler.post( - url=url, - headers=headers, - data=json.dumps(request_data), - ) - - if response.status_code != 200: - raise Exception(f"Error: {response.status_code} {response.text}") - - _json_response = response.json() - if "predictions" not in _json_response: - raise litellm.InternalServerError( - message=f"embedding response does not contain 'predictions', got {_json_response}", - llm_provider="vertex_ai", - model=model, - ) - _predictions = _json_response["predictions"] - - model_response.data = _predictions - model_response.model = model - - return model_response - - async def async_multimodal_embedding( - self, - model: str, - api_base: str, - data: VertexMultimodalEmbeddingRequest, - model_response: litellm.EmbeddingResponse, - timeout: Optional[Union[float, httpx.Timeout]], - headers={}, - client: Optional[AsyncHTTPHandler] = None, - ) -> litellm.EmbeddingResponse: - if client is None: - _params = {} - if timeout is not None: - if isinstance(timeout, float) or isinstance(timeout, int): - timeout = httpx.Timeout(timeout) - _params["timeout"] = timeout - client = AsyncHTTPHandler(**_params) # type: ignore - else: - client = client # type: ignore - - try: - response = await client.post(api_base, headers=headers, json=data) # type: ignore - response.raise_for_status() - except httpx.HTTPStatusError as err: - error_code = err.response.status_code - raise VertexAIError(status_code=error_code, message=err.response.text) - except httpx.TimeoutException: - raise VertexAIError(status_code=408, message="Timeout error occurred.") - - _json_response = response.json() - if "predictions" not in _json_response: - raise litellm.InternalServerError( - message=f"embedding response does not contain 'predictions', got {_json_response}", - llm_provider="vertex_ai", - model=model, - ) - _predictions = _json_response["predictions"] - - model_response.data = _predictions - model_response.model = model - - return model_response - class ModelResponseIterator: def __init__(self, streaming_response, sync_stream: bool): diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/multimodal_embeddings/embedding_handler.py b/litellm/llms/vertex_ai_and_google_ai_studio/multimodal_embeddings/embedding_handler.py new file mode 100644 index 000000000..455358191 --- /dev/null +++ b/litellm/llms/vertex_ai_and_google_ai_studio/multimodal_embeddings/embedding_handler.py @@ -0,0 +1,216 @@ +import json +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union + +import httpx + +import litellm +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( + VertexAIError, + VertexLLM, +) +from litellm.types.llms.vertex_ai import ( + Instance, + InstanceVideo, + VertexMultimodalEmbeddingRequest, +) + + +class VertexMultimodalEmbedding(VertexLLM): + def __init__(self) -> None: + super().__init__() + self.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS = [ + "multimodalembedding", + "multimodalembedding@001", + ] + + def multimodal_embedding( + self, + model: str, + input: Union[list, str], + print_verbose, + model_response: litellm.EmbeddingResponse, + custom_llm_provider: Literal["gemini", "vertex_ai"], + optional_params: dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + logging_obj=None, + encoding=None, + vertex_project=None, + vertex_location=None, + vertex_credentials=None, + aembedding=False, + timeout=300, + client=None, + ): + auth_header, url = self._get_token_and_url( + model=model, + gemini_api_key=api_key, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_credentials=vertex_credentials, + stream=None, + custom_llm_provider=custom_llm_provider, + api_base=api_base, + should_use_v1beta1_features=False, + mode="embedding", + ) + + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + _httpx_timeout = httpx.Timeout(timeout) + _params["timeout"] = _httpx_timeout + else: + _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) + + sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore + else: + sync_handler = client # type: ignore + + optional_params = optional_params or {} + + request_data = VertexMultimodalEmbeddingRequest() + + if "instances" in optional_params: + request_data["instances"] = optional_params["instances"] + elif isinstance(input, list): + vertex_instances: List[Instance] = self.process_openai_embedding_input( + _input=input + ) + request_data["instances"] = vertex_instances + + else: + # construct instances + vertex_request_instance = Instance(**optional_params) + + if isinstance(input, str): + vertex_request_instance["text"] = input + + request_data["instances"] = [vertex_request_instance] + + headers = { + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {auth_header}", + } + + ## LOGGING + logging_obj.pre_call( + input=input, + api_key="", + additional_args={ + "complete_input_dict": request_data, + "api_base": url, + "headers": headers, + }, + ) + + if aembedding is True: + return self.async_multimodal_embedding( + model=model, + api_base=url, + data=request_data, + timeout=timeout, + headers=headers, + client=client, + model_response=model_response, + ) + + response = sync_handler.post( + url=url, + headers=headers, + data=json.dumps(request_data), + ) + + if response.status_code != 200: + raise Exception(f"Error: {response.status_code} {response.text}") + + _json_response = response.json() + if "predictions" not in _json_response: + raise litellm.InternalServerError( + message=f"embedding response does not contain 'predictions', got {_json_response}", + llm_provider="vertex_ai", + model=model, + ) + _predictions = _json_response["predictions"] + + model_response.data = _predictions + model_response.model = model + + return model_response + + async def async_multimodal_embedding( + self, + model: str, + api_base: str, + data: VertexMultimodalEmbeddingRequest, + model_response: litellm.EmbeddingResponse, + timeout: Optional[Union[float, httpx.Timeout]], + headers={}, + client: Optional[AsyncHTTPHandler] = None, + ) -> litellm.EmbeddingResponse: + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + timeout = httpx.Timeout(timeout) + _params["timeout"] = timeout + client = AsyncHTTPHandler(**_params) # type: ignore + else: + client = client # type: ignore + + try: + response = await client.post(api_base, headers=headers, json=data) # type: ignore + response.raise_for_status() + except httpx.HTTPStatusError as err: + error_code = err.response.status_code + raise VertexAIError(status_code=error_code, message=err.response.text) + except httpx.TimeoutException: + raise VertexAIError(status_code=408, message="Timeout error occurred.") + + _json_response = response.json() + if "predictions" not in _json_response: + raise litellm.InternalServerError( + message=f"embedding response does not contain 'predictions', got {_json_response}", + llm_provider="vertex_ai", + model=model, + ) + _predictions = _json_response["predictions"] + + model_response.data = _predictions + model_response.model = model + + return model_response + + def process_openai_embedding_input( + self, _input: Union[list, str] + ) -> List[Instance]: + """ + Process the input for multimodal embedding requests. + + Args: + _input (Union[list, str]): The input data to process. + + Returns: + List[Instance]: A list of processed VertexAI Instance objects. + """ + + _input_list = None + if not isinstance(_input, list): + _input_list = [_input] + else: + _input_list = _input + + processed_instances = [] + for element in _input: + if not isinstance(element, dict): + # assuming that input is a list of strings + # example: input = ["hello from litellm"] + instance = Instance(text=element) + else: + # assume this is a + instance = Instance(**element) + processed_instances.append(instance) + + return processed_instances diff --git a/litellm/main.py b/litellm/main.py index ca9d145f1..36c711caf 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -132,6 +132,9 @@ from .llms.vertex_ai_and_google_ai_studio.embeddings.batch_embed_content_handler from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( VertexLLM, ) +from .llms.vertex_ai_and_google_ai_studio.multimodal_embeddings.embedding_handler import ( + VertexMultimodalEmbedding, +) from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import ( VertexAIPartnerModels, ) @@ -175,6 +178,7 @@ triton_chat_completions = TritonChatCompletion() bedrock_chat_completion = BedrockLLM() bedrock_converse_chat_completion = BedrockConverseLLM() vertex_chat_completion = VertexLLM() +vertex_multimodal_embedding = VertexMultimodalEmbedding() google_batch_embeddings = GoogleBatchEmbeddings() vertex_partner_models_chat_completion = VertexAIPartnerModels() vertex_text_to_speech = VertexTextToSpeechAPI() @@ -3581,10 +3585,11 @@ def embedding( if ( "image" in optional_params or "video" in optional_params - or model in vertex_chat_completion.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS + or model + in vertex_multimodal_embedding.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS ): # multimodal embedding is supported on vertex httpx - response = vertex_chat_completion.multimodal_embedding( + response = vertex_multimodal_embedding.multimodal_embedding( model=model, input=input, encoding=encoding, diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index a4238995d..472d6be58 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -1934,6 +1934,61 @@ async def test_vertexai_multimodal_embedding(): print("Response:", response) +@pytest.mark.asyncio +async def test_vertexai_multimodal_embedding_text_input(): + load_vertex_ai_credentials() + mock_response = AsyncMock() + + def return_val(): + return { + "predictions": [ + { + "textEmbedding": [0.4, 0.5, 0.6], # Simplified example + } + ] + } + + mock_response.json = return_val + mock_response.status_code = 200 + + expected_payload = { + "instances": [ + { + "text": "this is a unicorn", + } + ] + } + + with patch( + "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", + return_value=mock_response, + ) as mock_post: + # Act: Call the litellm.aembedding function + response = await litellm.aembedding( + model="vertex_ai/multimodalembedding@001", + input=[ + "this is a unicorn", + ], + ) + + # Assert + mock_post.assert_called_once() + _, kwargs = mock_post.call_args + args_to_vertexai = kwargs["json"] + + print("args to vertex ai call:", args_to_vertexai) + + assert args_to_vertexai == expected_payload + assert response.model == "multimodalembedding@001" + assert len(response.data) == 1 + response_data = response.data[0] + assert "textEmbedding" in response_data + + # Optional: Print for debugging + print("Arguments passed to Vertex AI:", args_to_vertexai) + print("Response:", response) + + @pytest.mark.skip( reason="new test - works locally running into vertex version issues on ci/cd" )