fix test_vertexai_multimodal_embedding use magicMock requests

This commit is contained in:
Ishaan Jaff 2024-08-22 09:56:24 -07:00
parent cc8e6f1d44
commit 4fe22ec493

View file

@ -15,7 +15,7 @@ import asyncio
import json
import os
import tempfile
from unittest.mock import MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@ -1842,12 +1842,40 @@ def test_vertexai_embedding():
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio()
@pytest.mark.asyncio
async def test_vertexai_multimodal_embedding():
load_vertex_ai_credentials()
mock_response = AsyncMock()
try:
litellm.set_verbose = True
def return_val():
return {
"predictions": [
{
"imageEmbedding": [0.1, 0.2, 0.3], # Simplified example
"textEmbedding": [0.4, 0.5, 0.6], # Simplified example
}
]
}
mock_response.json = return_val
mock_response.status_code = 200
expected_payload = {
"instances": [
{
"image": {
"gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
},
"text": "this is a unicorn",
}
]
}
with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
return_value=mock_response,
) as mock_post:
# Act: Call the litellm.aembedding function
response = await litellm.aembedding(
model="vertex_ai/multimodalembedding@001",
input=[
@ -1859,17 +1887,24 @@ async def test_vertexai_multimodal_embedding():
},
],
)
print(f"response:", response)
# Assert
mock_post.assert_called_once()
_, kwargs = mock_post.call_args
args_to_vertexai = kwargs["json"]
print("args to vertex ai call:", args_to_vertexai)
assert args_to_vertexai == expected_payload
assert response.model == "multimodalembedding@001"
assert len(response.data) == 1
response_data = response.data[0]
assert "imageEmbedding" in response_data
assert "textEmbedding" in response_data
_response_data = response.data[0]
assert "imageEmbedding" in _response_data
assert "textEmbedding" in _response_data
except litellm.RateLimitError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# Optional: Print for debugging
print("Arguments passed to Vertex AI:", args_to_vertexai)
print("Response:", response)
@pytest.mark.skip(