Merge pull request #4266 from BerriAI/litellm_gemini_image_url

Support 'image url' to vertex ai / google ai studio gemini models
This commit is contained in:
Krish Dholakia 2024-06-18 20:39:25 -07:00 committed by GitHub
commit 0c2c02ba8d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 140 additions and 143 deletions

View file

@ -242,7 +242,7 @@ class Logging:
extra={"api_base": {api_base}, **masked_headers}, extra={"api_base": {api_base}, **masked_headers},
) )
else: else:
verbose_logger.debug(f"\033[92m{curl_command}\033[0m\n") print_verbose(f"\033[92m{curl_command}\033[0m\n", log_level="DEBUG")
# log raw request to provider (like LangFuse) -- if opted in. # log raw request to provider (like LangFuse) -- if opted in.
if log_raw_request_response is True: if log_raw_request_response is True:
try: try:

View file

@ -1,24 +1,30 @@
import json
import re
import traceback
import uuid
import xml.etree.ElementTree as ET
from enum import Enum from enum import Enum
import requests, traceback
import json, re, xml.etree.ElementTree as ET
from jinja2 import Template, exceptions, meta, BaseLoader
from jinja2.sandbox import ImmutableSandboxedEnvironment
from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple
import requests
from jinja2 import BaseLoader, Template, exceptions, meta
from jinja2.sandbox import ImmutableSandboxedEnvironment
import litellm import litellm
import litellm.types import litellm.types
from litellm.types.completion import (
ChatCompletionUserMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionMessageParam,
ChatCompletionFunctionMessageParam,
ChatCompletionMessageToolCallParam,
ChatCompletionToolMessageParam,
)
import litellm.types.llms import litellm.types.llms
from litellm.types.llms.anthropic import *
import uuid
from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock
import litellm.types.llms.vertex_ai import litellm.types.llms.vertex_ai
from litellm.types.completion import (
ChatCompletionFunctionMessageParam,
ChatCompletionMessageParam,
ChatCompletionMessageToolCallParam,
ChatCompletionSystemMessageParam,
ChatCompletionToolMessageParam,
ChatCompletionUserMessageParam,
)
from litellm.types.llms.anthropic import *
from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock
from litellm.types.utils import GenericImageParsingChunk
def default_pt(messages): def default_pt(messages):
@ -622,9 +628,10 @@ def construct_tool_use_system_prompt(
def convert_url_to_base64(url): def convert_url_to_base64(url):
import requests
import base64 import base64
import requests
for _ in range(3): for _ in range(3):
try: try:
response = requests.get(url) response = requests.get(url)
@ -654,7 +661,7 @@ def convert_url_to_base64(url):
raise Exception(f"Error: Unable to fetch image from URL. url={url}") raise Exception(f"Error: Unable to fetch image from URL. url={url}")
def convert_to_anthropic_image_obj(openai_image_url: str): def convert_to_anthropic_image_obj(openai_image_url: str) -> GenericImageParsingChunk:
""" """
Input: Input:
"image_url": "data:image/jpeg;base64,{base64_image}", "image_url": "data:image/jpeg;base64,{base64_image}",
@ -675,11 +682,11 @@ def convert_to_anthropic_image_obj(openai_image_url: str):
# Infer image format from the URL # Infer image format from the URL
image_format = openai_image_url.split("data:image/")[1].split(";base64,")[0] image_format = openai_image_url.split("data:image/")[1].split(";base64,")[0]
return { return GenericImageParsingChunk(
"type": "base64", type="base64",
"media_type": f"image/{image_format}", media_type=f"image/{image_format}",
"data": base64_data, data=base64_data,
} )
except Exception as e: except Exception as e:
if "Error: Unable to fetch image from URL" in str(e): if "Error: Unable to fetch image from URL" in str(e):
raise e raise e
@ -1606,19 +1613,23 @@ def azure_text_pt(messages: list):
###### AMAZON BEDROCK ####### ###### AMAZON BEDROCK #######
from litellm.types.llms.bedrock import ContentBlock as BedrockContentBlock
from litellm.types.llms.bedrock import ImageBlock as BedrockImageBlock
from litellm.types.llms.bedrock import ImageSourceBlock as BedrockImageSourceBlock
from litellm.types.llms.bedrock import ToolBlock as BedrockToolBlock
from litellm.types.llms.bedrock import ( from litellm.types.llms.bedrock import (
ToolResultContentBlock as BedrockToolResultContentBlock,
ToolResultBlock as BedrockToolResultBlock,
ToolConfigBlock as BedrockToolConfigBlock,
ToolUseBlock as BedrockToolUseBlock,
ImageSourceBlock as BedrockImageSourceBlock,
ImageBlock as BedrockImageBlock,
ContentBlock as BedrockContentBlock,
ToolInputSchemaBlock as BedrockToolInputSchemaBlock,
ToolSpecBlock as BedrockToolSpecBlock,
ToolBlock as BedrockToolBlock,
ToolChoiceValuesBlock as BedrockToolChoiceValuesBlock, ToolChoiceValuesBlock as BedrockToolChoiceValuesBlock,
) )
from litellm.types.llms.bedrock import ToolConfigBlock as BedrockToolConfigBlock
from litellm.types.llms.bedrock import (
ToolInputSchemaBlock as BedrockToolInputSchemaBlock,
)
from litellm.types.llms.bedrock import ToolResultBlock as BedrockToolResultBlock
from litellm.types.llms.bedrock import (
ToolResultContentBlock as BedrockToolResultContentBlock,
)
from litellm.types.llms.bedrock import ToolSpecBlock as BedrockToolSpecBlock
from litellm.types.llms.bedrock import ToolUseBlock as BedrockToolUseBlock
def get_image_details(image_url) -> Tuple[str, str]: def get_image_details(image_url) -> Tuple[str, str]:
@ -1655,7 +1666,8 @@ def get_image_details(image_url) -> Tuple[str, str]:
def _process_bedrock_converse_image_block(image_url: str) -> BedrockImageBlock: def _process_bedrock_converse_image_block(image_url: str) -> BedrockImageBlock:
if "base64" in image_url: if "base64" in image_url:
# Case 1: Images with base64 encoding # Case 1: Images with base64 encoding
import base64, re import base64
import re
# base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image> # base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image>
image_metadata, img_without_base_64 = image_url.split(",") image_metadata, img_without_base_64 = image_url.split(",")

View file

@ -1,18 +1,22 @@
import os, types import inspect
import json import json
from enum import Enum import os
import requests # type: ignore
import time import time
from typing import Callable, Optional, Union, List, Literal, Any import types
import uuid
from enum import Enum
from typing import Any, Callable, List, Literal, Optional, Union
import httpx # type: ignore
import requests # type: ignore
from pydantic import BaseModel from pydantic import BaseModel
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.core_helpers import map_finish_reason
import litellm, uuid
import httpx, inspect # type: ignore
from litellm.types.llms.vertex_ai import *
from litellm.llms.prompt_templates.factory import ( from litellm.llms.prompt_templates.factory import (
convert_to_gemini_tool_call_result, convert_to_anthropic_image_obj,
convert_to_gemini_tool_call_invoke, convert_to_gemini_tool_call_invoke,
convert_to_gemini_tool_call_result,
) )
from litellm.types.files import ( from litellm.types.files import (
get_file_mime_type_for_file_type, get_file_mime_type_for_file_type,
@ -20,6 +24,8 @@ from litellm.types.files import (
is_gemini_1_5_accepted_file_type, is_gemini_1_5_accepted_file_type,
is_video_file_type, is_video_file_type,
) )
from litellm.types.llms.vertex_ai import *
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
class VertexAIError(Exception): class VertexAIError(Exception):
@ -274,28 +280,6 @@ def _get_image_bytes_from_url(image_url: str) -> bytes:
raise Exception(f"An exception occurs with this image - {str(e)}") raise Exception(f"An exception occurs with this image - {str(e)}")
def _load_image_from_url(image_url: str):
"""
Loads an image from a URL.
Args:
image_url (str): The URL of the image.
Returns:
Image: The loaded image.
"""
from vertexai.preview.generative_models import (
GenerativeModel,
Part,
GenerationConfig,
Image,
)
image_bytes = _get_image_bytes_from_url(image_url)
return Image.from_bytes(data=image_bytes)
def _convert_gemini_role(role: str) -> Literal["user", "model"]: def _convert_gemini_role(role: str) -> Literal["user", "model"]:
if role == "user": if role == "user":
return "user" return "user"
@ -323,28 +307,9 @@ def _process_gemini_image(image_url: str) -> PartType:
return PartType(file_data=file_data) return PartType(file_data=file_data)
# Direct links # Direct links
elif "https:/" in image_url: elif "https:/" in image_url or "base64" in image_url:
image = _load_image_from_url(image_url) image = convert_to_anthropic_image_obj(image_url)
_blob = BlobType(data=image.data, mime_type=image._mime_type) _blob = BlobType(data=image["data"], mime_type=image["media_type"])
return PartType(inline_data=_blob)
# Base64 encoding
elif "base64" in image_url:
import base64, re
# base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image>
image_metadata, img_without_base_64 = image_url.split(",")
# read mime_type from img_without_base_64=data:image/jpeg;base64
# Extract MIME type using regular expression
mime_type_match = re.match(r"data:(.*?);base64", image_metadata)
if mime_type_match:
mime_type = mime_type_match.group(1)
else:
mime_type = "image/jpeg"
decoded_img = base64.b64decode(img_without_base_64)
_blob = BlobType(data=decoded_img, mime_type=mime_type)
return PartType(inline_data=_blob) return PartType(inline_data=_blob)
raise Exception("Invalid image received - {}".format(image_url)) raise Exception("Invalid image received - {}".format(image_url))
except Exception as e: except Exception as e:
@ -480,23 +445,25 @@ def completion(
message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""", message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
) )
try: try:
import google.auth # type: ignore
import proto # type: ignore
from google.cloud import aiplatform # type: ignore
from google.cloud.aiplatform_v1beta1.types import (
content as gapic_content_types, # type: ignore
)
from google.protobuf import json_format # type: ignore
from google.protobuf.struct_pb2 import Value # type: ignore
from vertexai.language_models import CodeGenerationModel, TextGenerationModel
from vertexai.preview.generative_models import (
GenerationConfig,
GenerativeModel,
Part,
)
from vertexai.preview.language_models import ( from vertexai.preview.language_models import (
ChatModel, ChatModel,
CodeChatModel, CodeChatModel,
InputOutputTextPair, InputOutputTextPair,
) )
from vertexai.language_models import TextGenerationModel, CodeGenerationModel
from vertexai.preview.generative_models import (
GenerativeModel,
Part,
GenerationConfig,
)
from google.cloud import aiplatform # type: ignore
from google.protobuf import json_format # type: ignore
from google.protobuf.struct_pb2 import Value # type: ignore
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore
import google.auth # type: ignore
import proto # type: ignore
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744 ## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
print_verbose( print_verbose(
@ -1412,8 +1379,8 @@ def embedding(
message="vertexai import failed please run `pip install google-cloud-aiplatform`", message="vertexai import failed please run `pip install google-cloud-aiplatform`",
) )
from vertexai.language_models import TextEmbeddingModel, TextEmbeddingInput
import google.auth # type: ignore import google.auth # type: ignore
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744 ## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
try: try:

View file

@ -21,6 +21,7 @@ import litellm.litellm_core_utils.litellm_logging
from litellm import verbose_logger from litellm import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.prompt_templates.factory import convert_url_to_base64
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history from litellm.llms.vertex_ai import _gemini_convert_messages_with_history
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
ChatCompletionResponseMessage, ChatCompletionResponseMessage,

View file

@ -568,8 +568,6 @@ async def test_gemini_pro_vision(provider, sync_mode):
# DO Not DELETE this ASSERT # DO Not DELETE this ASSERT
# Google counts the prompt tokens for us, we should ensure we use the tokens from the orignal response # Google counts the prompt tokens for us, we should ensure we use the tokens from the orignal response
assert prompt_tokens == 263 # the gemini api returns 263 to us assert prompt_tokens == 263 # the gemini api returns 263 to us
# assert False
except litellm.RateLimitError as e: except litellm.RateLimitError as e:
pass pass
except Exception as e: except Exception as e:
@ -1152,38 +1150,44 @@ async def test_vertexai_aembedding():
# raise e # raise e
# test_gemini_pro_vision_stream() # test_gemini_pro_vision_stream()
# def test_gemini_pro_vision_async():
# try: def test_gemini_pro_vision_async():
# litellm.set_verbose = True try:
# litellm.num_retries=0 litellm.set_verbose = True
# async def test(): litellm.num_retries = 0
# resp = await litellm.acompletion(
# model = "vertex_ai/gemini-pro-vision", async def test():
# messages=[ load_vertex_ai_credentials()
# { resp = await litellm.acompletion(
# "role": "user", model="vertex_ai/gemini-pro-vision",
# "content": [ messages=[
# { {
# "type": "text", "role": "user",
# "text": "Whats in this image?" "content": [
# }, {"type": "text", "text": "Whats in this image?"},
# { {
# "type": "image_url", "type": "image_url",
# "image_url": { "image_url": {
# "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
# } },
# } },
# ] ],
# } }
# ], ],
# ) )
# print("async response gemini pro vision") print("async response gemini pro vision")
# print(resp) print(resp)
# asyncio.run(test())
# except Exception as e: asyncio.run(test())
# import traceback except litellm.RateLimitError:
# traceback.print_exc() pass
# raise e except Exception as e:
import traceback
traceback.print_exc()
raise e
# test_gemini_pro_vision_async() # test_gemini_pro_vision_async()

View file

@ -694,8 +694,10 @@ def test_completion_claude_3_base64():
pytest.fail(f"An exception occurred - {str(e)}") pytest.fail(f"An exception occurred - {str(e)}")
@pytest.mark.skip(reason="issue getting wikipedia images in ci/cd") @pytest.mark.parametrize(
def test_completion_claude_3_function_plus_image(): "model", ["gemini/gemini-1.5-flash"] # "claude-3-sonnet-20240229",
)
def test_completion_function_plus_image(model):
litellm.set_verbose = True litellm.set_verbose = True
image_content = [ image_content = [
@ -703,7 +705,7 @@ def test_completion_claude_3_function_plus_image():
{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" "url": "https://litellm-listing.s3.amazonaws.com/litellm_logo.png"
}, },
}, },
] ]
@ -719,7 +721,7 @@ def test_completion_claude_3_function_plus_image():
"type": "object", "type": "object",
"properties": { "properties": {
"location": { "location": {
"type": "text", "type": "string",
"description": "The city and state, e.g. San Francisco, CA", "description": "The city and state, e.g. San Francisco, CA",
}, },
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
@ -739,7 +741,7 @@ def test_completion_claude_3_function_plus_image():
] ]
response = completion( response = completion(
model="claude-3-sonnet-20240229", model=model,
messages=[image_message], messages=[image_message],
tool_choice=tool_choice, tool_choice=tool_choice,
tools=tools, tools=tools,

View file

@ -39,7 +39,7 @@ class FileDataType(TypedDict):
class BlobType(TypedDict): class BlobType(TypedDict):
mime_type: Required[str] mime_type: Required[str]
data: Required[bytes] data: Required[str]
class PartType(TypedDict, total=False): class PartType(TypedDict, total=False):

View file

@ -971,3 +971,14 @@ class TranscriptionResponse(OpenAIObject):
except: except:
# if using pydantic v1 # if using pydantic v1
return self.dict() return self.dict()
class GenericImageParsingChunk(TypedDict):
# {
# "type": "base64",
# "media_type": f"image/{image_format}",
# "data": base64_data,
# }
type: str
media_type: str
data: str

View file

@ -2647,7 +2647,7 @@ def get_optional_params(
if presence_penalty is not None: if presence_penalty is not None:
optional_params["presencePenalty"] = {"scale": presence_penalty} optional_params["presencePenalty"] = {"scale": presence_penalty}
elif ( elif (
custom_llm_provider == "palm" or custom_llm_provider == "gemini" custom_llm_provider == "palm"
): # https://developers.generativeai.google/tutorials/curl_quickstart ): # https://developers.generativeai.google/tutorials/curl_quickstart
## check if unsupported param passed in ## check if unsupported param passed in
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
@ -2694,7 +2694,7 @@ def get_optional_params(
print_verbose( print_verbose(
f"(end) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK - optional_params: {optional_params}" f"(end) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK - optional_params: {optional_params}"
) )
elif custom_llm_provider == "vertex_ai_beta": elif custom_llm_provider == "vertex_ai_beta" or custom_llm_provider == "gemini":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
) )
@ -3726,7 +3726,7 @@ def get_supported_openai_params(
elif request_type == "embeddings": elif request_type == "embeddings":
return litellm.DatabricksEmbeddingConfig().get_supported_openai_params() return litellm.DatabricksEmbeddingConfig().get_supported_openai_params()
elif custom_llm_provider == "palm" or custom_llm_provider == "gemini": elif custom_llm_provider == "palm" or custom_llm_provider == "gemini":
return ["temperature", "top_p", "stream", "n", "stop", "max_tokens"] return litellm.VertexAIConfig().get_supported_openai_params()
elif custom_llm_provider == "vertex_ai": elif custom_llm_provider == "vertex_ai":
if request_type == "chat_completion": if request_type == "chat_completion":
return litellm.VertexAIConfig().get_supported_openai_params() return litellm.VertexAIConfig().get_supported_openai_params()