fix(factory.py): support gemini-pro-vision on google ai studio

https://github.com/BerriAI/litellm/issues/1329
This commit is contained in:
Krrish Dholakia 2024-01-06 22:36:22 +05:30
parent a7245dba07
commit 5fd2f945f3
4 changed files with 132 additions and 2 deletions

View file

@ -385,6 +385,81 @@ def anthropic_pt(
prompt += f"{AnthropicConstants.AI_PROMPT.value}"
return prompt
def _load_image_from_url(image_url: str):
"""
Loads an image from a URL.
Args:
image_url (str): The URL of the image.
Returns:
Image: The loaded image.
"""
from io import BytesIO
try:
from PIL import Image
except:
raise Exception("gemini image conversion failed please run `pip install Pillow`")
# Download the image from the URL
response = requests.get(image_url)
image = Image.open(BytesIO(response.content))
return image
def _gemini_vision_convert_messages(messages: list):
"""
Converts given messages for GPT-4 Vision to Gemini format.
Args:
messages (list): The messages to convert. Each message can be a dictionary with a "content" key. The content can be a string or a list of elements. If it is a string, it will be concatenated to the prompt. If it is a list, each element will be processed based on its type:
- If the element is a dictionary with a "type" key equal to "text", its "text" value will be concatenated to the prompt.
- If the element is a dictionary with a "type" key equal to "image_url", its "image_url" value will be added to the list of images.
Returns:
tuple: A tuple containing the prompt (a string) and the processed images (a list of objects representing the images).
"""
try:
from PIL import Image
except:
raise Exception("gemini image conversion failed please run `pip install Pillow`")
try:
# given messages for gpt-4 vision, convert them for gemini
# https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/getting-started/intro_gemini_python.ipynb
prompt = ""
images = []
for message in messages:
if isinstance(message["content"], str):
prompt += message["content"]
elif isinstance(message["content"], list):
# see https://docs.litellm.ai/docs/providers/openai#openai-vision-models
for element in message["content"]:
if isinstance(element, dict):
if element["type"] == "text":
prompt += element["text"]
elif element["type"] == "image_url":
image_url = element["image_url"]["url"]
images.append(image_url)
# processing images passed to gemini
processed_images = []
for img in images:
if "https:/" in img:
# Case 1: Image from URL
image = _load_image_from_url(img)
processed_images.append(image)
else:
# Case 2: Image filepath (e.g. temp.jpeg) given
image = Image.open(img)
processed_images.append(image)
content = [prompt] + processed_images
return content
except Exception as e:
raise e
def gemini_text_image_pt(messages: list):
"""
@ -511,7 +586,10 @@ def prompt_factory(
messages=messages, prompt_format=prompt_format, chat_template=chat_template
)
elif custom_llm_provider == "gemini":
return gemini_text_image_pt(messages=messages)
if model == "gemini-pro-vision":
return _gemini_vision_convert_messages(messages=messages)
else:
return gemini_text_image_pt(messages=messages)
try:
if "meta-llama/llama-2" in model and "chat" in model:
return llama_2_chat_pt(messages=messages)

View file

@ -0,0 +1,33 @@
import os, sys, traceback
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from dotenv import load_dotenv
def generate_text():
try:
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What's in this image?"
},
{
"type": "image_url",
"image_url": {
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
}
}
]
}
]
response = litellm.completion(model="gemini/gemini-pro-vision", messages=messages)
print(response)
except Exception as exception:
raise Exception("An error occurred during text generation:", exception)
generate_text()

View file

@ -4012,7 +4012,10 @@ def get_llm_provider(
api_base = "https://api.voyageai.com/v1"
dynamic_api_key = get_secret("VOYAGE_API_KEY")
return model, custom_llm_provider, dynamic_api_key, api_base
elif model.split("/", 1)[0] in litellm.provider_list:
custom_llm_provider = model.split("/", 1)[0]
model = model.split("/", 1)[1]
return model, custom_llm_provider, dynamic_api_key, api_base
# check if api base is a known openai compatible endpoint
if api_base:
for endpoint in litellm.openai_compatible_endpoints:

View file

@ -583,6 +583,22 @@
"litellm_provider": "palm",
"mode": "completion"
},
"gemini/gemini-pro": {
"max_tokens": 30720,
"max_output_tokens": 2048,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "gemini",
"mode": "chat"
},
"gemini/gemini-pro-vision": {
"max_tokens": 30720,
"max_output_tokens": 2048,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "gemini",
"mode": "chat"
},
"command-nightly": {
"max_tokens": 4096,
"input_cost_per_token": 0.000015,