fix: Support WebP image format and avoid token calculation error (#7182)

* fix get_image_dimensions

* attempt without pillow

* add clear type hints

* fix run_async_function_within_sync_function

* fix calculage_img_tokens

* fix is_prompt_caching_valid_prompt

* fix naming

* fix calculate_img_tokens

* fix unused imports

* fix calculate_img_tokens

* test test_is_prompt_caching_enabled_error_handling

* test_is_prompt_caching_enabled_return_default_image_dimensions

* fix openai_token_counter

* fix get_image_dimensions

* test_token_counter_with_image_url_with_detail_high

* test_img_url_token_counter

* fix test utils

* fix testing

* test_is_prompt_caching_enabled
This commit is contained in:
Ishaan Jaff 2024-12-12 14:32:39 -08:00 committed by GitHub
parent c6d6bda76c
commit 8c7605a164
8 changed files with 336 additions and 143 deletions

View file

@ -95,7 +95,10 @@ from litellm.litellm_core_utils.redact_messages import (
)
from litellm.litellm_core_utils.rules import Rules
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.litellm_core_utils.token_counter import get_modified_max_tokens
from litellm.litellm_core_utils.token_counter import (
calculate_img_tokens,
get_modified_max_tokens,
)
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.router_utils.get_retry_from_policy import (
get_num_retries_from_retry_policy,
@ -1283,6 +1286,7 @@ def openai_token_counter( # noqa: PLR0915
count_response_tokens: Optional[
bool
] = False, # Flag passed from litellm.stream_chunk_builder, to indicate counting tokens for LLM Response. We need this because for LLM input we add +3 tokens per message - based on OpenAI's token counter
use_default_image_token_count: Optional[bool] = False,
):
"""
Return the number of tokens used by a list of messages.
@ -1341,13 +1345,19 @@ def openai_token_counter( # noqa: PLR0915
image_url_dict = c["image_url"]
detail = image_url_dict.get("detail", "auto")
url = image_url_dict.get("url")
num_tokens += calculage_img_tokens(
data=url, mode=detail
num_tokens += calculate_img_tokens(
data=url,
mode=detail,
use_default_image_token_count=use_default_image_token_count
or False,
)
elif isinstance(c["image_url"], str):
image_url_str = c["image_url"]
num_tokens += calculage_img_tokens(
data=image_url_str, mode="auto"
num_tokens += calculate_img_tokens(
data=image_url_str,
mode="auto",
use_default_image_token_count=use_default_image_token_count
or False,
)
elif text is not None and count_response_tokens is True:
# This is the case where we need to count tokens for a streamed response. We should NOT add +3 tokens per message in this branch
@ -1375,130 +1385,6 @@ def openai_token_counter( # noqa: PLR0915
return num_tokens
def resize_image_high_res(width, height):
# Maximum dimensions for high res mode
max_short_side = 768
max_long_side = 2000
# Return early if no resizing is needed
if width <= 768 and height <= 768:
return width, height
# Determine the longer and shorter sides
longer_side = max(width, height)
shorter_side = min(width, height)
# Calculate the aspect ratio
aspect_ratio = longer_side / shorter_side
# Resize based on the short side being 768px
if width <= height: # Portrait or square
resized_width = max_short_side
resized_height = int(resized_width * aspect_ratio)
# if the long side exceeds the limit after resizing, adjust both sides accordingly
if resized_height > max_long_side:
resized_height = max_long_side
resized_width = int(resized_height / aspect_ratio)
else: # Landscape
resized_height = max_short_side
resized_width = int(resized_height * aspect_ratio)
# if the long side exceeds the limit after resizing, adjust both sides accordingly
if resized_width > max_long_side:
resized_width = max_long_side
resized_height = int(resized_width / aspect_ratio)
return resized_width, resized_height
# Test the function with the given example
def calculate_tiles_needed(
resized_width, resized_height, tile_width=512, tile_height=512
):
tiles_across = (resized_width + tile_width - 1) // tile_width
tiles_down = (resized_height + tile_height - 1) // tile_height
total_tiles = tiles_across * tiles_down
return total_tiles
def get_image_type(image_data: bytes) -> Union[str, None]:
"""take an image (really only the first ~100 bytes max are needed)
and return 'png' 'gif' 'jpeg' 'heic' or None. method added to
allow deprecation of imghdr in 3.13"""
if image_data[0:8] == b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a":
return "png"
if image_data[0:4] == b"GIF8" and image_data[5:6] == b"a":
return "gif"
if image_data[0:3] == b"\xff\xd8\xff":
return "jpeg"
if image_data[4:8] == b"ftyp":
return "heic"
return None
def get_image_dimensions(data):
img_data = None
try:
# Try to open as URL
# Try to open as URL
client = HTTPHandler(concurrent_limit=1)
response = client.get(data)
img_data = response.read()
except Exception:
# If not URL, assume it's base64
header, encoded = data.split(",", 1)
img_data = base64.b64decode(encoded)
img_type = get_image_type(img_data)
if img_type == "png":
w, h = struct.unpack(">LL", img_data[16:24])
return w, h
elif img_type == "gif":
w, h = struct.unpack("<HH", img_data[6:10])
return w, h
elif img_type == "jpeg":
with io.BytesIO(img_data) as fhandle:
fhandle.seek(0)
size = 2
ftype = 0
while not 0xC0 <= ftype <= 0xCF or ftype in (0xC4, 0xC8, 0xCC):
fhandle.seek(size, 1)
byte = fhandle.read(1)
while ord(byte) == 0xFF:
byte = fhandle.read(1)
ftype = ord(byte)
size = struct.unpack(">H", fhandle.read(2))[0] - 2
fhandle.seek(1, 1)
h, w = struct.unpack(">HH", fhandle.read(4))
return w, h
else:
return None, None
def calculage_img_tokens(
data,
mode: Literal["low", "high", "auto"] = "auto",
base_tokens: int = 85, # openai default - https://openai.com/pricing
):
if mode == "low" or mode == "auto":
return base_tokens
elif mode == "high":
width, height = get_image_dimensions(data=data)
resized_width, resized_height = resize_image_high_res(
width=width, height=height
)
tiles_needed_high_res = calculate_tiles_needed(resized_width, resized_height)
tile_tokens = (base_tokens * 2) * tiles_needed_high_res
total_tokens = base_tokens + tile_tokens
return total_tokens
def create_pretrained_tokenizer(
identifier: str, revision="main", auth_token: Optional[str] = None
):
@ -1615,6 +1501,7 @@ def token_counter(
count_response_tokens: Optional[bool] = False,
tools: Optional[List[ChatCompletionToolParam]] = None,
tool_choice: Optional[ChatCompletionNamedToolChoiceParam] = None,
use_default_image_token_count: Optional[bool] = False,
) -> int:
"""
Count the number of tokens in a given text using a specified model.
@ -1649,13 +1536,19 @@ def token_counter(
image_url_dict = c["image_url"]
detail = image_url_dict.get("detail", "auto")
url = image_url_dict.get("url")
num_tokens += calculage_img_tokens(
data=url, mode=detail
num_tokens += calculate_img_tokens(
data=url,
mode=detail,
use_default_image_token_count=use_default_image_token_count
or False,
)
elif isinstance(c["image_url"], str):
image_url_str = c["image_url"]
num_tokens += calculage_img_tokens(
data=image_url_str, mode="auto"
num_tokens += calculate_img_tokens(
data=image_url_str,
mode="auto",
use_default_image_token_count=use_default_image_token_count
or False,
)
if message.get("tool_calls"):
is_tool_call = True
@ -1695,6 +1588,8 @@ def token_counter(
count_response_tokens=count_response_tokens,
tools=tools,
tool_choice=tool_choice,
use_default_image_token_count=use_default_image_token_count
or False,
)
else:
print_verbose(
@ -1708,6 +1603,8 @@ def token_counter(
count_response_tokens=count_response_tokens,
tools=tools,
tool_choice=tool_choice,
use_default_image_token_count=use_default_image_token_count
or False,
)
else:
num_tokens = len(encoding.encode(text, disallowed_special=())) # type: ignore
@ -6480,9 +6377,20 @@ def is_prompt_caching_valid_prompt(
OpenAI + Anthropic providers have a minimum token count of 1024 for prompt caching.
"""
if messages is None and tools is None:
try:
if messages is None and tools is None:
return False
if custom_llm_provider is not None and not model.startswith(
custom_llm_provider
):
model = custom_llm_provider + "/" + model
token_count = token_counter(
messages=messages,
tools=tools,
model=model,
use_default_image_token_count=True,
)
return token_count >= 1024
except Exception as e:
verbose_logger.error(f"Error in is_prompt_caching_valid_prompt: {e}")
return False
if custom_llm_provider is not None and not model.startswith(custom_llm_provider):
model = custom_llm_provider + "/" + model
token_count = token_counter(messages=messages, tools=tools, model=model)
return token_count >= 1024