fix(gemini/): fix image_url handling for gemini

Fixes https://github.com/BerriAI/litellm/issues/6897
This commit is contained in:
Krrish Dholakia 2024-11-25 21:15:19 +05:30
parent b55c829561
commit 05b5a21014
7 changed files with 123 additions and 18 deletions

View file

@ -33,6 +33,7 @@ from litellm.types.llms.openai import (
ChatCompletionAssistantToolCall,
ChatCompletionFunctionMessage,
ChatCompletionImageObject,
ChatCompletionImageUrlObject,
ChatCompletionTextObject,
ChatCompletionToolCallFunctionChunk,
ChatCompletionToolMessage,
@ -681,6 +682,27 @@ def construct_tool_use_system_prompt(
return tool_use_system_prompt
def convert_generic_image_chunk_to_openai_image_obj(
image_chunk: GenericImageParsingChunk,
) -> str:
"""
Convert a generic image chunk to an OpenAI image object.
Input:
GenericImageParsingChunk(
type="base64",
media_type="image/jpeg",
data="...",
)
Return:
"data:image/jpeg;base64,{base64_image}"
"""
return "data:{};{},{}".format(
image_chunk["media_type"], image_chunk["type"], image_chunk["data"]
)
def convert_to_anthropic_image_obj(openai_image_url: str) -> GenericImageParsingChunk:
"""
Input:
@ -706,6 +728,7 @@ def convert_to_anthropic_image_obj(openai_image_url: str) -> GenericImageParsing
data=base64_data,
)
except Exception as e:
traceback.print_exc()
if "Error: Unable to fetch image from URL" in str(e):
raise e
raise Exception(

View file

@ -294,7 +294,12 @@ def _transform_request_body(
optional_params = {k: v for k, v in optional_params.items() if k not in remove_keys}
try:
content = _gemini_convert_messages_with_history(messages=messages)
if custom_llm_provider == "gemini":
content = litellm.GoogleAIStudioGeminiConfig._transform_messages(
messages=messages
)
else:
content = litellm.VertexGeminiConfig._transform_messages(messages=messages)
tools: Optional[Tools] = optional_params.pop("tools", None)
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop(

View file

@ -35,7 +35,12 @@ from litellm.llms.custom_httpx.http_handler import (
HTTPHandler,
get_async_httpx_client,
)
from litellm.llms.prompt_templates.factory import (
convert_generic_image_chunk_to_openai_image_obj,
convert_to_anthropic_image_obj,
)
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionResponseMessage,
ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk,
@ -78,6 +83,8 @@ from ..common_utils import (
)
from ..vertex_llm_base import VertexBase
from .transformation import (
_gemini_convert_messages_with_history,
_process_gemini_image,
async_transform_request_body,
set_headers,
sync_transform_request_body,
@ -912,6 +919,10 @@ class VertexGeminiConfig:
return model_response
@staticmethod
def _transform_messages(messages: List[AllMessageValues]) -> List[ContentType]:
return _gemini_convert_messages_with_history(messages=messages)
class GoogleAIStudioGeminiConfig(
VertexGeminiConfig
@ -1015,6 +1026,32 @@ class GoogleAIStudioGeminiConfig(
model, non_default_params, optional_params, drop_params
)
@staticmethod
def _transform_messages(messages: List[AllMessageValues]) -> List[ContentType]:
"""
Google AI Studio Gemini does not support image urls in messages.
"""
for message in messages:
_message_content = message.get("content")
if _message_content is not None and isinstance(_message_content, list):
_parts: List[PartType] = []
for element in _message_content:
if element.get("type") == "image_url":
img_element = element
_image_url: Optional[str] = None
if isinstance(img_element.get("image_url"), dict):
_image_url = img_element["image_url"].get("url") # type: ignore
else:
_image_url = img_element.get("image_url") # type: ignore
if _image_url and "https://" in _image_url:
image_obj = convert_to_anthropic_image_obj(_image_url)
img_element["image_url"] = ( # type: ignore
convert_generic_image_chunk_to_openai_image_obj(
image_obj
)
)
return _gemini_convert_messages_with_history(messages=messages)
async def make_call(
client: Optional[AsyncHTTPHandler],

View file

@ -190,6 +190,35 @@ class BaseLLMChatTest(ABC):
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
pass
def test_image_url(self):
litellm.set_verbose = True
from litellm.utils import supports_vision
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
base_completion_call_args = self.get_base_completion_call_args()
if not supports_vision(base_completion_call_args["model"], None):
pytest.skip("Model does not support image input")
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {
"url": "https://i.pinimg.com/736x/b4/b1/be/b4b1becad04d03a9071db2817fc9fe77.jpg"
},
},
],
}
]
response = litellm.completion(**base_completion_call_args, messages=messages)
assert response is not None
@pytest.fixture
def pdf_messages(self):
import base64

View file

@ -0,0 +1,15 @@
from base_llm_unit_tests import BaseLLMChatTest
class TestGoogleAIStudioGemini(BaseLLMChatTest):
def get_base_completion_call_args(self) -> dict:
return {"model": "gemini/gemini-1.5-flash"}
def test_tool_call_no_arguments(self, tool_call_no_arguments):
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
from litellm.llms.prompt_templates.factory import (
convert_to_gemini_tool_call_invoke,
)
result = convert_to_gemini_tool_call_invoke(tool_call_no_arguments)
print(result)

View file

@ -687,3 +687,16 @@ def test_just_system_message():
llm_provider="bedrock",
)
assert "bedrock requires at least one non-system message" in str(e.value)
def test_convert_generic_image_chunk_to_openai_image_obj():
from litellm.llms.prompt_templates.factory import (
convert_generic_image_chunk_to_openai_image_obj,
convert_to_anthropic_image_obj,
)
url = "https://i.pinimg.com/736x/b4/b1/be/b4b1becad04d03a9071db2817fc9fe77.jpg"
image_obj = convert_to_anthropic_image_obj(url)
url_str = convert_generic_image_chunk_to_openai_image_obj(image_obj)
image_obj = convert_to_anthropic_image_obj(url_str)
print(image_obj)

View file

@ -1298,20 +1298,3 @@ def test_vertex_embedding_url(model, expected_url):
assert url == expected_url
assert endpoint == "predict"
from base_llm_unit_tests import BaseLLMChatTest
class TestVertexGemini(BaseLLMChatTest):
def get_base_completion_call_args(self) -> dict:
return {"model": "gemini/gemini-1.5-flash"}
def test_tool_call_no_arguments(self, tool_call_no_arguments):
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
from litellm.llms.prompt_templates.factory import (
convert_to_gemini_tool_call_invoke,
)
result = convert_to_gemini_tool_call_invoke(tool_call_no_arguments)
print(result)