(fix) Vertex Improve Performance when using image_url (#6593)

* fix transformation vertex

* test test_process_gemini_image

* test_image_completion_request

* testing fix - bedrock has deprecated cohere.command-text-v14

* fix vertex pdf
This commit is contained in:
Ishaan Jaff 2024-11-05 11:25:09 +05:30 committed by GitHub
parent c047d51cc8
commit 96b0e324e3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 180 additions and 3 deletions

View file

@ -15,6 +15,7 @@ sys.path.insert(
import pytest
import litellm
from litellm import get_optional_params
from litellm.llms.custom_httpx.http_handler import HTTPHandler
def test_completion_pydantic_obj_2():
@ -1171,3 +1172,148 @@ def test_logprobs():
print(resp)
assert resp.choices[0].logprobs is not None
def test_process_gemini_image():
"""Test the _process_gemini_image function for different image sources"""
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.transformation import (
_process_gemini_image,
)
from litellm.types.llms.vertex_ai import PartType, FileDataType, BlobType
# Test GCS URI
gcs_result = _process_gemini_image("gs://bucket/image.png")
assert gcs_result["file_data"] == FileDataType(
mime_type="image/png", file_uri="gs://bucket/image.png"
)
# Test HTTPS JPG URL
https_result = _process_gemini_image("https://example.com/image.jpg")
print("https_result JPG", https_result)
assert https_result["file_data"] == FileDataType(
mime_type="image/jpeg", file_uri="https://example.com/image.jpg"
)
# Test HTTPS PNG URL
https_result = _process_gemini_image("https://example.com/image.png")
print("https_result PNG", https_result)
assert https_result["file_data"] == FileDataType(
mime_type="image/png", file_uri="https://example.com/image.png"
)
# Test base64 image
base64_image = "data:image/jpeg;base64,/9j/4AAQSkZJRg..."
base64_result = _process_gemini_image(base64_image)
print("base64_result", base64_result)
assert base64_result["inline_data"]["mime_type"] == "image/jpeg"
assert base64_result["inline_data"]["data"] == "/9j/4AAQSkZJRg..."
def test_get_image_mime_type_from_url():
"""Test the _get_image_mime_type_from_url function for different image URLs"""
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.transformation import (
_get_image_mime_type_from_url,
)
# Test JPEG images
assert (
_get_image_mime_type_from_url("https://example.com/image.jpg") == "image/jpeg"
)
assert (
_get_image_mime_type_from_url("https://example.com/image.jpeg") == "image/jpeg"
)
assert (
_get_image_mime_type_from_url("https://example.com/IMAGE.JPG") == "image/jpeg"
)
# Test PNG images
assert _get_image_mime_type_from_url("https://example.com/image.png") == "image/png"
assert _get_image_mime_type_from_url("https://example.com/IMAGE.PNG") == "image/png"
# Test WebP images
assert (
_get_image_mime_type_from_url("https://example.com/image.webp") == "image/webp"
)
assert (
_get_image_mime_type_from_url("https://example.com/IMAGE.WEBP") == "image/webp"
)
# Test unsupported formats
assert _get_image_mime_type_from_url("https://example.com/image.gif") is None
assert _get_image_mime_type_from_url("https://example.com/image.bmp") is None
assert _get_image_mime_type_from_url("https://example.com/image") is None
assert _get_image_mime_type_from_url("invalid_url") is None
@pytest.mark.parametrize(
"image_url", ["https://example.com/image.jpg", "https://example.com/image.png"]
)
def test_image_completion_request(image_url):
"""https:// .jpg, .png images are passed directly to the model"""
from unittest.mock import patch, Mock
import litellm
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.transformation import (
_get_image_mime_type_from_url,
)
# Mock response data
mock_response = Mock()
mock_response.json.return_value = {
"candidates": [{"content": {"parts": [{"text": "This is a sunflower"}]}}],
"usageMetadata": {
"promptTokenCount": 11,
"candidatesTokenCount": 50,
"totalTokenCount": 61,
},
"modelVersion": "gemini-1.5-pro",
}
mock_response.raise_for_status = MagicMock()
mock_response.status_code = 200
# Expected request body
expected_request_body = {
"contents": [
{
"role": "user",
"parts": [
{"text": "Whats in this image?"},
{
"file_data": {
"file_uri": image_url,
"mime_type": _get_image_mime_type_from_url(image_url),
}
},
],
}
],
"system_instruction": {"parts": [{"text": "Be a good bot"}]},
"generationConfig": {},
}
messages = [
{"role": "system", "content": "Be a good bot"},
{
"role": "user",
"content": [
{"type": "text", "text": "Whats in this image?"},
{"type": "image_url", "image_url": {"url": image_url}},
],
},
]
client = HTTPHandler()
with patch.object(client, "post", new=MagicMock()) as mock_post:
mock_post.return_value = mock_response
try:
litellm.completion(
model="gemini/gemini-1.5-pro",
messages=messages,
client=client,
)
except Exception as e:
print(e)
# Assert the request body matches expected
mock_post.assert_called_once()
print("mock_post.call_args.kwargs['json']", mock_post.call_args.kwargs["json"])
assert mock_post.call_args.kwargs["json"] == expected_request_body