Merge pull request #5255 from BerriAI/litellm_fix_token_counter

fix(utils.py): fix get_image_dimensions to handle more image types
This commit is contained in:
Krish Dholakia 2024-08-16 17:27:27 -07:00 committed by GitHub
commit 6cf8c47366
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 50 additions and 23 deletions

View file

@ -356,3 +356,22 @@ def test_gpt_4o_token_counter():
) )
mock_client.assert_called() mock_client.assert_called()
@pytest.mark.parametrize(
"img_url",
[
"https://blog.purpureus.net/assets/blog/personal_key_rotation/simplified-asset-graph.jpg",
"",
],
)
def test_img_url_token_counter(img_url):
from litellm.utils import get_image_dimensions
width, height = get_image_dimensions(data=img_url)
print(width, height)
assert width is not None
assert height is not None

View file

@ -14,7 +14,9 @@ import binascii
import copy import copy
import datetime import datetime
import hashlib import hashlib
import imghdr
import inspect import inspect
import io
import itertools import itertools
import json import json
import logging import logging
@ -1797,35 +1799,41 @@ def calculate_tiles_needed(
def get_image_dimensions(data): def get_image_dimensions(data):
img_data = None img_data = None
# Check if data is a URL by trying to parse it
try: try:
response = requests.get(data) # Try to open as URL
response.raise_for_status() # Check if the request was successful # Try to open as URL
img_data = response.content client = HTTPHandler(concurrent_limit=1)
response = client.get(data)
img_data = response.read()
except Exception: except Exception:
# Data is not a URL, handle as base64 # If not URL, assume it's base64
header, encoded = data.split(",", 1) header, encoded = data.split(",", 1)
img_data = base64.b64decode(encoded) img_data = base64.b64decode(encoded)
# Try to determine dimensions from headers img_type = imghdr.what(None, h=img_data)
# This is a very simplistic check, primarily works with PNG and non-progressive JPEG
if img_data[:8] == b"\x89PNG\r\n\x1a\n": if img_type == "png":
# PNG Image; width and height are 4 bytes each and start at offset 16 w, h = struct.unpack(">LL", img_data[16:24])
width, height = struct.unpack(">ii", img_data[16:24]) return w, h
return width, height elif img_type == "gif":
elif img_data[:2] == b"\xff\xd8": w, h = struct.unpack("<HH", img_data[6:10])
# JPEG Image; for dimensions, SOF0 block (0xC0) gives dimensions at offset 3 for length, and then 5 and 7 for height and width return w, h
# This will NOT find dimensions for all JPEGs (e.g., progressive JPEGs) elif img_type == "jpeg":
# Find SOF0 marker (0xFF followed by 0xC0) with io.BytesIO(img_data) as fhandle:
sof = re.search(b"\xff\xc0....", img_data) fhandle.seek(0)
if sof: size = 2
# Parse SOF0 block to find dimensions ftype = 0
height, width = struct.unpack(">HH", sof.group()[5:9]) while not 0xC0 <= ftype <= 0xCF or ftype in (0xC4, 0xC8, 0xCC):
return width, height fhandle.seek(size, 1)
else: byte = fhandle.read(1)
return None, None while ord(byte) == 0xFF:
byte = fhandle.read(1)
ftype = ord(byte)
size = struct.unpack(">H", fhandle.read(2))[0] - 2
fhandle.seek(1, 1)
h, w = struct.unpack(">HH", fhandle.read(4))
return w, h
else: else:
# Unsupported format
return None, None return None, None