forked from phoenix/litellm-mirror
fix(factory.py): support gemini-pro-vision on google ai studio
https://github.com/BerriAI/litellm/issues/1329
This commit is contained in:
parent
a7245dba07
commit
5fd2f945f3
4 changed files with 132 additions and 2 deletions
|
@ -385,6 +385,81 @@ def anthropic_pt(
|
|||
prompt += f"{AnthropicConstants.AI_PROMPT.value}"
|
||||
return prompt
|
||||
|
||||
|
||||
def _load_image_from_url(image_url: str):
|
||||
"""
|
||||
Loads an image from a URL.
|
||||
|
||||
Args:
|
||||
image_url (str): The URL of the image.
|
||||
|
||||
Returns:
|
||||
Image: The loaded image.
|
||||
"""
|
||||
from io import BytesIO
|
||||
try:
|
||||
from PIL import Image
|
||||
except:
|
||||
raise Exception("gemini image conversion failed please run `pip install Pillow`")
|
||||
|
||||
# Download the image from the URL
|
||||
response = requests.get(image_url)
|
||||
image = Image.open(BytesIO(response.content))
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def _gemini_vision_convert_messages(messages: list):
|
||||
"""
|
||||
Converts given messages for GPT-4 Vision to Gemini format.
|
||||
|
||||
Args:
|
||||
messages (list): The messages to convert. Each message can be a dictionary with a "content" key. The content can be a string or a list of elements. If it is a string, it will be concatenated to the prompt. If it is a list, each element will be processed based on its type:
|
||||
- If the element is a dictionary with a "type" key equal to "text", its "text" value will be concatenated to the prompt.
|
||||
- If the element is a dictionary with a "type" key equal to "image_url", its "image_url" value will be added to the list of images.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing the prompt (a string) and the processed images (a list of objects representing the images).
|
||||
"""
|
||||
try:
|
||||
from PIL import Image
|
||||
except:
|
||||
raise Exception("gemini image conversion failed please run `pip install Pillow`")
|
||||
|
||||
try:
|
||||
|
||||
# given messages for gpt-4 vision, convert them for gemini
|
||||
# https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/getting-started/intro_gemini_python.ipynb
|
||||
prompt = ""
|
||||
images = []
|
||||
for message in messages:
|
||||
if isinstance(message["content"], str):
|
||||
prompt += message["content"]
|
||||
elif isinstance(message["content"], list):
|
||||
# see https://docs.litellm.ai/docs/providers/openai#openai-vision-models
|
||||
for element in message["content"]:
|
||||
if isinstance(element, dict):
|
||||
if element["type"] == "text":
|
||||
prompt += element["text"]
|
||||
elif element["type"] == "image_url":
|
||||
image_url = element["image_url"]["url"]
|
||||
images.append(image_url)
|
||||
# processing images passed to gemini
|
||||
processed_images = []
|
||||
for img in images:
|
||||
if "https:/" in img:
|
||||
# Case 1: Image from URL
|
||||
image = _load_image_from_url(img)
|
||||
processed_images.append(image)
|
||||
else:
|
||||
# Case 2: Image filepath (e.g. temp.jpeg) given
|
||||
image = Image.open(img)
|
||||
processed_images.append(image)
|
||||
content = [prompt] + processed_images
|
||||
return content
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def gemini_text_image_pt(messages: list):
|
||||
"""
|
||||
|
@ -511,7 +586,10 @@ def prompt_factory(
|
|||
messages=messages, prompt_format=prompt_format, chat_template=chat_template
|
||||
)
|
||||
elif custom_llm_provider == "gemini":
|
||||
return gemini_text_image_pt(messages=messages)
|
||||
if model == "gemini-pro-vision":
|
||||
return _gemini_vision_convert_messages(messages=messages)
|
||||
else:
|
||||
return gemini_text_image_pt(messages=messages)
|
||||
try:
|
||||
if "meta-llama/llama-2" in model and "chat" in model:
|
||||
return llama_2_chat_pt(messages=messages)
|
||||
|
|
33
litellm/tests/test_google_ai_studio_gemini.py
Normal file
33
litellm/tests/test_google_ai_studio_gemini.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
import os, sys, traceback
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from dotenv import load_dotenv
|
||||
|
||||
def generate_text():
|
||||
try:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this image?"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
response = litellm.completion(model="gemini/gemini-pro-vision", messages=messages)
|
||||
print(response)
|
||||
except Exception as exception:
|
||||
raise Exception("An error occurred during text generation:", exception)
|
||||
|
||||
generate_text()
|
|
@ -4012,7 +4012,10 @@ def get_llm_provider(
|
|||
api_base = "https://api.voyageai.com/v1"
|
||||
dynamic_api_key = get_secret("VOYAGE_API_KEY")
|
||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||
|
||||
elif model.split("/", 1)[0] in litellm.provider_list:
|
||||
custom_llm_provider = model.split("/", 1)[0]
|
||||
model = model.split("/", 1)[1]
|
||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||
# check if api base is a known openai compatible endpoint
|
||||
if api_base:
|
||||
for endpoint in litellm.openai_compatible_endpoints:
|
||||
|
|
|
@ -583,6 +583,22 @@
|
|||
"litellm_provider": "palm",
|
||||
"mode": "completion"
|
||||
},
|
||||
"gemini/gemini-pro": {
|
||||
"max_tokens": 30720,
|
||||
"max_output_tokens": 2048,
|
||||
"input_cost_per_token": 0.0,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "gemini",
|
||||
"mode": "chat"
|
||||
},
|
||||
"gemini/gemini-pro-vision": {
|
||||
"max_tokens": 30720,
|
||||
"max_output_tokens": 2048,
|
||||
"input_cost_per_token": 0.0,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "gemini",
|
||||
"mode": "chat"
|
||||
},
|
||||
"command-nightly": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.000015,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue