fix(utils.py): support token counting for gpt-4-vision models

This commit is contained in:
Krrish Dholakia 2024-01-02 14:41:28 +05:30
parent eda6ab8cdc
commit 0fffcc1579
3 changed files with 237 additions and 7 deletions

View file

@ -0,0 +1,78 @@
from typing import Literal
def calculage_img_tokens(
width,
height,
mode: Literal["low", "high", "auto"] = "auto",
base_tokens: int = 85, # openai default - https://openai.com/pricing
):
if mode == "low":
return base_tokens
elif mode == "high" or mode == "auto":
resized_width, resized_height = resize_image_high_res(
width=width, height=height
)
tiles_needed_high_res = calculate_tiles_needed(resized_width, resized_height)
tile_tokens = (base_tokens * 2) * tiles_needed_high_res
total_tokens = base_tokens + tile_tokens
return total_tokens
def resize_image_high_res(width, height):
# Maximum dimensions for high res mode
max_short_side = 768
max_long_side = 2000
# Determine the longer and shorter sides
longer_side = max(width, height)
shorter_side = min(width, height)
# Calculate the aspect ratio
aspect_ratio = longer_side / shorter_side
# Resize based on the short side being 768px
if width <= height: # Portrait or square
resized_width = max_short_side
resized_height = int(resized_width * aspect_ratio)
# if the long side exceeds the limit after resizing, adjust both sides accordingly
if resized_height > max_long_side:
resized_height = max_long_side
resized_width = int(resized_height / aspect_ratio)
else: # Landscape
resized_height = max_short_side
resized_width = int(resized_height * aspect_ratio)
# if the long side exceeds the limit after resizing, adjust both sides accordingly
if resized_width > max_long_side:
resized_width = max_long_side
resized_height = int(resized_width / aspect_ratio)
return resized_width, resized_height
# Test the function with the given example
def calculate_tiles_needed(
resized_width, resized_height, tile_width=512, tile_height=512
):
tiles_across = (resized_width + tile_width - 1) // tile_width
tiles_down = (resized_height + tile_height - 1) // tile_height
total_tiles = tiles_across * tiles_down
return total_tiles
# Test high res mode with 1875 x 768 image
resized_width_high_res = 1875
resized_height_high_res = 768
tiles_needed_high_res = calculate_tiles_needed(
resized_width_high_res, resized_height_high_res
)
print(
f"Tiles needed for high res image ({resized_width_high_res}x{resized_height_high_res}): {tiles_needed_high_res}"
)
# If you had the original size and needed to resize and then calculate tiles:
original_size = (10000, 4096)
resized_size_high_res = resize_image_high_res(*original_size)
print(f"Resized dimensions in high res mode: {resized_size_high_res}")
tiles_needed = calculate_tiles_needed(*resized_size_high_res)
print(f"Tiles needed for high res image {resized_size_high_res}: {tiles_needed}")

View file

@ -119,3 +119,23 @@ def test_encoding_and_decoding():
# test_encoding_and_decoding() # test_encoding_and_decoding()
def test_gpt_vision_token_counting():
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Whats in this image?"},
{
"type": "image_url",
"image_url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
},
],
}
]
tokens = token_counter(model="gpt-4-vision-preview", messages=messages)
print(f"tokens: {tokens}")
# test_gpt_vision_token_counting()

View file

@ -7,7 +7,7 @@
# #
# Thank you users! We ❤️ you! - Krrish & Ishaan # Thank you users! We ❤️ you! - Krrish & Ishaan
import sys, re, binascii import sys, re, binascii, struct
import litellm import litellm
import dotenv, json, traceback, threading, base64 import dotenv, json, traceback, threading, base64
import subprocess, os import subprocess, os
@ -2495,15 +2495,127 @@ def openai_token_counter(
for message in messages: for message in messages:
num_tokens += tokens_per_message num_tokens += tokens_per_message
for key, value in message.items(): for key, value in message.items():
num_tokens += len(encoding.encode(value, disallowed_special=())) if isinstance(value, str):
if key == "name": num_tokens += len(encoding.encode(value, disallowed_special=()))
num_tokens += tokens_per_name if key == "name":
num_tokens += tokens_per_name
elif isinstance(value, List):
for c in value:
if c["type"] == "text":
text += c["text"]
elif c["type"] == "image_url":
if isinstance(c["image_url"], dict):
image_url_dict = c["image_url"]
detail = image_url_dict.get("detail", "auto")
url = image_url_dict.get("url")
num_tokens += calculage_img_tokens(
data=url, mode=detail
)
elif isinstance(c["image_url"], str):
image_url_str = c["image_url"]
num_tokens += calculage_img_tokens(
data=image_url_str, mode="auto"
)
elif text is not None: elif text is not None:
num_tokens = len(encoding.encode(text, disallowed_special=())) num_tokens = len(encoding.encode(text, disallowed_special=()))
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens return num_tokens
def resize_image_high_res(width, height):
# Maximum dimensions for high res mode
max_short_side = 768
max_long_side = 2000
# Determine the longer and shorter sides
longer_side = max(width, height)
shorter_side = min(width, height)
# Calculate the aspect ratio
aspect_ratio = longer_side / shorter_side
# Resize based on the short side being 768px
if width <= height: # Portrait or square
resized_width = max_short_side
resized_height = int(resized_width * aspect_ratio)
# if the long side exceeds the limit after resizing, adjust both sides accordingly
if resized_height > max_long_side:
resized_height = max_long_side
resized_width = int(resized_height / aspect_ratio)
else: # Landscape
resized_height = max_short_side
resized_width = int(resized_height * aspect_ratio)
# if the long side exceeds the limit after resizing, adjust both sides accordingly
if resized_width > max_long_side:
resized_width = max_long_side
resized_height = int(resized_width / aspect_ratio)
return resized_width, resized_height
# Test the function with the given example
def calculate_tiles_needed(
resized_width, resized_height, tile_width=512, tile_height=512
):
tiles_across = (resized_width + tile_width - 1) // tile_width
tiles_down = (resized_height + tile_height - 1) // tile_height
total_tiles = tiles_across * tiles_down
return total_tiles
def get_image_dimensions(data):
img_data = None
# Check if data is a URL by trying to parse it
try:
response = requests.get(data)
response.raise_for_status() # Check if the request was successful
img_data = response.content
except Exception:
# Data is not a URL, handle as base64
header, encoded = data.split(",", 1)
img_data = base64.b64decode(encoded)
# Try to determine dimensions from headers
# This is a very simplistic check, primarily works with PNG and non-progressive JPEG
if img_data[:8] == b"\x89PNG\r\n\x1a\n":
# PNG Image; width and height are 4 bytes each and start at offset 16
width, height = struct.unpack(">ii", img_data[16:24])
return width, height
elif img_data[:2] == b"\xff\xd8":
# JPEG Image; for dimensions, SOF0 block (0xC0) gives dimensions at offset 3 for length, and then 5 and 7 for height and width
# This will NOT find dimensions for all JPEGs (e.g., progressive JPEGs)
# Find SOF0 marker (0xFF followed by 0xC0)
sof = re.search(b"\xff\xc0....", img_data)
if sof:
# Parse SOF0 block to find dimensions
height, width = struct.unpack(">HH", sof.group()[5:9])
return width, height
else:
return None, None
else:
# Unsupported format
return None, None
def calculage_img_tokens(
data,
mode: Literal["low", "high", "auto"] = "auto",
base_tokens: int = 85, # openai default - https://openai.com/pricing
):
if mode == "low" or mode == "auto":
return base_tokens
elif mode == "high":
width, height = get_image_dimensions(data=data)
resized_width, resized_height = resize_image_high_res(
width=width, height=height
)
tiles_needed_high_res = calculate_tiles_needed(resized_width, resized_height)
tile_tokens = (base_tokens * 2) * tiles_needed_high_res
total_tokens = base_tokens + tile_tokens
return total_tokens
def token_counter( def token_counter(
model="", model="",
text: Optional[Union[str, List[str]]] = None, text: Optional[Union[str, List[str]]] = None,
@ -2522,13 +2634,33 @@ def token_counter(
""" """
# use tiktoken, anthropic, cohere or llama2's tokenizer depending on the model # use tiktoken, anthropic, cohere or llama2's tokenizer depending on the model
is_tool_call = False is_tool_call = False
num_tokens = 0
if text == None: if text == None:
if messages is not None: if messages is not None:
print_verbose(f"token_counter messages received: {messages}") print_verbose(f"token_counter messages received: {messages}")
text = "" text = ""
for message in messages: for message in messages:
if message.get("content", None): if message.get("content", None) is not None:
text += message["content"] content = message.get("content")
if isinstance(content, str):
text += message["content"]
elif isinstance(content, List):
for c in content:
if c["type"] == "text":
text += c["text"]
elif c["type"] == "image_url":
if isinstance(c["image_url"], dict):
image_url_dict = c["image_url"]
detail = image_url_dict.get("detail", "auto")
url = image_url_dict.get("url")
num_tokens += calculage_img_tokens(
data=url, mode=detail
)
elif isinstance(c["image_url"], str):
image_url_str = c["image_url"]
num_tokens += calculage_img_tokens(
data=image_url_str, mode="auto"
)
if "tool_calls" in message: if "tool_calls" in message:
is_tool_call = True is_tool_call = True
for tool_call in message["tool_calls"]: for tool_call in message["tool_calls"]:
@ -2539,7 +2671,7 @@ def token_counter(
raise ValueError("text and messages cannot both be None") raise ValueError("text and messages cannot both be None")
elif isinstance(text, List): elif isinstance(text, List):
text = "".join(t for t in text if isinstance(t, str)) text = "".join(t for t in text if isinstance(t, str))
num_tokens = 0
if model is not None: if model is not None:
tokenizer_json = _select_tokenizer(model=model) tokenizer_json = _select_tokenizer(model=model)
if tokenizer_json["type"] == "huggingface_tokenizer": if tokenizer_json["type"] == "huggingface_tokenizer":