diff --git a/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py b/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py index f63d1ce11e..cd3de6e526 100644 --- a/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py +++ b/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py @@ -226,7 +226,15 @@ class VertexMultimodalEmbedding(VertexLLM): else: return Instance(image=InstanceImage(gcsUri=input_element)) elif is_base64_encoded(s=input_element): - return Instance(image=InstanceImage(bytesBase64Encoded=input_element)) + return Instance( + image=InstanceImage( + bytesBase64Encoded=( + input_element.split(",")[1] + if "," in input_element + else input_element + ) + ) + ) else: return Instance(text=input_element) diff --git a/litellm/main.py b/litellm/main.py index 1826f2df78..3d4152d634 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3715,6 +3715,7 @@ def embedding( # noqa: PLR0915 aembedding=aembedding, print_verbose=print_verbose, custom_llm_provider="vertex_ai", + client=client, ) else: response = vertex_embedding.embedding( diff --git a/tests/litellm/llms/vertex_ai/multimodal_embeddings/test_embedding_handler.py b/tests/litellm/llms/vertex_ai/multimodal_embeddings/test_embedding_handler.py new file mode 100644 index 0000000000..cfd1167836 --- /dev/null +++ b/tests/litellm/llms/vertex_ai/multimodal_embeddings/test_embedding_handler.py @@ -0,0 +1,43 @@ +import json +import os +import sys +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +sys.path.insert( + 0, os.path.abspath("../../../../..") +) # Adds the parent directory to the system path + +from litellm.llms.vertex_ai.multimodal_embeddings.embedding_handler import ( + VertexMultimodalEmbedding, +) +from litellm.types.llms.vertex_ai import Instance, InstanceImage + + +class TestVertexMultimodalEmbedding: + def setup_method(self): + self.embedding_handler = VertexMultimodalEmbedding() + + def test_process_openai_embedding_input(self): + input_data = [ + "", + "", + ] + expected_output = [ + Instance( + image=InstanceImage(bytesBase64Encoded=input_data[0].split(",")[1]) + ), + Instance( + image=InstanceImage(bytesBase64Encoded=input_data[1].split(",")[1]) + ), + ] + assert ( + self.embedding_handler._process_input_element(input_data[0]) + == expected_output[0] + ) + assert ( + self.embedding_handler._process_input_element(input_data[1]) + == expected_output[1] + ) diff --git a/tests/llm_translation/duck.png b/tests/llm_translation/duck.png new file mode 100644 index 0000000000..541ce2fc75 Binary files /dev/null and b/tests/llm_translation/duck.png differ diff --git a/tests/llm_translation/guinea.png b/tests/llm_translation/guinea.png new file mode 100644 index 0000000000..4c8537f72b Binary files /dev/null and b/tests/llm_translation/guinea.png differ diff --git a/tests/llm_translation/test_vertex.py b/tests/llm_translation/test_vertex.py index da6fd4e285..d7a373d58d 100644 --- a/tests/llm_translation/test_vertex.py +++ b/tests/llm_translation/test_vertex.py @@ -1,3 +1,5 @@ +import base64 +import numpy as np import json import os import sys @@ -25,6 +27,11 @@ from litellm.types.llms.vertex_ai import PartType, BlobType import httpx +def encode_image_to_base64(image_path): + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") + + def test_completion_pydantic_obj_2(): from pydantic import BaseModel from litellm.llms.custom_httpx.http_handler import HTTPHandler @@ -1264,6 +1271,25 @@ from typing import Dict, Any # from your_module import _process_gemini_image, PartType, FileDataType, BlobType +# Add these fixtures below existing fixtures +@pytest.fixture +def vertex_client(): + from litellm.llms.custom_httpx.http_handler import HTTPHandler + + return HTTPHandler() + + +@pytest.fixture +def encoded_images(): + image_paths = [ + "./tests/llm_translation/duck.png", + # "./duck.png", + "./tests/llm_translation/guinea.png", + # "./guinea.png", + ] + return [encode_image_to_base64(path) for path in image_paths] + + @pytest.fixture def mock_convert_url_to_base64(): with patch( @@ -1305,3 +1331,73 @@ def test_process_gemini_image_http_url( # Act result = _process_gemini_image(http_url) # assert result["file_data"]["file_uri"] == http_url + + +@pytest.mark.parametrize( + "input_string, expected_closer_index", + [ + ("Duck", 0), # Duck closer to duck image + ("Guinea", 1), # Guinea closer to guinea image + ], +) +def test_aaavertex_embeddings_distances( + vertex_client, encoded_images, input_string, expected_closer_index +): + """ + Test cosine distances between image and text embeddings using Vertex AI multimodalembedding@001 + """ + from unittest.mock import patch + + # Mock different embedding values to simulate realistic distances + mock_image_embeddings = [ + [0.9] + [0.1] * 767, # Duck embedding - closer to "Duck" + [0.1] * 767 + [0.9], # Guinea embedding - closer to "Guinea" + ] + + image_embeddings = [] + mock_response = MagicMock() + + def mock_auth_token(*args, **kwargs): + return "my-fake-token", "pathrise-project" + + with patch.object(vertex_client, "post", return_value=mock_response), patch.object( + litellm.main.vertex_multimodal_embedding, + "_ensure_access_token", + side_effect=mock_auth_token, + ): + for idx, encoded_image in enumerate(encoded_images): + mock_response.json.return_value = { + "predictions": [{"imageEmbedding": mock_image_embeddings[idx]}] + } + mock_response.status_code = 200 + response = litellm.embedding( + model="vertex_ai/multimodalembedding@001", + input=[f"data:image/png;base64,{encoded_image}"], + client=vertex_client, + ) + print("response: ", response) + image_embeddings.append(response.data[0].embedding) + + # Mock text embedding based on input string + mock_text_embedding = ( + [0.9] + [0.1] * 767 if input_string == "Duck" else [0.1] * 767 + [0.9] + ) + text_mock_response = MagicMock() + text_mock_response.json.return_value = { + "predictions": [{"imageEmbedding": mock_text_embedding}] + } + text_mock_response.status_code = 200 + with patch.object( + vertex_client, "post", return_value=text_mock_response + ), patch.object( + litellm.main.vertex_multimodal_embedding, + "_ensure_access_token", + side_effect=mock_auth_token, + ): + text_response = litellm.embedding( + model="vertex_ai/multimodalembedding@001", + input=[input_string], + client=vertex_client, + ) + print("text_response: ", text_response) + text_embedding = text_response.data[0].embedding