forked from phoenix/litellm-mirror
fix(gemini/): fix image_url handling for gemini
Fixes https://github.com/BerriAI/litellm/issues/6897
This commit is contained in:
parent
b55c829561
commit
05b5a21014
7 changed files with 123 additions and 18 deletions
|
@ -33,6 +33,7 @@ from litellm.types.llms.openai import (
|
||||||
ChatCompletionAssistantToolCall,
|
ChatCompletionAssistantToolCall,
|
||||||
ChatCompletionFunctionMessage,
|
ChatCompletionFunctionMessage,
|
||||||
ChatCompletionImageObject,
|
ChatCompletionImageObject,
|
||||||
|
ChatCompletionImageUrlObject,
|
||||||
ChatCompletionTextObject,
|
ChatCompletionTextObject,
|
||||||
ChatCompletionToolCallFunctionChunk,
|
ChatCompletionToolCallFunctionChunk,
|
||||||
ChatCompletionToolMessage,
|
ChatCompletionToolMessage,
|
||||||
|
@ -681,6 +682,27 @@ def construct_tool_use_system_prompt(
|
||||||
return tool_use_system_prompt
|
return tool_use_system_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def convert_generic_image_chunk_to_openai_image_obj(
|
||||||
|
image_chunk: GenericImageParsingChunk,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Convert a generic image chunk to an OpenAI image object.
|
||||||
|
|
||||||
|
Input:
|
||||||
|
GenericImageParsingChunk(
|
||||||
|
type="base64",
|
||||||
|
media_type="image/jpeg",
|
||||||
|
data="...",
|
||||||
|
)
|
||||||
|
|
||||||
|
Return:
|
||||||
|
"data:image/jpeg;base64,{base64_image}"
|
||||||
|
"""
|
||||||
|
return "data:{};{},{}".format(
|
||||||
|
image_chunk["media_type"], image_chunk["type"], image_chunk["data"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def convert_to_anthropic_image_obj(openai_image_url: str) -> GenericImageParsingChunk:
|
def convert_to_anthropic_image_obj(openai_image_url: str) -> GenericImageParsingChunk:
|
||||||
"""
|
"""
|
||||||
Input:
|
Input:
|
||||||
|
@ -706,6 +728,7 @@ def convert_to_anthropic_image_obj(openai_image_url: str) -> GenericImageParsing
|
||||||
data=base64_data,
|
data=base64_data,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
if "Error: Unable to fetch image from URL" in str(e):
|
if "Error: Unable to fetch image from URL" in str(e):
|
||||||
raise e
|
raise e
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
|
|
@ -294,7 +294,12 @@ def _transform_request_body(
|
||||||
optional_params = {k: v for k, v in optional_params.items() if k not in remove_keys}
|
optional_params = {k: v for k, v in optional_params.items() if k not in remove_keys}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
content = _gemini_convert_messages_with_history(messages=messages)
|
if custom_llm_provider == "gemini":
|
||||||
|
content = litellm.GoogleAIStudioGeminiConfig._transform_messages(
|
||||||
|
messages=messages
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
content = litellm.VertexGeminiConfig._transform_messages(messages=messages)
|
||||||
tools: Optional[Tools] = optional_params.pop("tools", None)
|
tools: Optional[Tools] = optional_params.pop("tools", None)
|
||||||
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
|
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
|
||||||
safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop(
|
safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop(
|
||||||
|
|
|
@ -35,7 +35,12 @@ from litellm.llms.custom_httpx.http_handler import (
|
||||||
HTTPHandler,
|
HTTPHandler,
|
||||||
get_async_httpx_client,
|
get_async_httpx_client,
|
||||||
)
|
)
|
||||||
|
from litellm.llms.prompt_templates.factory import (
|
||||||
|
convert_generic_image_chunk_to_openai_image_obj,
|
||||||
|
convert_to_anthropic_image_obj,
|
||||||
|
)
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
|
AllMessageValues,
|
||||||
ChatCompletionResponseMessage,
|
ChatCompletionResponseMessage,
|
||||||
ChatCompletionToolCallChunk,
|
ChatCompletionToolCallChunk,
|
||||||
ChatCompletionToolCallFunctionChunk,
|
ChatCompletionToolCallFunctionChunk,
|
||||||
|
@ -78,6 +83,8 @@ from ..common_utils import (
|
||||||
)
|
)
|
||||||
from ..vertex_llm_base import VertexBase
|
from ..vertex_llm_base import VertexBase
|
||||||
from .transformation import (
|
from .transformation import (
|
||||||
|
_gemini_convert_messages_with_history,
|
||||||
|
_process_gemini_image,
|
||||||
async_transform_request_body,
|
async_transform_request_body,
|
||||||
set_headers,
|
set_headers,
|
||||||
sync_transform_request_body,
|
sync_transform_request_body,
|
||||||
|
@ -912,6 +919,10 @@ class VertexGeminiConfig:
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _transform_messages(messages: List[AllMessageValues]) -> List[ContentType]:
|
||||||
|
return _gemini_convert_messages_with_history(messages=messages)
|
||||||
|
|
||||||
|
|
||||||
class GoogleAIStudioGeminiConfig(
|
class GoogleAIStudioGeminiConfig(
|
||||||
VertexGeminiConfig
|
VertexGeminiConfig
|
||||||
|
@ -1015,6 +1026,32 @@ class GoogleAIStudioGeminiConfig(
|
||||||
model, non_default_params, optional_params, drop_params
|
model, non_default_params, optional_params, drop_params
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _transform_messages(messages: List[AllMessageValues]) -> List[ContentType]:
|
||||||
|
"""
|
||||||
|
Google AI Studio Gemini does not support image urls in messages.
|
||||||
|
"""
|
||||||
|
for message in messages:
|
||||||
|
_message_content = message.get("content")
|
||||||
|
if _message_content is not None and isinstance(_message_content, list):
|
||||||
|
_parts: List[PartType] = []
|
||||||
|
for element in _message_content:
|
||||||
|
if element.get("type") == "image_url":
|
||||||
|
img_element = element
|
||||||
|
_image_url: Optional[str] = None
|
||||||
|
if isinstance(img_element.get("image_url"), dict):
|
||||||
|
_image_url = img_element["image_url"].get("url") # type: ignore
|
||||||
|
else:
|
||||||
|
_image_url = img_element.get("image_url") # type: ignore
|
||||||
|
if _image_url and "https://" in _image_url:
|
||||||
|
image_obj = convert_to_anthropic_image_obj(_image_url)
|
||||||
|
img_element["image_url"] = ( # type: ignore
|
||||||
|
convert_generic_image_chunk_to_openai_image_obj(
|
||||||
|
image_obj
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return _gemini_convert_messages_with_history(messages=messages)
|
||||||
|
|
||||||
|
|
||||||
async def make_call(
|
async def make_call(
|
||||||
client: Optional[AsyncHTTPHandler],
|
client: Optional[AsyncHTTPHandler],
|
||||||
|
|
|
@ -190,6 +190,35 @@ class BaseLLMChatTest(ABC):
|
||||||
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_image_url(self):
|
||||||
|
litellm.set_verbose = True
|
||||||
|
from litellm.utils import supports_vision
|
||||||
|
|
||||||
|
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||||
|
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||||
|
|
||||||
|
base_completion_call_args = self.get_base_completion_call_args()
|
||||||
|
if not supports_vision(base_completion_call_args["model"], None):
|
||||||
|
pytest.skip("Model does not support image input")
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What's in this image?"},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://i.pinimg.com/736x/b4/b1/be/b4b1becad04d03a9071db2817fc9fe77.jpg"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
response = litellm.completion(**base_completion_call_args, messages=messages)
|
||||||
|
assert response is not None
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def pdf_messages(self):
|
def pdf_messages(self):
|
||||||
import base64
|
import base64
|
||||||
|
|
15
tests/llm_translation/test_gemini.py
Normal file
15
tests/llm_translation/test_gemini.py
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
from base_llm_unit_tests import BaseLLMChatTest
|
||||||
|
|
||||||
|
|
||||||
|
class TestGoogleAIStudioGemini(BaseLLMChatTest):
|
||||||
|
def get_base_completion_call_args(self) -> dict:
|
||||||
|
return {"model": "gemini/gemini-1.5-flash"}
|
||||||
|
|
||||||
|
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
||||||
|
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||||
|
from litellm.llms.prompt_templates.factory import (
|
||||||
|
convert_to_gemini_tool_call_invoke,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = convert_to_gemini_tool_call_invoke(tool_call_no_arguments)
|
||||||
|
print(result)
|
|
@ -687,3 +687,16 @@ def test_just_system_message():
|
||||||
llm_provider="bedrock",
|
llm_provider="bedrock",
|
||||||
)
|
)
|
||||||
assert "bedrock requires at least one non-system message" in str(e.value)
|
assert "bedrock requires at least one non-system message" in str(e.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_generic_image_chunk_to_openai_image_obj():
|
||||||
|
from litellm.llms.prompt_templates.factory import (
|
||||||
|
convert_generic_image_chunk_to_openai_image_obj,
|
||||||
|
convert_to_anthropic_image_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
url = "https://i.pinimg.com/736x/b4/b1/be/b4b1becad04d03a9071db2817fc9fe77.jpg"
|
||||||
|
image_obj = convert_to_anthropic_image_obj(url)
|
||||||
|
url_str = convert_generic_image_chunk_to_openai_image_obj(image_obj)
|
||||||
|
image_obj = convert_to_anthropic_image_obj(url_str)
|
||||||
|
print(image_obj)
|
||||||
|
|
|
@ -1298,20 +1298,3 @@ def test_vertex_embedding_url(model, expected_url):
|
||||||
|
|
||||||
assert url == expected_url
|
assert url == expected_url
|
||||||
assert endpoint == "predict"
|
assert endpoint == "predict"
|
||||||
|
|
||||||
|
|
||||||
from base_llm_unit_tests import BaseLLMChatTest
|
|
||||||
|
|
||||||
|
|
||||||
class TestVertexGemini(BaseLLMChatTest):
|
|
||||||
def get_base_completion_call_args(self) -> dict:
|
|
||||||
return {"model": "gemini/gemini-1.5-flash"}
|
|
||||||
|
|
||||||
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
|
||||||
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
|
||||||
from litellm.llms.prompt_templates.factory import (
|
|
||||||
convert_to_gemini_tool_call_invoke,
|
|
||||||
)
|
|
||||||
|
|
||||||
result = convert_to_gemini_tool_call_invoke(tool_call_no_arguments)
|
|
||||||
print(result)
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue