mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix: Support WebP image format and avoid token calculation error (#7182)
* fix get_image_dimensions * attempt without pillow * add clear type hints * fix run_async_function_within_sync_function * fix calculage_img_tokens * fix is_prompt_caching_valid_prompt * fix naming * fix calculate_img_tokens * fix unused imports * fix calculate_img_tokens * test test_is_prompt_caching_enabled_error_handling * test_is_prompt_caching_enabled_return_default_image_dimensions * fix openai_token_counter * fix get_image_dimensions * test_token_counter_with_image_url_with_detail_high * test_img_url_token_counter * fix test utils * fix testing * test_is_prompt_caching_enabled
This commit is contained in:
parent
c6d6bda76c
commit
8c7605a164
8 changed files with 336 additions and 143 deletions
|
@ -2,6 +2,9 @@ ROUTER_MAX_FALLBACKS = 5
|
||||||
DEFAULT_BATCH_SIZE = 512
|
DEFAULT_BATCH_SIZE = 512
|
||||||
DEFAULT_FLUSH_INTERVAL_SECONDS = 5
|
DEFAULT_FLUSH_INTERVAL_SECONDS = 5
|
||||||
DEFAULT_MAX_RETRIES = 2
|
DEFAULT_MAX_RETRIES = 2
|
||||||
|
DEFAULT_IMAGE_TOKEN_COUNT = 250
|
||||||
|
DEFAULT_IMAGE_WIDTH = 300
|
||||||
|
DEFAULT_IMAGE_HEIGHT = 300
|
||||||
LITELLM_CHAT_PROVIDERS = [
|
LITELLM_CHAT_PROVIDERS = [
|
||||||
"openai",
|
"openai",
|
||||||
"openai_like",
|
"openai_like",
|
||||||
|
|
|
@ -1,9 +1,18 @@
|
||||||
# What is this?
|
# What is this?
|
||||||
## Helper utilities for token counting
|
## Helper utilities for token counting
|
||||||
from typing import Optional
|
import base64
|
||||||
|
import io
|
||||||
|
import struct
|
||||||
|
from typing import Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import verbose_logger
|
from litellm import verbose_logger
|
||||||
|
from litellm.constants import (
|
||||||
|
DEFAULT_IMAGE_HEIGHT,
|
||||||
|
DEFAULT_IMAGE_TOKEN_COUNT,
|
||||||
|
DEFAULT_IMAGE_WIDTH,
|
||||||
|
)
|
||||||
|
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
|
||||||
|
|
||||||
|
|
||||||
def get_modified_max_tokens(
|
def get_modified_max_tokens(
|
||||||
|
@ -81,3 +90,184 @@ def get_modified_max_tokens(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return user_max_tokens
|
return user_max_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def resize_image_high_res(
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
# Maximum dimensions for high res mode
|
||||||
|
max_short_side = 768
|
||||||
|
max_long_side = 2000
|
||||||
|
|
||||||
|
# Return early if no resizing is needed
|
||||||
|
if width <= 768 and height <= 768:
|
||||||
|
return width, height
|
||||||
|
|
||||||
|
# Determine the longer and shorter sides
|
||||||
|
longer_side = max(width, height)
|
||||||
|
shorter_side = min(width, height)
|
||||||
|
|
||||||
|
# Calculate the aspect ratio
|
||||||
|
aspect_ratio = longer_side / shorter_side
|
||||||
|
|
||||||
|
# Resize based on the short side being 768px
|
||||||
|
if width <= height: # Portrait or square
|
||||||
|
resized_width = max_short_side
|
||||||
|
resized_height = int(resized_width * aspect_ratio)
|
||||||
|
# if the long side exceeds the limit after resizing, adjust both sides accordingly
|
||||||
|
if resized_height > max_long_side:
|
||||||
|
resized_height = max_long_side
|
||||||
|
resized_width = int(resized_height / aspect_ratio)
|
||||||
|
else: # Landscape
|
||||||
|
resized_height = max_short_side
|
||||||
|
resized_width = int(resized_height * aspect_ratio)
|
||||||
|
# if the long side exceeds the limit after resizing, adjust both sides accordingly
|
||||||
|
if resized_width > max_long_side:
|
||||||
|
resized_width = max_long_side
|
||||||
|
resized_height = int(resized_width / aspect_ratio)
|
||||||
|
|
||||||
|
return resized_width, resized_height
|
||||||
|
|
||||||
|
|
||||||
|
# Test the function with the given example
|
||||||
|
def calculate_tiles_needed(
|
||||||
|
resized_width, resized_height, tile_width=512, tile_height=512
|
||||||
|
):
|
||||||
|
tiles_across = (resized_width + tile_width - 1) // tile_width
|
||||||
|
tiles_down = (resized_height + tile_height - 1) // tile_height
|
||||||
|
total_tiles = tiles_across * tiles_down
|
||||||
|
return total_tiles
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_type(image_data: bytes) -> Union[str, None]:
|
||||||
|
"""take an image (really only the first ~100 bytes max are needed)
|
||||||
|
and return 'png' 'gif' 'jpeg' 'webp' 'heic' or None. method added to
|
||||||
|
allow deprecation of imghdr in 3.13"""
|
||||||
|
|
||||||
|
if image_data[0:8] == b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a":
|
||||||
|
return "png"
|
||||||
|
|
||||||
|
if image_data[0:4] == b"GIF8" and image_data[5:6] == b"a":
|
||||||
|
return "gif"
|
||||||
|
|
||||||
|
if image_data[0:3] == b"\xff\xd8\xff":
|
||||||
|
return "jpeg"
|
||||||
|
|
||||||
|
if image_data[4:8] == b"ftyp":
|
||||||
|
return "heic"
|
||||||
|
|
||||||
|
if image_data[0:4] == b"RIFF" and image_data[8:12] == b"WEBP":
|
||||||
|
return "webp"
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_dimensions(
|
||||||
|
data: str,
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Async Function to get the dimensions of an image from a URL or base64 encoded string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (str): The URL or base64 encoded string of the image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[int, int]: The width and height of the image.
|
||||||
|
"""
|
||||||
|
img_data = None
|
||||||
|
try:
|
||||||
|
# Try to open as URL
|
||||||
|
client = _get_httpx_client()
|
||||||
|
response = client.get(data)
|
||||||
|
img_data = response.read()
|
||||||
|
except Exception:
|
||||||
|
# If not URL, assume it's base64
|
||||||
|
_header, encoded = data.split(",", 1)
|
||||||
|
img_data = base64.b64decode(encoded)
|
||||||
|
|
||||||
|
img_type = get_image_type(img_data)
|
||||||
|
|
||||||
|
if img_type == "png":
|
||||||
|
w, h = struct.unpack(">LL", img_data[16:24])
|
||||||
|
return w, h
|
||||||
|
elif img_type == "gif":
|
||||||
|
w, h = struct.unpack("<HH", img_data[6:10])
|
||||||
|
return w, h
|
||||||
|
elif img_type == "jpeg":
|
||||||
|
with io.BytesIO(img_data) as fhandle:
|
||||||
|
fhandle.seek(0)
|
||||||
|
size = 2
|
||||||
|
ftype = 0
|
||||||
|
while not 0xC0 <= ftype <= 0xCF or ftype in (0xC4, 0xC8, 0xCC):
|
||||||
|
fhandle.seek(size, 1)
|
||||||
|
byte = fhandle.read(1)
|
||||||
|
while ord(byte) == 0xFF:
|
||||||
|
byte = fhandle.read(1)
|
||||||
|
ftype = ord(byte)
|
||||||
|
size = struct.unpack(">H", fhandle.read(2))[0] - 2
|
||||||
|
fhandle.seek(1, 1)
|
||||||
|
h, w = struct.unpack(">HH", fhandle.read(4))
|
||||||
|
return w, h
|
||||||
|
elif img_type == "webp":
|
||||||
|
# For WebP, the dimensions are stored at different offsets depending on the format
|
||||||
|
# Check for VP8X (extended format)
|
||||||
|
if img_data[12:16] == b"VP8X":
|
||||||
|
w = struct.unpack("<I", img_data[24:27] + b"\x00")[0] + 1
|
||||||
|
h = struct.unpack("<I", img_data[27:30] + b"\x00")[0] + 1
|
||||||
|
return w, h
|
||||||
|
# Check for VP8 (lossy format)
|
||||||
|
elif img_data[12:16] == b"VP8 ":
|
||||||
|
w = struct.unpack("<H", img_data[26:28])[0] & 0x3FFF
|
||||||
|
h = struct.unpack("<H", img_data[28:30])[0] & 0x3FFF
|
||||||
|
return w, h
|
||||||
|
# Check for VP8L (lossless format)
|
||||||
|
elif img_data[12:16] == b"VP8L":
|
||||||
|
bits = struct.unpack("<I", img_data[21:25])[0]
|
||||||
|
w = (bits & 0x3FFF) + 1
|
||||||
|
h = ((bits >> 14) & 0x3FFF) + 1
|
||||||
|
return w, h
|
||||||
|
|
||||||
|
# return sensible default image dimensions if unable to get dimensions
|
||||||
|
return DEFAULT_IMAGE_WIDTH, DEFAULT_IMAGE_HEIGHT
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_img_tokens(
|
||||||
|
data,
|
||||||
|
mode: Literal["low", "high", "auto"] = "auto",
|
||||||
|
base_tokens: int = 85, # openai default - https://openai.com/pricing
|
||||||
|
use_default_image_token_count: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Calculate the number of tokens for an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (str): The URL or base64 encoded string of the image.
|
||||||
|
mode (Literal["low", "high", "auto"]): The mode to use for calculating the number of tokens.
|
||||||
|
base_tokens (int): The base number of tokens for an image.
|
||||||
|
use_default_image_token_count (bool): When True, will NOT make a GET request to the image URL and instead return the default image dimensions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The number of tokens for the image.
|
||||||
|
"""
|
||||||
|
if use_default_image_token_count:
|
||||||
|
verbose_logger.debug(
|
||||||
|
"Using default image token count: {}".format(DEFAULT_IMAGE_TOKEN_COUNT)
|
||||||
|
)
|
||||||
|
return DEFAULT_IMAGE_TOKEN_COUNT
|
||||||
|
if mode == "low" or mode == "auto":
|
||||||
|
return base_tokens
|
||||||
|
elif mode == "high":
|
||||||
|
# Run the async function using the helper
|
||||||
|
width, height = get_image_dimensions(
|
||||||
|
data=data,
|
||||||
|
)
|
||||||
|
resized_width, resized_height = resize_image_high_res(
|
||||||
|
width=width, height=height
|
||||||
|
)
|
||||||
|
tiles_needed_high_res = calculate_tiles_needed(
|
||||||
|
resized_width=resized_width, resized_height=resized_height
|
||||||
|
)
|
||||||
|
tile_tokens = (base_tokens * 2) * tiles_needed_high_res
|
||||||
|
total_tokens = base_tokens + tile_tokens
|
||||||
|
return total_tokens
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: gpt-4o
|
- model_name: openai/*
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: openai/gpt-4o
|
model: openai/gpt-4o
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
|
|
178
litellm/utils.py
178
litellm/utils.py
|
@ -95,7 +95,10 @@ from litellm.litellm_core_utils.redact_messages import (
|
||||||
)
|
)
|
||||||
from litellm.litellm_core_utils.rules import Rules
|
from litellm.litellm_core_utils.rules import Rules
|
||||||
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
||||||
from litellm.litellm_core_utils.token_counter import get_modified_max_tokens
|
from litellm.litellm_core_utils.token_counter import (
|
||||||
|
calculate_img_tokens,
|
||||||
|
get_modified_max_tokens,
|
||||||
|
)
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.router_utils.get_retry_from_policy import (
|
from litellm.router_utils.get_retry_from_policy import (
|
||||||
get_num_retries_from_retry_policy,
|
get_num_retries_from_retry_policy,
|
||||||
|
@ -1283,6 +1286,7 @@ def openai_token_counter( # noqa: PLR0915
|
||||||
count_response_tokens: Optional[
|
count_response_tokens: Optional[
|
||||||
bool
|
bool
|
||||||
] = False, # Flag passed from litellm.stream_chunk_builder, to indicate counting tokens for LLM Response. We need this because for LLM input we add +3 tokens per message - based on OpenAI's token counter
|
] = False, # Flag passed from litellm.stream_chunk_builder, to indicate counting tokens for LLM Response. We need this because for LLM input we add +3 tokens per message - based on OpenAI's token counter
|
||||||
|
use_default_image_token_count: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Return the number of tokens used by a list of messages.
|
Return the number of tokens used by a list of messages.
|
||||||
|
@ -1341,13 +1345,19 @@ def openai_token_counter( # noqa: PLR0915
|
||||||
image_url_dict = c["image_url"]
|
image_url_dict = c["image_url"]
|
||||||
detail = image_url_dict.get("detail", "auto")
|
detail = image_url_dict.get("detail", "auto")
|
||||||
url = image_url_dict.get("url")
|
url = image_url_dict.get("url")
|
||||||
num_tokens += calculage_img_tokens(
|
num_tokens += calculate_img_tokens(
|
||||||
data=url, mode=detail
|
data=url,
|
||||||
|
mode=detail,
|
||||||
|
use_default_image_token_count=use_default_image_token_count
|
||||||
|
or False,
|
||||||
)
|
)
|
||||||
elif isinstance(c["image_url"], str):
|
elif isinstance(c["image_url"], str):
|
||||||
image_url_str = c["image_url"]
|
image_url_str = c["image_url"]
|
||||||
num_tokens += calculage_img_tokens(
|
num_tokens += calculate_img_tokens(
|
||||||
data=image_url_str, mode="auto"
|
data=image_url_str,
|
||||||
|
mode="auto",
|
||||||
|
use_default_image_token_count=use_default_image_token_count
|
||||||
|
or False,
|
||||||
)
|
)
|
||||||
elif text is not None and count_response_tokens is True:
|
elif text is not None and count_response_tokens is True:
|
||||||
# This is the case where we need to count tokens for a streamed response. We should NOT add +3 tokens per message in this branch
|
# This is the case where we need to count tokens for a streamed response. We should NOT add +3 tokens per message in this branch
|
||||||
|
@ -1375,130 +1385,6 @@ def openai_token_counter( # noqa: PLR0915
|
||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
|
|
||||||
def resize_image_high_res(width, height):
|
|
||||||
# Maximum dimensions for high res mode
|
|
||||||
max_short_side = 768
|
|
||||||
max_long_side = 2000
|
|
||||||
|
|
||||||
# Return early if no resizing is needed
|
|
||||||
if width <= 768 and height <= 768:
|
|
||||||
return width, height
|
|
||||||
|
|
||||||
# Determine the longer and shorter sides
|
|
||||||
longer_side = max(width, height)
|
|
||||||
shorter_side = min(width, height)
|
|
||||||
|
|
||||||
# Calculate the aspect ratio
|
|
||||||
aspect_ratio = longer_side / shorter_side
|
|
||||||
|
|
||||||
# Resize based on the short side being 768px
|
|
||||||
if width <= height: # Portrait or square
|
|
||||||
resized_width = max_short_side
|
|
||||||
resized_height = int(resized_width * aspect_ratio)
|
|
||||||
# if the long side exceeds the limit after resizing, adjust both sides accordingly
|
|
||||||
if resized_height > max_long_side:
|
|
||||||
resized_height = max_long_side
|
|
||||||
resized_width = int(resized_height / aspect_ratio)
|
|
||||||
else: # Landscape
|
|
||||||
resized_height = max_short_side
|
|
||||||
resized_width = int(resized_height * aspect_ratio)
|
|
||||||
# if the long side exceeds the limit after resizing, adjust both sides accordingly
|
|
||||||
if resized_width > max_long_side:
|
|
||||||
resized_width = max_long_side
|
|
||||||
resized_height = int(resized_width / aspect_ratio)
|
|
||||||
|
|
||||||
return resized_width, resized_height
|
|
||||||
|
|
||||||
|
|
||||||
# Test the function with the given example
|
|
||||||
def calculate_tiles_needed(
|
|
||||||
resized_width, resized_height, tile_width=512, tile_height=512
|
|
||||||
):
|
|
||||||
tiles_across = (resized_width + tile_width - 1) // tile_width
|
|
||||||
tiles_down = (resized_height + tile_height - 1) // tile_height
|
|
||||||
total_tiles = tiles_across * tiles_down
|
|
||||||
return total_tiles
|
|
||||||
|
|
||||||
|
|
||||||
def get_image_type(image_data: bytes) -> Union[str, None]:
|
|
||||||
"""take an image (really only the first ~100 bytes max are needed)
|
|
||||||
and return 'png' 'gif' 'jpeg' 'heic' or None. method added to
|
|
||||||
allow deprecation of imghdr in 3.13"""
|
|
||||||
|
|
||||||
if image_data[0:8] == b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a":
|
|
||||||
return "png"
|
|
||||||
|
|
||||||
if image_data[0:4] == b"GIF8" and image_data[5:6] == b"a":
|
|
||||||
return "gif"
|
|
||||||
|
|
||||||
if image_data[0:3] == b"\xff\xd8\xff":
|
|
||||||
return "jpeg"
|
|
||||||
|
|
||||||
if image_data[4:8] == b"ftyp":
|
|
||||||
return "heic"
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_image_dimensions(data):
|
|
||||||
img_data = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Try to open as URL
|
|
||||||
# Try to open as URL
|
|
||||||
client = HTTPHandler(concurrent_limit=1)
|
|
||||||
response = client.get(data)
|
|
||||||
img_data = response.read()
|
|
||||||
except Exception:
|
|
||||||
# If not URL, assume it's base64
|
|
||||||
header, encoded = data.split(",", 1)
|
|
||||||
img_data = base64.b64decode(encoded)
|
|
||||||
|
|
||||||
img_type = get_image_type(img_data)
|
|
||||||
|
|
||||||
if img_type == "png":
|
|
||||||
w, h = struct.unpack(">LL", img_data[16:24])
|
|
||||||
return w, h
|
|
||||||
elif img_type == "gif":
|
|
||||||
w, h = struct.unpack("<HH", img_data[6:10])
|
|
||||||
return w, h
|
|
||||||
elif img_type == "jpeg":
|
|
||||||
with io.BytesIO(img_data) as fhandle:
|
|
||||||
fhandle.seek(0)
|
|
||||||
size = 2
|
|
||||||
ftype = 0
|
|
||||||
while not 0xC0 <= ftype <= 0xCF or ftype in (0xC4, 0xC8, 0xCC):
|
|
||||||
fhandle.seek(size, 1)
|
|
||||||
byte = fhandle.read(1)
|
|
||||||
while ord(byte) == 0xFF:
|
|
||||||
byte = fhandle.read(1)
|
|
||||||
ftype = ord(byte)
|
|
||||||
size = struct.unpack(">H", fhandle.read(2))[0] - 2
|
|
||||||
fhandle.seek(1, 1)
|
|
||||||
h, w = struct.unpack(">HH", fhandle.read(4))
|
|
||||||
return w, h
|
|
||||||
else:
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
|
|
||||||
def calculage_img_tokens(
|
|
||||||
data,
|
|
||||||
mode: Literal["low", "high", "auto"] = "auto",
|
|
||||||
base_tokens: int = 85, # openai default - https://openai.com/pricing
|
|
||||||
):
|
|
||||||
if mode == "low" or mode == "auto":
|
|
||||||
return base_tokens
|
|
||||||
elif mode == "high":
|
|
||||||
width, height = get_image_dimensions(data=data)
|
|
||||||
resized_width, resized_height = resize_image_high_res(
|
|
||||||
width=width, height=height
|
|
||||||
)
|
|
||||||
tiles_needed_high_res = calculate_tiles_needed(resized_width, resized_height)
|
|
||||||
tile_tokens = (base_tokens * 2) * tiles_needed_high_res
|
|
||||||
total_tokens = base_tokens + tile_tokens
|
|
||||||
return total_tokens
|
|
||||||
|
|
||||||
|
|
||||||
def create_pretrained_tokenizer(
|
def create_pretrained_tokenizer(
|
||||||
identifier: str, revision="main", auth_token: Optional[str] = None
|
identifier: str, revision="main", auth_token: Optional[str] = None
|
||||||
):
|
):
|
||||||
|
@ -1615,6 +1501,7 @@ def token_counter(
|
||||||
count_response_tokens: Optional[bool] = False,
|
count_response_tokens: Optional[bool] = False,
|
||||||
tools: Optional[List[ChatCompletionToolParam]] = None,
|
tools: Optional[List[ChatCompletionToolParam]] = None,
|
||||||
tool_choice: Optional[ChatCompletionNamedToolChoiceParam] = None,
|
tool_choice: Optional[ChatCompletionNamedToolChoiceParam] = None,
|
||||||
|
use_default_image_token_count: Optional[bool] = False,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Count the number of tokens in a given text using a specified model.
|
Count the number of tokens in a given text using a specified model.
|
||||||
|
@ -1649,13 +1536,19 @@ def token_counter(
|
||||||
image_url_dict = c["image_url"]
|
image_url_dict = c["image_url"]
|
||||||
detail = image_url_dict.get("detail", "auto")
|
detail = image_url_dict.get("detail", "auto")
|
||||||
url = image_url_dict.get("url")
|
url = image_url_dict.get("url")
|
||||||
num_tokens += calculage_img_tokens(
|
num_tokens += calculate_img_tokens(
|
||||||
data=url, mode=detail
|
data=url,
|
||||||
|
mode=detail,
|
||||||
|
use_default_image_token_count=use_default_image_token_count
|
||||||
|
or False,
|
||||||
)
|
)
|
||||||
elif isinstance(c["image_url"], str):
|
elif isinstance(c["image_url"], str):
|
||||||
image_url_str = c["image_url"]
|
image_url_str = c["image_url"]
|
||||||
num_tokens += calculage_img_tokens(
|
num_tokens += calculate_img_tokens(
|
||||||
data=image_url_str, mode="auto"
|
data=image_url_str,
|
||||||
|
mode="auto",
|
||||||
|
use_default_image_token_count=use_default_image_token_count
|
||||||
|
or False,
|
||||||
)
|
)
|
||||||
if message.get("tool_calls"):
|
if message.get("tool_calls"):
|
||||||
is_tool_call = True
|
is_tool_call = True
|
||||||
|
@ -1695,6 +1588,8 @@ def token_counter(
|
||||||
count_response_tokens=count_response_tokens,
|
count_response_tokens=count_response_tokens,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
|
use_default_image_token_count=use_default_image_token_count
|
||||||
|
or False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print_verbose(
|
print_verbose(
|
||||||
|
@ -1708,6 +1603,8 @@ def token_counter(
|
||||||
count_response_tokens=count_response_tokens,
|
count_response_tokens=count_response_tokens,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
|
use_default_image_token_count=use_default_image_token_count
|
||||||
|
or False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
num_tokens = len(encoding.encode(text, disallowed_special=())) # type: ignore
|
num_tokens = len(encoding.encode(text, disallowed_special=())) # type: ignore
|
||||||
|
@ -6480,9 +6377,20 @@ def is_prompt_caching_valid_prompt(
|
||||||
|
|
||||||
OpenAI + Anthropic providers have a minimum token count of 1024 for prompt caching.
|
OpenAI + Anthropic providers have a minimum token count of 1024 for prompt caching.
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
if messages is None and tools is None:
|
if messages is None and tools is None:
|
||||||
return False
|
return False
|
||||||
if custom_llm_provider is not None and not model.startswith(custom_llm_provider):
|
if custom_llm_provider is not None and not model.startswith(
|
||||||
|
custom_llm_provider
|
||||||
|
):
|
||||||
model = custom_llm_provider + "/" + model
|
model = custom_llm_provider + "/" + model
|
||||||
token_count = token_counter(messages=messages, tools=tools, model=model)
|
token_count = token_counter(
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
model=model,
|
||||||
|
use_default_image_token_count=True,
|
||||||
|
)
|
||||||
return token_count >= 1024
|
return token_count >= 1024
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.error(f"Error in is_prompt_caching_valid_prompt: {e}")
|
||||||
|
return False
|
||||||
|
|
|
@ -16,6 +16,7 @@ from respx import MockRouter
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import Choices, Message, ModelResponse
|
from litellm import Choices, Message, ModelResponse
|
||||||
from base_llm_unit_tests import BaseLLMChatTest
|
from base_llm_unit_tests import BaseLLMChatTest
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
def test_openai_prediction_param():
|
def test_openai_prediction_param():
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
|
|
||||||
def calculage_img_tokens(
|
def calculate_img_tokens(
|
||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
mode: Literal["low", "high", "auto"] = "auto",
|
mode: Literal["low", "high", "auto"] = "auto",
|
||||||
|
|
|
@ -370,7 +370,7 @@ def test_gpt_4o_token_counter():
|
||||||
)
|
)
|
||||||
def test_img_url_token_counter(img_url):
|
def test_img_url_token_counter(img_url):
|
||||||
|
|
||||||
from litellm.utils import get_image_dimensions
|
from litellm.litellm_core_utils.token_counter import get_image_dimensions
|
||||||
|
|
||||||
width, height = get_image_dimensions(data=img_url)
|
width, height = get_image_dimensions(data=img_url)
|
||||||
|
|
||||||
|
|
|
@ -37,6 +37,8 @@ from litellm.utils import (
|
||||||
trim_messages,
|
trim_messages,
|
||||||
validate_environment,
|
validate_environment,
|
||||||
)
|
)
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
|
||||||
# Assuming your trim_messages, shorten_message_to_fit_limit, and get_token_count functions are all in a module named 'message_utils'
|
# Assuming your trim_messages, shorten_message_to_fit_limit, and get_token_count functions are all in a module named 'message_utils'
|
||||||
|
|
||||||
|
@ -1147,3 +1149,92 @@ def test_get_end_user_id_for_cost_tracking_prometheus_only(
|
||||||
)
|
)
|
||||||
== expected_end_user_id
|
== expected_end_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_prompt_caching_enabled_error_handling():
|
||||||
|
"""
|
||||||
|
Assert that `is_prompt_caching_valid_prompt` safely handles errors in `token_counter`.
|
||||||
|
"""
|
||||||
|
with patch(
|
||||||
|
"litellm.utils.token_counter",
|
||||||
|
side_effect=Exception(
|
||||||
|
"Mocked error, This should not raise an error. Instead is_prompt_caching_valid_prompt should return False."
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = litellm.utils.is_prompt_caching_valid_prompt(
|
||||||
|
messages=[{"role": "user", "content": "test"}],
|
||||||
|
tools=None,
|
||||||
|
custom_llm_provider="anthropic",
|
||||||
|
model="anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is False # Should return False when an error occurs
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_prompt_caching_enabled_return_default_image_dimensions():
|
||||||
|
"""
|
||||||
|
Assert that `is_prompt_caching_valid_prompt` calls token_counter with use_default_image_token_count=True
|
||||||
|
when processing messages containing images
|
||||||
|
|
||||||
|
IMPORTANT: Ensures Get token counter does not make a GET request to the image url
|
||||||
|
"""
|
||||||
|
with patch("litellm.utils.token_counter") as mock_token_counter:
|
||||||
|
litellm.utils.is_prompt_caching_valid_prompt(
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What is in this image?"},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://www.gstatic.com/webp/gallery/1.webp",
|
||||||
|
"detail": "high",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
tools=None,
|
||||||
|
custom_llm_provider="openai",
|
||||||
|
model="gpt-4o-mini",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert token_counter was called with use_default_image_token_count=True
|
||||||
|
args_to_mock_token_counter = mock_token_counter.call_args[1]
|
||||||
|
print("args_to_mock", args_to_mock_token_counter)
|
||||||
|
assert args_to_mock_token_counter["use_default_image_token_count"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_counter_with_image_url_with_detail_high():
|
||||||
|
"""
|
||||||
|
Assert that token_counter does not make a GET request to the image url when `use_default_image_token_count=True`
|
||||||
|
|
||||||
|
PROD TEST this is importat - Can impact latency very badly
|
||||||
|
"""
|
||||||
|
from litellm.constants import DEFAULT_IMAGE_TOKEN_COUNT
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
import logging
|
||||||
|
|
||||||
|
verbose_logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
_tokens = litellm.utils.token_counter(
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://www.gstatic.com/webp/gallery/1.webp",
|
||||||
|
"detail": "high",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
model="gpt-4o-mini",
|
||||||
|
use_default_image_token_count=True,
|
||||||
|
)
|
||||||
|
print("tokens", _tokens)
|
||||||
|
assert _tokens == DEFAULT_IMAGE_TOKEN_COUNT + 7
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue