mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix vertex ai multimodal embedding translation (#9471)
* remove data:image/jpeg;base64, prefix from base64 image input vertex_ai's multimodal embeddings endpoint expects a raw base64 string without `data:image/jpeg;base64,` prefix. * Add Vertex Multimodal Embedding Test * fix(test_vertex.py): add e2e tests on multimodal embeddings * test: unit testing * test: remove sklearn dep * test: update test with fixed route * test: fix test --------- Co-authored-by: Jonarod <jonrodd@gmail.com> Co-authored-by: Emerson Gomes <emerson.gomes@thalesgroup.com>
This commit is contained in:
parent
75994d0bf0
commit
92883560f0
6 changed files with 149 additions and 1 deletions
|
@ -226,7 +226,15 @@ class VertexMultimodalEmbedding(VertexLLM):
|
||||||
else:
|
else:
|
||||||
return Instance(image=InstanceImage(gcsUri=input_element))
|
return Instance(image=InstanceImage(gcsUri=input_element))
|
||||||
elif is_base64_encoded(s=input_element):
|
elif is_base64_encoded(s=input_element):
|
||||||
return Instance(image=InstanceImage(bytesBase64Encoded=input_element))
|
return Instance(
|
||||||
|
image=InstanceImage(
|
||||||
|
bytesBase64Encoded=(
|
||||||
|
input_element.split(",")[1]
|
||||||
|
if "," in input_element
|
||||||
|
else input_element
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return Instance(text=input_element)
|
return Instance(text=input_element)
|
||||||
|
|
||||||
|
|
|
@ -3715,6 +3715,7 @@ def embedding( # noqa: PLR0915
|
||||||
aembedding=aembedding,
|
aembedding=aembedding,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
custom_llm_provider="vertex_ai",
|
custom_llm_provider="vertex_ai",
|
||||||
|
client=client,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = vertex_embedding.embedding(
|
response = vertex_embedding.embedding(
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../../../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
from litellm.llms.vertex_ai.multimodal_embeddings.embedding_handler import (
|
||||||
|
VertexMultimodalEmbedding,
|
||||||
|
)
|
||||||
|
from litellm.types.llms.vertex_ai import Instance, InstanceImage
|
||||||
|
|
||||||
|
|
||||||
|
class TestVertexMultimodalEmbedding:
|
||||||
|
def setup_method(self):
|
||||||
|
self.embedding_handler = VertexMultimodalEmbedding()
|
||||||
|
|
||||||
|
def test_process_openai_embedding_input(self):
|
||||||
|
input_data = [
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
expected_output = [
|
||||||
|
Instance(
|
||||||
|
image=InstanceImage(bytesBase64Encoded=input_data[0].split(",")[1])
|
||||||
|
),
|
||||||
|
Instance(
|
||||||
|
image=InstanceImage(bytesBase64Encoded=input_data[1].split(",")[1])
|
||||||
|
),
|
||||||
|
]
|
||||||
|
assert (
|
||||||
|
self.embedding_handler._process_input_element(input_data[0])
|
||||||
|
== expected_output[0]
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
self.embedding_handler._process_input_element(input_data[1])
|
||||||
|
== expected_output[1]
|
||||||
|
)
|
BIN
tests/llm_translation/duck.png
Normal file
BIN
tests/llm_translation/duck.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 535 KiB |
BIN
tests/llm_translation/guinea.png
Normal file
BIN
tests/llm_translation/guinea.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.5 MiB |
|
@ -1,3 +1,5 @@
|
||||||
|
import base64
|
||||||
|
import numpy as np
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
@ -25,6 +27,11 @@ from litellm.types.llms.vertex_ai import PartType, BlobType
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
|
def encode_image_to_base64(image_path):
|
||||||
|
with open(image_path, "rb") as image_file:
|
||||||
|
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
def test_completion_pydantic_obj_2():
|
def test_completion_pydantic_obj_2():
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
@ -1264,6 +1271,25 @@ from typing import Dict, Any
|
||||||
# from your_module import _process_gemini_image, PartType, FileDataType, BlobType
|
# from your_module import _process_gemini_image, PartType, FileDataType, BlobType
|
||||||
|
|
||||||
|
|
||||||
|
# Add these fixtures below existing fixtures
|
||||||
|
@pytest.fixture
|
||||||
|
def vertex_client():
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
|
||||||
|
return HTTPHandler()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def encoded_images():
|
||||||
|
image_paths = [
|
||||||
|
"./tests/llm_translation/duck.png",
|
||||||
|
# "./duck.png",
|
||||||
|
"./tests/llm_translation/guinea.png",
|
||||||
|
# "./guinea.png",
|
||||||
|
]
|
||||||
|
return [encode_image_to_base64(path) for path in image_paths]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_convert_url_to_base64():
|
def mock_convert_url_to_base64():
|
||||||
with patch(
|
with patch(
|
||||||
|
@ -1305,3 +1331,73 @@ def test_process_gemini_image_http_url(
|
||||||
# Act
|
# Act
|
||||||
result = _process_gemini_image(http_url)
|
result = _process_gemini_image(http_url)
|
||||||
# assert result["file_data"]["file_uri"] == http_url
|
# assert result["file_data"]["file_uri"] == http_url
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_string, expected_closer_index",
|
||||||
|
[
|
||||||
|
("Duck", 0), # Duck closer to duck image
|
||||||
|
("Guinea", 1), # Guinea closer to guinea image
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_aaavertex_embeddings_distances(
|
||||||
|
vertex_client, encoded_images, input_string, expected_closer_index
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test cosine distances between image and text embeddings using Vertex AI multimodalembedding@001
|
||||||
|
"""
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
# Mock different embedding values to simulate realistic distances
|
||||||
|
mock_image_embeddings = [
|
||||||
|
[0.9] + [0.1] * 767, # Duck embedding - closer to "Duck"
|
||||||
|
[0.1] * 767 + [0.9], # Guinea embedding - closer to "Guinea"
|
||||||
|
]
|
||||||
|
|
||||||
|
image_embeddings = []
|
||||||
|
mock_response = MagicMock()
|
||||||
|
|
||||||
|
def mock_auth_token(*args, **kwargs):
|
||||||
|
return "my-fake-token", "pathrise-project"
|
||||||
|
|
||||||
|
with patch.object(vertex_client, "post", return_value=mock_response), patch.object(
|
||||||
|
litellm.main.vertex_multimodal_embedding,
|
||||||
|
"_ensure_access_token",
|
||||||
|
side_effect=mock_auth_token,
|
||||||
|
):
|
||||||
|
for idx, encoded_image in enumerate(encoded_images):
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"predictions": [{"imageEmbedding": mock_image_embeddings[idx]}]
|
||||||
|
}
|
||||||
|
mock_response.status_code = 200
|
||||||
|
response = litellm.embedding(
|
||||||
|
model="vertex_ai/multimodalembedding@001",
|
||||||
|
input=[f"data:image/png;base64,{encoded_image}"],
|
||||||
|
client=vertex_client,
|
||||||
|
)
|
||||||
|
print("response: ", response)
|
||||||
|
image_embeddings.append(response.data[0].embedding)
|
||||||
|
|
||||||
|
# Mock text embedding based on input string
|
||||||
|
mock_text_embedding = (
|
||||||
|
[0.9] + [0.1] * 767 if input_string == "Duck" else [0.1] * 767 + [0.9]
|
||||||
|
)
|
||||||
|
text_mock_response = MagicMock()
|
||||||
|
text_mock_response.json.return_value = {
|
||||||
|
"predictions": [{"imageEmbedding": mock_text_embedding}]
|
||||||
|
}
|
||||||
|
text_mock_response.status_code = 200
|
||||||
|
with patch.object(
|
||||||
|
vertex_client, "post", return_value=text_mock_response
|
||||||
|
), patch.object(
|
||||||
|
litellm.main.vertex_multimodal_embedding,
|
||||||
|
"_ensure_access_token",
|
||||||
|
side_effect=mock_auth_token,
|
||||||
|
):
|
||||||
|
text_response = litellm.embedding(
|
||||||
|
model="vertex_ai/multimodalembedding@001",
|
||||||
|
input=[input_string],
|
||||||
|
client=vertex_client,
|
||||||
|
)
|
||||||
|
print("text_response: ", text_response)
|
||||||
|
text_embedding = text_response.data[0].embedding
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue