fix: Resolve Llama4 tool calling 500 errors

This commit addresses issue #2584 by:
- Implementing lazy torch imports in llama4/chat_format.py and datatypes.py to prevent ModuleNotFoundError in torch-free environments.
- Adding comprehensive unit tests to verify that text-only functionality works without torch and that vision features fail gracefully.
- Ensuring the module remains importable and functional for text-based operations, thus resolving the 500 internal server errors.
This commit is contained in:
skamenan7 2025-07-21 14:12:55 -04:00 committed by Sumanth Kamenani
parent 3d43e143d2
commit f5c1935c18
3 changed files with 197 additions and 44 deletions

View file

@ -4,13 +4,46 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# NOTE: This file used to import torch unconditionally, which breaks on
# minimal CPU-only installs and causes HTTP 500s when a Llama-4 model is
# used for *text-only* tool calling (issue #2584). We now lazy-import
# torch/vision deps so that simply importing the module or using text chat
# functions does **not** require torch.
import io
import json
import uuid
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
# Attempt to import the heavy vision stack; fall back gracefully if unavailable.
try:
import torch # type: ignore
from PIL import Image as PIL_Image # type: ignore
from .preprocess import ResizeNormalizeImageTransform, VariableSizeImageTransform # noqa: E401
_TORCH_AVAILABLE = True
except (ModuleNotFoundError, ImportError): # torch or PIL or derived deps missing
torch = None # type: ignore
PIL_Image = None # type: ignore
_TORCH_AVAILABLE = False
def _raise_torch_required():
raise ImportError(
"Llama-4 vision features require the `torch` and `Pillow` packages. "
"Install them to enable image processing."
)
# Create dummy transform classes that raise if instantiated
class ResizeNormalizeImageTransform: # type: ignore
def __init__(self, *_, **__):
_raise_torch_required()
class VariableSizeImageTransform: # type: ignore
def __init__(self, *_, **__):
_raise_torch_required()
import torch
from PIL import Image as PIL_Image
# TODO: either fork these or move them to the common package
from ..datatypes import (
@ -27,7 +60,6 @@ from ..datatypes import (
from ..llama3.tool_utils import ToolUtils
from .args import VisionArgs
from .datatypes import LLMInput
from .preprocess import ResizeNormalizeImageTransform, VariableSizeImageTransform
from .tokenizer import Tokenizer
@ -41,20 +73,28 @@ def role_str(role: Role) -> str:
return role_strs[role]
@dataclass
class TransformedImage:
image_tiles: torch.Tensor
# is the aspect ratio needed anywhere?
aspect_ratio: tuple[int, int]
# Define TransformedImage so that the module is importable without torch.
if TYPE_CHECKING or _TORCH_AVAILABLE:
# Normal case - torch available.
@dataclass
class TransformedImage: # type: ignore
image_tiles: "torch.Tensor" # quotes to avoid mypy when torch is None
aspect_ratio: tuple[int, int]
else:
@dataclass
class TransformedImage: # type: ignore
image_tiles: Any # placeholder; actual methods will never be called
aspect_ratio: tuple[int, int]
def convert_image_to_rgb(image: PIL_Image.Image, bg: tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
def convert_image_to_rgb(image: "PIL_Image.Image", bg: tuple[int, int, int] = (255, 255, 255)) -> "PIL_Image.Image":
if image.mode == "RGBA":
image.load() # for png.split()
new_img = PIL_Image.new("RGB", image.size, bg)
new_img = PIL_Image.new("RGB", image.size, bg) # type: ignore
new_img.paste(image, mask=image.split()[3]) # 3 is the alpha channel
return new_img
return image.convert("RGB")
return image.convert("RGB") # type: ignore
class ChatFormat:
@ -75,6 +115,8 @@ class ChatFormat:
self.image_transform = None
self.dynamic_image_transform = None
if vision_args:
if not _TORCH_AVAILABLE:
_raise_torch_required()
self.dynamic_image_transform = VariableSizeImageTransform(vision_args.image_size.width)
self.image_transform = ResizeNormalizeImageTransform(
vision_args.image_size.width, vision_args.image_size.height
@ -98,6 +140,9 @@ class ChatFormat:
self,
transformed_image: TransformedImage,
) -> list[int]:
if not _TORCH_AVAILABLE:
_raise_torch_required()
assert self.vision_args is not None, "The model is not vision-enabled"
image_tensor = transformed_image.image_tiles
@ -164,7 +209,7 @@ class ChatFormat:
added_bos = True
bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data
image = PIL_Image.open(bytes_io)
image = PIL_Image.open(bytes_io) # type: ignore
image = convert_image_to_rgb(image)
image_tiles, ar = self.dynamic_image_transform(image, max_num_chunks=self.max_num_chunks)