fix vertex ai multimodal embedding translation (#9471)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 20s
Helm unit test / unit-test (push) Successful in 24s

* remove data:image/jpeg;base64, prefix from base64 image input

vertex_ai's multimodal embeddings endpoint expects a raw base64 string without `data:image/jpeg;base64,` prefix.

* Add Vertex Multimodal Embedding Test

* fix(test_vertex.py): add e2e tests on multimodal embeddings

* test: unit testing

* test: remove sklearn dep

* test: update test with fixed route

* test: fix test

---------

Co-authored-by: Jonarod <jonrodd@gmail.com>
Co-authored-by: Emerson Gomes <emerson.gomes@thalesgroup.com>
This commit is contained in:
Krish Dholakia 2025-03-24 23:23:28 -07:00 committed by GitHub
parent 75994d0bf0
commit 92883560f0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 149 additions and 1 deletions

View file

@ -226,7 +226,15 @@ class VertexMultimodalEmbedding(VertexLLM):
else:
return Instance(image=InstanceImage(gcsUri=input_element))
elif is_base64_encoded(s=input_element):
return Instance(image=InstanceImage(bytesBase64Encoded=input_element))
return Instance(
image=InstanceImage(
bytesBase64Encoded=(
input_element.split(",")[1]
if "," in input_element
else input_element
)
)
)
else:
return Instance(text=input_element)

View file

@ -3715,6 +3715,7 @@ def embedding( # noqa: PLR0915
aembedding=aembedding,
print_verbose=print_verbose,
custom_llm_provider="vertex_ai",
client=client,
)
else:
response = vertex_embedding.embedding(

View file

@ -0,0 +1,43 @@
import json
import os
import sys
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
sys.path.insert(
0, os.path.abspath("../../../../..")
) # Adds the parent directory to the system path
from litellm.llms.vertex_ai.multimodal_embeddings.embedding_handler import (
VertexMultimodalEmbedding,
)
from litellm.types.llms.vertex_ai import Instance, InstanceImage
class TestVertexMultimodalEmbedding:
def setup_method(self):
self.embedding_handler = VertexMultimodalEmbedding()
def test_process_openai_embedding_input(self):
input_data = [
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+ip1sAAAAASUVORK5CYII=",
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+ip1sAAAAASUVORK5CYII=",
]
expected_output = [
Instance(
image=InstanceImage(bytesBase64Encoded=input_data[0].split(",")[1])
),
Instance(
image=InstanceImage(bytesBase64Encoded=input_data[1].split(",")[1])
),
]
assert (
self.embedding_handler._process_input_element(input_data[0])
== expected_output[0]
)
assert (
self.embedding_handler._process_input_element(input_data[1])
== expected_output[1]
)

Binary file not shown.

After

Width:  |  Height:  |  Size: 535 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.5 MiB

View file

@ -1,3 +1,5 @@
import base64
import numpy as np
import json
import os
import sys
@ -25,6 +27,11 @@ from litellm.types.llms.vertex_ai import PartType, BlobType
import httpx
def encode_image_to_base64(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def test_completion_pydantic_obj_2():
from pydantic import BaseModel
from litellm.llms.custom_httpx.http_handler import HTTPHandler
@ -1264,6 +1271,25 @@ from typing import Dict, Any
# from your_module import _process_gemini_image, PartType, FileDataType, BlobType
# Add these fixtures below existing fixtures
@pytest.fixture
def vertex_client():
from litellm.llms.custom_httpx.http_handler import HTTPHandler
return HTTPHandler()
@pytest.fixture
def encoded_images():
image_paths = [
"./tests/llm_translation/duck.png",
# "./duck.png",
"./tests/llm_translation/guinea.png",
# "./guinea.png",
]
return [encode_image_to_base64(path) for path in image_paths]
@pytest.fixture
def mock_convert_url_to_base64():
with patch(
@ -1305,3 +1331,73 @@ def test_process_gemini_image_http_url(
# Act
result = _process_gemini_image(http_url)
# assert result["file_data"]["file_uri"] == http_url
@pytest.mark.parametrize(
"input_string, expected_closer_index",
[
("Duck", 0), # Duck closer to duck image
("Guinea", 1), # Guinea closer to guinea image
],
)
def test_aaavertex_embeddings_distances(
vertex_client, encoded_images, input_string, expected_closer_index
):
"""
Test cosine distances between image and text embeddings using Vertex AI multimodalembedding@001
"""
from unittest.mock import patch
# Mock different embedding values to simulate realistic distances
mock_image_embeddings = [
[0.9] + [0.1] * 767, # Duck embedding - closer to "Duck"
[0.1] * 767 + [0.9], # Guinea embedding - closer to "Guinea"
]
image_embeddings = []
mock_response = MagicMock()
def mock_auth_token(*args, **kwargs):
return "my-fake-token", "pathrise-project"
with patch.object(vertex_client, "post", return_value=mock_response), patch.object(
litellm.main.vertex_multimodal_embedding,
"_ensure_access_token",
side_effect=mock_auth_token,
):
for idx, encoded_image in enumerate(encoded_images):
mock_response.json.return_value = {
"predictions": [{"imageEmbedding": mock_image_embeddings[idx]}]
}
mock_response.status_code = 200
response = litellm.embedding(
model="vertex_ai/multimodalembedding@001",
input=[f"data:image/png;base64,{encoded_image}"],
client=vertex_client,
)
print("response: ", response)
image_embeddings.append(response.data[0].embedding)
# Mock text embedding based on input string
mock_text_embedding = (
[0.9] + [0.1] * 767 if input_string == "Duck" else [0.1] * 767 + [0.9]
)
text_mock_response = MagicMock()
text_mock_response.json.return_value = {
"predictions": [{"imageEmbedding": mock_text_embedding}]
}
text_mock_response.status_code = 200
with patch.object(
vertex_client, "post", return_value=text_mock_response
), patch.object(
litellm.main.vertex_multimodal_embedding,
"_ensure_access_token",
side_effect=mock_auth_token,
):
text_response = litellm.embedding(
model="vertex_ai/multimodalembedding@001",
input=[input_string],
client=vertex_client,
)
print("text_response: ", text_response)
text_embedding = text_response.data[0].embedding