From 05b5a21014048b8f71a428d2722b827fcdf17ba7 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 25 Nov 2024 21:15:19 +0530 Subject: [PATCH] fix(gemini/): fix image_url handling for gemini Fixes https://github.com/BerriAI/litellm/issues/6897 --- litellm/llms/prompt_templates/factory.py | 23 ++++++++++++ .../gemini/transformation.py | 7 +++- .../vertex_and_google_ai_studio_gemini.py | 37 +++++++++++++++++++ tests/llm_translation/base_llm_unit_tests.py | 29 +++++++++++++++ tests/llm_translation/test_gemini.py | 15 ++++++++ tests/llm_translation/test_prompt_factory.py | 13 +++++++ tests/llm_translation/test_vertex.py | 17 --------- 7 files changed, 123 insertions(+), 18 deletions(-) create mode 100644 tests/llm_translation/test_gemini.py diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index 45b7a6c5b..cb79a81b7 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -33,6 +33,7 @@ from litellm.types.llms.openai import ( ChatCompletionAssistantToolCall, ChatCompletionFunctionMessage, ChatCompletionImageObject, + ChatCompletionImageUrlObject, ChatCompletionTextObject, ChatCompletionToolCallFunctionChunk, ChatCompletionToolMessage, @@ -681,6 +682,27 @@ def construct_tool_use_system_prompt( return tool_use_system_prompt +def convert_generic_image_chunk_to_openai_image_obj( + image_chunk: GenericImageParsingChunk, +) -> str: + """ + Convert a generic image chunk to an OpenAI image object. + + Input: + GenericImageParsingChunk( + type="base64", + media_type="image/jpeg", + data="...", + ) + + Return: + "data:image/jpeg;base64,{base64_image}" + """ + return "data:{};{},{}".format( + image_chunk["media_type"], image_chunk["type"], image_chunk["data"] + ) + + def convert_to_anthropic_image_obj(openai_image_url: str) -> GenericImageParsingChunk: """ Input: @@ -706,6 +728,7 @@ def convert_to_anthropic_image_obj(openai_image_url: str) -> GenericImageParsing data=base64_data, ) except Exception as e: + traceback.print_exc() if "Error: Unable to fetch image from URL" in str(e): raise e raise Exception( diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/transformation.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/transformation.py index f828d93c8..4b5b7281b 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/transformation.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/transformation.py @@ -294,7 +294,12 @@ def _transform_request_body( optional_params = {k: v for k, v in optional_params.items() if k not in remove_keys} try: - content = _gemini_convert_messages_with_history(messages=messages) + if custom_llm_provider == "gemini": + content = litellm.GoogleAIStudioGeminiConfig._transform_messages( + messages=messages + ) + else: + content = litellm.VertexGeminiConfig._transform_messages(messages=messages) tools: Optional[Tools] = optional_params.pop("tools", None) tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None) safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop( diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py index f2fc599ed..4287ed1bc 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py @@ -35,7 +35,12 @@ from litellm.llms.custom_httpx.http_handler import ( HTTPHandler, get_async_httpx_client, ) +from litellm.llms.prompt_templates.factory import ( + convert_generic_image_chunk_to_openai_image_obj, + convert_to_anthropic_image_obj, +) from litellm.types.llms.openai import ( + AllMessageValues, ChatCompletionResponseMessage, ChatCompletionToolCallChunk, ChatCompletionToolCallFunctionChunk, @@ -78,6 +83,8 @@ from ..common_utils import ( ) from ..vertex_llm_base import VertexBase from .transformation import ( + _gemini_convert_messages_with_history, + _process_gemini_image, async_transform_request_body, set_headers, sync_transform_request_body, @@ -912,6 +919,10 @@ class VertexGeminiConfig: return model_response + @staticmethod + def _transform_messages(messages: List[AllMessageValues]) -> List[ContentType]: + return _gemini_convert_messages_with_history(messages=messages) + class GoogleAIStudioGeminiConfig( VertexGeminiConfig @@ -1015,6 +1026,32 @@ class GoogleAIStudioGeminiConfig( model, non_default_params, optional_params, drop_params ) + @staticmethod + def _transform_messages(messages: List[AllMessageValues]) -> List[ContentType]: + """ + Google AI Studio Gemini does not support image urls in messages. + """ + for message in messages: + _message_content = message.get("content") + if _message_content is not None and isinstance(_message_content, list): + _parts: List[PartType] = [] + for element in _message_content: + if element.get("type") == "image_url": + img_element = element + _image_url: Optional[str] = None + if isinstance(img_element.get("image_url"), dict): + _image_url = img_element["image_url"].get("url") # type: ignore + else: + _image_url = img_element.get("image_url") # type: ignore + if _image_url and "https://" in _image_url: + image_obj = convert_to_anthropic_image_obj(_image_url) + img_element["image_url"] = ( # type: ignore + convert_generic_image_chunk_to_openai_image_obj( + image_obj + ) + ) + return _gemini_convert_messages_with_history(messages=messages) + async def make_call( client: Optional[AsyncHTTPHandler], diff --git a/tests/llm_translation/base_llm_unit_tests.py b/tests/llm_translation/base_llm_unit_tests.py index 88fce6dac..24a972e20 100644 --- a/tests/llm_translation/base_llm_unit_tests.py +++ b/tests/llm_translation/base_llm_unit_tests.py @@ -190,6 +190,35 @@ class BaseLLMChatTest(ABC): """Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833""" pass + def test_image_url(self): + litellm.set_verbose = True + from litellm.utils import supports_vision + + os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" + litellm.model_cost = litellm.get_model_cost_map(url="") + + base_completion_call_args = self.get_base_completion_call_args() + if not supports_vision(base_completion_call_args["model"], None): + pytest.skip("Model does not support image input") + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "https://i.pinimg.com/736x/b4/b1/be/b4b1becad04d03a9071db2817fc9fe77.jpg" + }, + }, + ], + } + ] + + response = litellm.completion(**base_completion_call_args, messages=messages) + assert response is not None + @pytest.fixture def pdf_messages(self): import base64 diff --git a/tests/llm_translation/test_gemini.py b/tests/llm_translation/test_gemini.py new file mode 100644 index 000000000..4e6c5118d --- /dev/null +++ b/tests/llm_translation/test_gemini.py @@ -0,0 +1,15 @@ +from base_llm_unit_tests import BaseLLMChatTest + + +class TestGoogleAIStudioGemini(BaseLLMChatTest): + def get_base_completion_call_args(self) -> dict: + return {"model": "gemini/gemini-1.5-flash"} + + def test_tool_call_no_arguments(self, tool_call_no_arguments): + """Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833""" + from litellm.llms.prompt_templates.factory import ( + convert_to_gemini_tool_call_invoke, + ) + + result = convert_to_gemini_tool_call_invoke(tool_call_no_arguments) + print(result) diff --git a/tests/llm_translation/test_prompt_factory.py b/tests/llm_translation/test_prompt_factory.py index 104997563..d8cf191f6 100644 --- a/tests/llm_translation/test_prompt_factory.py +++ b/tests/llm_translation/test_prompt_factory.py @@ -687,3 +687,16 @@ def test_just_system_message(): llm_provider="bedrock", ) assert "bedrock requires at least one non-system message" in str(e.value) + + +def test_convert_generic_image_chunk_to_openai_image_obj(): + from litellm.llms.prompt_templates.factory import ( + convert_generic_image_chunk_to_openai_image_obj, + convert_to_anthropic_image_obj, + ) + + url = "https://i.pinimg.com/736x/b4/b1/be/b4b1becad04d03a9071db2817fc9fe77.jpg" + image_obj = convert_to_anthropic_image_obj(url) + url_str = convert_generic_image_chunk_to_openai_image_obj(image_obj) + image_obj = convert_to_anthropic_image_obj(url_str) + print(image_obj) diff --git a/tests/llm_translation/test_vertex.py b/tests/llm_translation/test_vertex.py index 3e1087536..c2c1fdecf 100644 --- a/tests/llm_translation/test_vertex.py +++ b/tests/llm_translation/test_vertex.py @@ -1298,20 +1298,3 @@ def test_vertex_embedding_url(model, expected_url): assert url == expected_url assert endpoint == "predict" - - -from base_llm_unit_tests import BaseLLMChatTest - - -class TestVertexGemini(BaseLLMChatTest): - def get_base_completion_call_args(self) -> dict: - return {"model": "gemini/gemini-1.5-flash"} - - def test_tool_call_no_arguments(self, tool_call_no_arguments): - """Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833""" - from litellm.llms.prompt_templates.factory import ( - convert_to_gemini_tool_call_invoke, - ) - - result = convert_to_gemini_tool_call_invoke(tool_call_no_arguments) - print(result)