From 96b0e324e306c78611cdb4add6f4134fc7a2686c Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 5 Nov 2024 11:25:09 +0530 Subject: [PATCH] (fix) Vertex Improve Performance when using `image_url` (#6593) * fix transformation vertex * test test_process_gemini_image * test_image_completion_request * testing fix - bedrock has deprecated cohere.command-text-v14 * fix vertex pdf --- .../gemini/transformation.py | 37 ++++- tests/llm_translation/test_vertex.py | 146 ++++++++++++++++++ 2 files changed, 180 insertions(+), 3 deletions(-) diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/transformation.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/transformation.py index 66ab07674..f828d93c8 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/transformation.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/transformation.py @@ -51,6 +51,9 @@ from ..common_utils import ( def _process_gemini_image(image_url: str) -> PartType: + """ + Given an image URL, return the appropriate PartType for Gemini + """ try: # GCS URIs if "gs://" in image_url: @@ -68,9 +71,14 @@ def _process_gemini_image(image_url: str) -> PartType: file_data = FileDataType(mime_type=mime_type, file_uri=image_url) return PartType(file_data=file_data) - - # Direct links - elif "https:/" in image_url or "base64" in image_url: + elif ( + "https://" in image_url + and (image_type := _get_image_mime_type_from_url(image_url)) is not None + ): + file_data = FileDataType(file_uri=image_url, mime_type=image_type) + return PartType(file_data=file_data) + elif "https://" in image_url or "base64" in image_url: + # https links for unsupported mime types and base64 images image = convert_to_anthropic_image_obj(image_url) _blob = BlobType(data=image["data"], mime_type=image["media_type"]) return PartType(inline_data=_blob) @@ -79,6 +87,29 @@ def _process_gemini_image(image_url: str) -> PartType: raise e +def _get_image_mime_type_from_url(url: str) -> Optional[str]: + """ + Get mime type for common image URLs + See gemini mime types: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/image-understanding#image-requirements + + Supported by Gemini: + - PNG (`image/png`) + - JPEG (`image/jpeg`) + - WebP (`image/webp`) + Example: + url = https://example.com/image.jpg + Returns: image/jpeg + """ + url = url.lower() + if url.endswith((".jpg", ".jpeg")): + return "image/jpeg" + elif url.endswith(".png"): + return "image/png" + elif url.endswith(".webp"): + return "image/webp" + return None + + def _gemini_convert_messages_with_history( # noqa: PLR0915 messages: List[AllMessageValues], ) -> List[ContentType]: diff --git a/tests/llm_translation/test_vertex.py b/tests/llm_translation/test_vertex.py index 467be4ddf..a06179a49 100644 --- a/tests/llm_translation/test_vertex.py +++ b/tests/llm_translation/test_vertex.py @@ -15,6 +15,7 @@ sys.path.insert( import pytest import litellm from litellm import get_optional_params +from litellm.llms.custom_httpx.http_handler import HTTPHandler def test_completion_pydantic_obj_2(): @@ -1171,3 +1172,148 @@ def test_logprobs(): print(resp) assert resp.choices[0].logprobs is not None + + +def test_process_gemini_image(): + """Test the _process_gemini_image function for different image sources""" + from litellm.llms.vertex_ai_and_google_ai_studio.gemini.transformation import ( + _process_gemini_image, + ) + from litellm.types.llms.vertex_ai import PartType, FileDataType, BlobType + + # Test GCS URI + gcs_result = _process_gemini_image("gs://bucket/image.png") + assert gcs_result["file_data"] == FileDataType( + mime_type="image/png", file_uri="gs://bucket/image.png" + ) + + # Test HTTPS JPG URL + https_result = _process_gemini_image("https://example.com/image.jpg") + print("https_result JPG", https_result) + assert https_result["file_data"] == FileDataType( + mime_type="image/jpeg", file_uri="https://example.com/image.jpg" + ) + + # Test HTTPS PNG URL + https_result = _process_gemini_image("https://example.com/image.png") + print("https_result PNG", https_result) + assert https_result["file_data"] == FileDataType( + mime_type="image/png", file_uri="https://example.com/image.png" + ) + + # Test base64 image + base64_image = "..." + base64_result = _process_gemini_image(base64_image) + print("base64_result", base64_result) + assert base64_result["inline_data"]["mime_type"] == "image/jpeg" + assert base64_result["inline_data"]["data"] == "/9j/4AAQSkZJRg..." + + +def test_get_image_mime_type_from_url(): + """Test the _get_image_mime_type_from_url function for different image URLs""" + from litellm.llms.vertex_ai_and_google_ai_studio.gemini.transformation import ( + _get_image_mime_type_from_url, + ) + + # Test JPEG images + assert ( + _get_image_mime_type_from_url("https://example.com/image.jpg") == "image/jpeg" + ) + assert ( + _get_image_mime_type_from_url("https://example.com/image.jpeg") == "image/jpeg" + ) + assert ( + _get_image_mime_type_from_url("https://example.com/IMAGE.JPG") == "image/jpeg" + ) + + # Test PNG images + assert _get_image_mime_type_from_url("https://example.com/image.png") == "image/png" + assert _get_image_mime_type_from_url("https://example.com/IMAGE.PNG") == "image/png" + + # Test WebP images + assert ( + _get_image_mime_type_from_url("https://example.com/image.webp") == "image/webp" + ) + assert ( + _get_image_mime_type_from_url("https://example.com/IMAGE.WEBP") == "image/webp" + ) + + # Test unsupported formats + assert _get_image_mime_type_from_url("https://example.com/image.gif") is None + assert _get_image_mime_type_from_url("https://example.com/image.bmp") is None + assert _get_image_mime_type_from_url("https://example.com/image") is None + assert _get_image_mime_type_from_url("invalid_url") is None + + +@pytest.mark.parametrize( + "image_url", ["https://example.com/image.jpg", "https://example.com/image.png"] +) +def test_image_completion_request(image_url): + """https:// .jpg, .png images are passed directly to the model""" + from unittest.mock import patch, Mock + import litellm + from litellm.llms.vertex_ai_and_google_ai_studio.gemini.transformation import ( + _get_image_mime_type_from_url, + ) + + # Mock response data + mock_response = Mock() + mock_response.json.return_value = { + "candidates": [{"content": {"parts": [{"text": "This is a sunflower"}]}}], + "usageMetadata": { + "promptTokenCount": 11, + "candidatesTokenCount": 50, + "totalTokenCount": 61, + }, + "modelVersion": "gemini-1.5-pro", + } + mock_response.raise_for_status = MagicMock() + mock_response.status_code = 200 + + # Expected request body + expected_request_body = { + "contents": [ + { + "role": "user", + "parts": [ + {"text": "Whats in this image?"}, + { + "file_data": { + "file_uri": image_url, + "mime_type": _get_image_mime_type_from_url(image_url), + } + }, + ], + } + ], + "system_instruction": {"parts": [{"text": "Be a good bot"}]}, + "generationConfig": {}, + } + + messages = [ + {"role": "system", "content": "Be a good bot"}, + { + "role": "user", + "content": [ + {"type": "text", "text": "Whats in this image?"}, + {"type": "image_url", "image_url": {"url": image_url}}, + ], + }, + ] + + client = HTTPHandler() + with patch.object(client, "post", new=MagicMock()) as mock_post: + mock_post.return_value = mock_response + try: + litellm.completion( + model="gemini/gemini-1.5-pro", + messages=messages, + client=client, + ) + except Exception as e: + print(e) + + # Assert the request body matches expected + mock_post.assert_called_once() + print("mock_post.call_args.kwargs['json']", mock_post.call_args.kwargs["json"]) + assert mock_post.call_args.kwargs["json"] == expected_request_body