mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
fix vertex ai multimodal embedding translation (#9471)
* remove data:image/jpeg;base64, prefix from base64 image input vertex_ai's multimodal embeddings endpoint expects a raw base64 string without `data:image/jpeg;base64,` prefix. * Add Vertex Multimodal Embedding Test * fix(test_vertex.py): add e2e tests on multimodal embeddings * test: unit testing * test: remove sklearn dep * test: update test with fixed route * test: fix test --------- Co-authored-by: Jonarod <jonrodd@gmail.com> Co-authored-by: Emerson Gomes <emerson.gomes@thalesgroup.com>
This commit is contained in:
parent
75994d0bf0
commit
92883560f0
6 changed files with 149 additions and 1 deletions
|
@ -226,7 +226,15 @@ class VertexMultimodalEmbedding(VertexLLM):
|
|||
else:
|
||||
return Instance(image=InstanceImage(gcsUri=input_element))
|
||||
elif is_base64_encoded(s=input_element):
|
||||
return Instance(image=InstanceImage(bytesBase64Encoded=input_element))
|
||||
return Instance(
|
||||
image=InstanceImage(
|
||||
bytesBase64Encoded=(
|
||||
input_element.split(",")[1]
|
||||
if "," in input_element
|
||||
else input_element
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
return Instance(text=input_element)
|
||||
|
||||
|
|
|
@ -3715,6 +3715,7 @@ def embedding( # noqa: PLR0915
|
|||
aembedding=aembedding,
|
||||
print_verbose=print_verbose,
|
||||
custom_llm_provider="vertex_ai",
|
||||
client=client,
|
||||
)
|
||||
else:
|
||||
response = vertex_embedding.embedding(
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm.llms.vertex_ai.multimodal_embeddings.embedding_handler import (
|
||||
VertexMultimodalEmbedding,
|
||||
)
|
||||
from litellm.types.llms.vertex_ai import Instance, InstanceImage
|
||||
|
||||
|
||||
class TestVertexMultimodalEmbedding:
|
||||
def setup_method(self):
|
||||
self.embedding_handler = VertexMultimodalEmbedding()
|
||||
|
||||
def test_process_openai_embedding_input(self):
|
||||
input_data = [
|
||||
"",
|
||||
"",
|
||||
]
|
||||
expected_output = [
|
||||
Instance(
|
||||
image=InstanceImage(bytesBase64Encoded=input_data[0].split(",")[1])
|
||||
),
|
||||
Instance(
|
||||
image=InstanceImage(bytesBase64Encoded=input_data[1].split(",")[1])
|
||||
),
|
||||
]
|
||||
assert (
|
||||
self.embedding_handler._process_input_element(input_data[0])
|
||||
== expected_output[0]
|
||||
)
|
||||
assert (
|
||||
self.embedding_handler._process_input_element(input_data[1])
|
||||
== expected_output[1]
|
||||
)
|
BIN
tests/llm_translation/duck.png
Normal file
BIN
tests/llm_translation/duck.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 535 KiB |
BIN
tests/llm_translation/guinea.png
Normal file
BIN
tests/llm_translation/guinea.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.5 MiB |
|
@ -1,3 +1,5 @@
|
|||
import base64
|
||||
import numpy as np
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
@ -25,6 +27,11 @@ from litellm.types.llms.vertex_ai import PartType, BlobType
|
|||
import httpx
|
||||
|
||||
|
||||
def encode_image_to_base64(image_path):
|
||||
with open(image_path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
|
||||
|
||||
def test_completion_pydantic_obj_2():
|
||||
from pydantic import BaseModel
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
@ -1264,6 +1271,25 @@ from typing import Dict, Any
|
|||
# from your_module import _process_gemini_image, PartType, FileDataType, BlobType
|
||||
|
||||
|
||||
# Add these fixtures below existing fixtures
|
||||
@pytest.fixture
|
||||
def vertex_client():
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
return HTTPHandler()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def encoded_images():
|
||||
image_paths = [
|
||||
"./tests/llm_translation/duck.png",
|
||||
# "./duck.png",
|
||||
"./tests/llm_translation/guinea.png",
|
||||
# "./guinea.png",
|
||||
]
|
||||
return [encode_image_to_base64(path) for path in image_paths]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_convert_url_to_base64():
|
||||
with patch(
|
||||
|
@ -1305,3 +1331,73 @@ def test_process_gemini_image_http_url(
|
|||
# Act
|
||||
result = _process_gemini_image(http_url)
|
||||
# assert result["file_data"]["file_uri"] == http_url
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_string, expected_closer_index",
|
||||
[
|
||||
("Duck", 0), # Duck closer to duck image
|
||||
("Guinea", 1), # Guinea closer to guinea image
|
||||
],
|
||||
)
|
||||
def test_aaavertex_embeddings_distances(
|
||||
vertex_client, encoded_images, input_string, expected_closer_index
|
||||
):
|
||||
"""
|
||||
Test cosine distances between image and text embeddings using Vertex AI multimodalembedding@001
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
|
||||
# Mock different embedding values to simulate realistic distances
|
||||
mock_image_embeddings = [
|
||||
[0.9] + [0.1] * 767, # Duck embedding - closer to "Duck"
|
||||
[0.1] * 767 + [0.9], # Guinea embedding - closer to "Guinea"
|
||||
]
|
||||
|
||||
image_embeddings = []
|
||||
mock_response = MagicMock()
|
||||
|
||||
def mock_auth_token(*args, **kwargs):
|
||||
return "my-fake-token", "pathrise-project"
|
||||
|
||||
with patch.object(vertex_client, "post", return_value=mock_response), patch.object(
|
||||
litellm.main.vertex_multimodal_embedding,
|
||||
"_ensure_access_token",
|
||||
side_effect=mock_auth_token,
|
||||
):
|
||||
for idx, encoded_image in enumerate(encoded_images):
|
||||
mock_response.json.return_value = {
|
||||
"predictions": [{"imageEmbedding": mock_image_embeddings[idx]}]
|
||||
}
|
||||
mock_response.status_code = 200
|
||||
response = litellm.embedding(
|
||||
model="vertex_ai/multimodalembedding@001",
|
||||
input=[f"data:image/png;base64,{encoded_image}"],
|
||||
client=vertex_client,
|
||||
)
|
||||
print("response: ", response)
|
||||
image_embeddings.append(response.data[0].embedding)
|
||||
|
||||
# Mock text embedding based on input string
|
||||
mock_text_embedding = (
|
||||
[0.9] + [0.1] * 767 if input_string == "Duck" else [0.1] * 767 + [0.9]
|
||||
)
|
||||
text_mock_response = MagicMock()
|
||||
text_mock_response.json.return_value = {
|
||||
"predictions": [{"imageEmbedding": mock_text_embedding}]
|
||||
}
|
||||
text_mock_response.status_code = 200
|
||||
with patch.object(
|
||||
vertex_client, "post", return_value=text_mock_response
|
||||
), patch.object(
|
||||
litellm.main.vertex_multimodal_embedding,
|
||||
"_ensure_access_token",
|
||||
side_effect=mock_auth_token,
|
||||
):
|
||||
text_response = litellm.embedding(
|
||||
model="vertex_ai/multimodalembedding@001",
|
||||
input=[input_string],
|
||||
client=vertex_client,
|
||||
)
|
||||
print("text_response: ", text_response)
|
||||
text_embedding = text_response.data[0].embedding
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue