fix(utils.py): fix get_image_dimensions to handle more image types

Fixes https://github.com/BerriAI/litellm/issues/5205
This commit is contained in:
Krrish Dholakia 2024-08-16 12:00:04 -07:00
parent cbdaecb5a8
commit 7129e93992
2 changed files with 50 additions and 23 deletions

View file

@ -14,7 +14,9 @@ import binascii
import copy
import datetime
import hashlib
import imghdr
import inspect
import io
import itertools
import json
import logging
@ -1797,35 +1799,41 @@ def calculate_tiles_needed(
def get_image_dimensions(data):
img_data = None
# Check if data is a URL by trying to parse it
try:
response = requests.get(data)
response.raise_for_status() # Check if the request was successful
img_data = response.content
# Try to open as URL
# Try to open as URL
client = HTTPHandler(concurrent_limit=1)
response = client.get(data)
img_data = response.read()
except Exception:
# Data is not a URL, handle as base64
# If not URL, assume it's base64
header, encoded = data.split(",", 1)
img_data = base64.b64decode(encoded)
# Try to determine dimensions from headers
# This is a very simplistic check, primarily works with PNG and non-progressive JPEG
if img_data[:8] == b"\x89PNG\r\n\x1a\n":
# PNG Image; width and height are 4 bytes each and start at offset 16
width, height = struct.unpack(">ii", img_data[16:24])
return width, height
elif img_data[:2] == b"\xff\xd8":
# JPEG Image; for dimensions, SOF0 block (0xC0) gives dimensions at offset 3 for length, and then 5 and 7 for height and width
# This will NOT find dimensions for all JPEGs (e.g., progressive JPEGs)
# Find SOF0 marker (0xFF followed by 0xC0)
sof = re.search(b"\xff\xc0....", img_data)
if sof:
# Parse SOF0 block to find dimensions
height, width = struct.unpack(">HH", sof.group()[5:9])
return width, height
else:
return None, None
img_type = imghdr.what(None, h=img_data)
if img_type == "png":
w, h = struct.unpack(">LL", img_data[16:24])
return w, h
elif img_type == "gif":
w, h = struct.unpack("<HH", img_data[6:10])
return w, h
elif img_type == "jpeg":
with io.BytesIO(img_data) as fhandle:
fhandle.seek(0)
size = 2
ftype = 0
while not 0xC0 <= ftype <= 0xCF or ftype in (0xC4, 0xC8, 0xCC):
fhandle.seek(size, 1)
byte = fhandle.read(1)
while ord(byte) == 0xFF:
byte = fhandle.read(1)
ftype = ord(byte)
size = struct.unpack(">H", fhandle.read(2))[0] - 2
fhandle.seek(1, 1)
h, w = struct.unpack(">HH", fhandle.read(4))
return w, h
else:
# Unsupported format
return None, None