forked from phoenix/litellm-mirror
fix(gemini/): fix image_url handling for gemini
Fixes https://github.com/BerriAI/litellm/issues/6897
This commit is contained in:
parent
b55c829561
commit
05b5a21014
7 changed files with 123 additions and 18 deletions
|
@ -33,6 +33,7 @@ from litellm.types.llms.openai import (
|
|||
ChatCompletionAssistantToolCall,
|
||||
ChatCompletionFunctionMessage,
|
||||
ChatCompletionImageObject,
|
||||
ChatCompletionImageUrlObject,
|
||||
ChatCompletionTextObject,
|
||||
ChatCompletionToolCallFunctionChunk,
|
||||
ChatCompletionToolMessage,
|
||||
|
@ -681,6 +682,27 @@ def construct_tool_use_system_prompt(
|
|||
return tool_use_system_prompt
|
||||
|
||||
|
||||
def convert_generic_image_chunk_to_openai_image_obj(
|
||||
image_chunk: GenericImageParsingChunk,
|
||||
) -> str:
|
||||
"""
|
||||
Convert a generic image chunk to an OpenAI image object.
|
||||
|
||||
Input:
|
||||
GenericImageParsingChunk(
|
||||
type="base64",
|
||||
media_type="image/jpeg",
|
||||
data="...",
|
||||
)
|
||||
|
||||
Return:
|
||||
"data:image/jpeg;base64,{base64_image}"
|
||||
"""
|
||||
return "data:{};{},{}".format(
|
||||
image_chunk["media_type"], image_chunk["type"], image_chunk["data"]
|
||||
)
|
||||
|
||||
|
||||
def convert_to_anthropic_image_obj(openai_image_url: str) -> GenericImageParsingChunk:
|
||||
"""
|
||||
Input:
|
||||
|
@ -706,6 +728,7 @@ def convert_to_anthropic_image_obj(openai_image_url: str) -> GenericImageParsing
|
|||
data=base64_data,
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if "Error: Unable to fetch image from URL" in str(e):
|
||||
raise e
|
||||
raise Exception(
|
||||
|
|
|
@ -294,7 +294,12 @@ def _transform_request_body(
|
|||
optional_params = {k: v for k, v in optional_params.items() if k not in remove_keys}
|
||||
|
||||
try:
|
||||
content = _gemini_convert_messages_with_history(messages=messages)
|
||||
if custom_llm_provider == "gemini":
|
||||
content = litellm.GoogleAIStudioGeminiConfig._transform_messages(
|
||||
messages=messages
|
||||
)
|
||||
else:
|
||||
content = litellm.VertexGeminiConfig._transform_messages(messages=messages)
|
||||
tools: Optional[Tools] = optional_params.pop("tools", None)
|
||||
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
|
||||
safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop(
|
||||
|
|
|
@ -35,7 +35,12 @@ from litellm.llms.custom_httpx.http_handler import (
|
|||
HTTPHandler,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.llms.prompt_templates.factory import (
|
||||
convert_generic_image_chunk_to_openai_image_obj,
|
||||
convert_to_anthropic_image_obj,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionResponseMessage,
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionToolCallFunctionChunk,
|
||||
|
@ -78,6 +83,8 @@ from ..common_utils import (
|
|||
)
|
||||
from ..vertex_llm_base import VertexBase
|
||||
from .transformation import (
|
||||
_gemini_convert_messages_with_history,
|
||||
_process_gemini_image,
|
||||
async_transform_request_body,
|
||||
set_headers,
|
||||
sync_transform_request_body,
|
||||
|
@ -912,6 +919,10 @@ class VertexGeminiConfig:
|
|||
|
||||
return model_response
|
||||
|
||||
@staticmethod
|
||||
def _transform_messages(messages: List[AllMessageValues]) -> List[ContentType]:
|
||||
return _gemini_convert_messages_with_history(messages=messages)
|
||||
|
||||
|
||||
class GoogleAIStudioGeminiConfig(
|
||||
VertexGeminiConfig
|
||||
|
@ -1015,6 +1026,32 @@ class GoogleAIStudioGeminiConfig(
|
|||
model, non_default_params, optional_params, drop_params
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _transform_messages(messages: List[AllMessageValues]) -> List[ContentType]:
|
||||
"""
|
||||
Google AI Studio Gemini does not support image urls in messages.
|
||||
"""
|
||||
for message in messages:
|
||||
_message_content = message.get("content")
|
||||
if _message_content is not None and isinstance(_message_content, list):
|
||||
_parts: List[PartType] = []
|
||||
for element in _message_content:
|
||||
if element.get("type") == "image_url":
|
||||
img_element = element
|
||||
_image_url: Optional[str] = None
|
||||
if isinstance(img_element.get("image_url"), dict):
|
||||
_image_url = img_element["image_url"].get("url") # type: ignore
|
||||
else:
|
||||
_image_url = img_element.get("image_url") # type: ignore
|
||||
if _image_url and "https://" in _image_url:
|
||||
image_obj = convert_to_anthropic_image_obj(_image_url)
|
||||
img_element["image_url"] = ( # type: ignore
|
||||
convert_generic_image_chunk_to_openai_image_obj(
|
||||
image_obj
|
||||
)
|
||||
)
|
||||
return _gemini_convert_messages_with_history(messages=messages)
|
||||
|
||||
|
||||
async def make_call(
|
||||
client: Optional[AsyncHTTPHandler],
|
||||
|
|
|
@ -190,6 +190,35 @@ class BaseLLMChatTest(ABC):
|
|||
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||
pass
|
||||
|
||||
def test_image_url(self):
|
||||
litellm.set_verbose = True
|
||||
from litellm.utils import supports_vision
|
||||
|
||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||
|
||||
base_completion_call_args = self.get_base_completion_call_args()
|
||||
if not supports_vision(base_completion_call_args["model"], None):
|
||||
pytest.skip("Model does not support image input")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://i.pinimg.com/736x/b4/b1/be/b4b1becad04d03a9071db2817fc9fe77.jpg"
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
response = litellm.completion(**base_completion_call_args, messages=messages)
|
||||
assert response is not None
|
||||
|
||||
@pytest.fixture
|
||||
def pdf_messages(self):
|
||||
import base64
|
||||
|
|
15
tests/llm_translation/test_gemini.py
Normal file
15
tests/llm_translation/test_gemini.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
from base_llm_unit_tests import BaseLLMChatTest
|
||||
|
||||
|
||||
class TestGoogleAIStudioGemini(BaseLLMChatTest):
|
||||
def get_base_completion_call_args(self) -> dict:
|
||||
return {"model": "gemini/gemini-1.5-flash"}
|
||||
|
||||
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
||||
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||
from litellm.llms.prompt_templates.factory import (
|
||||
convert_to_gemini_tool_call_invoke,
|
||||
)
|
||||
|
||||
result = convert_to_gemini_tool_call_invoke(tool_call_no_arguments)
|
||||
print(result)
|
|
@ -687,3 +687,16 @@ def test_just_system_message():
|
|||
llm_provider="bedrock",
|
||||
)
|
||||
assert "bedrock requires at least one non-system message" in str(e.value)
|
||||
|
||||
|
||||
def test_convert_generic_image_chunk_to_openai_image_obj():
|
||||
from litellm.llms.prompt_templates.factory import (
|
||||
convert_generic_image_chunk_to_openai_image_obj,
|
||||
convert_to_anthropic_image_obj,
|
||||
)
|
||||
|
||||
url = "https://i.pinimg.com/736x/b4/b1/be/b4b1becad04d03a9071db2817fc9fe77.jpg"
|
||||
image_obj = convert_to_anthropic_image_obj(url)
|
||||
url_str = convert_generic_image_chunk_to_openai_image_obj(image_obj)
|
||||
image_obj = convert_to_anthropic_image_obj(url_str)
|
||||
print(image_obj)
|
||||
|
|
|
@ -1298,20 +1298,3 @@ def test_vertex_embedding_url(model, expected_url):
|
|||
|
||||
assert url == expected_url
|
||||
assert endpoint == "predict"
|
||||
|
||||
|
||||
from base_llm_unit_tests import BaseLLMChatTest
|
||||
|
||||
|
||||
class TestVertexGemini(BaseLLMChatTest):
|
||||
def get_base_completion_call_args(self) -> dict:
|
||||
return {"model": "gemini/gemini-1.5-flash"}
|
||||
|
||||
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
||||
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||
from litellm.llms.prompt_templates.factory import (
|
||||
convert_to_gemini_tool_call_invoke,
|
||||
)
|
||||
|
||||
result = convert_to_gemini_tool_call_invoke(tool_call_no_arguments)
|
||||
print(result)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue