fix: Resolve Llama4 tool calling 500 errors

This commit addresses issue #2584 by:
- Implementing lazy torch imports in llama4/chat_format.py and datatypes.py to prevent ModuleNotFoundError in torch-free environments.
- Adding comprehensive unit tests to verify that text-only functionality works without torch and that vision features fail gracefully.
- Ensuring the module remains importable and functional for text-based operations, thus resolving the 500 internal server errors.
This commit is contained in:
skamenan7 2025-07-21 14:12:55 -04:00 committed by Sumanth Kamenani
parent 3d43e143d2
commit f5c1935c18
3 changed files with 197 additions and 44 deletions

View file

@ -4,13 +4,46 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# NOTE: This file used to import torch unconditionally, which breaks on
# minimal CPU-only installs and causes HTTP 500s when a Llama-4 model is
# used for *text-only* tool calling (issue #2584). We now lazy-import
# torch/vision deps so that simply importing the module or using text chat
# functions does **not** require torch.
import io
import json
import uuid
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
# Attempt to import the heavy vision stack; fall back gracefully if unavailable.
try:
import torch # type: ignore
from PIL import Image as PIL_Image # type: ignore
from .preprocess import ResizeNormalizeImageTransform, VariableSizeImageTransform # noqa: E401
_TORCH_AVAILABLE = True
except (ModuleNotFoundError, ImportError): # torch or PIL or derived deps missing
torch = None # type: ignore
PIL_Image = None # type: ignore
_TORCH_AVAILABLE = False
def _raise_torch_required():
raise ImportError(
"Llama-4 vision features require the `torch` and `Pillow` packages. "
"Install them to enable image processing."
)
# Create dummy transform classes that raise if instantiated
class ResizeNormalizeImageTransform: # type: ignore
def __init__(self, *_, **__):
_raise_torch_required()
class VariableSizeImageTransform: # type: ignore
def __init__(self, *_, **__):
_raise_torch_required()
import torch
from PIL import Image as PIL_Image
# TODO: either fork these or move them to the common package
from ..datatypes import (
@ -27,7 +60,6 @@ from ..datatypes import (
from ..llama3.tool_utils import ToolUtils
from .args import VisionArgs
from .datatypes import LLMInput
from .preprocess import ResizeNormalizeImageTransform, VariableSizeImageTransform
from .tokenizer import Tokenizer
@ -41,20 +73,28 @@ def role_str(role: Role) -> str:
return role_strs[role]
@dataclass
class TransformedImage:
image_tiles: torch.Tensor
# is the aspect ratio needed anywhere?
aspect_ratio: tuple[int, int]
# Define TransformedImage so that the module is importable without torch.
if TYPE_CHECKING or _TORCH_AVAILABLE:
# Normal case - torch available.
@dataclass
class TransformedImage: # type: ignore
image_tiles: "torch.Tensor" # quotes to avoid mypy when torch is None
aspect_ratio: tuple[int, int]
else:
@dataclass
class TransformedImage: # type: ignore
image_tiles: Any # placeholder; actual methods will never be called
aspect_ratio: tuple[int, int]
def convert_image_to_rgb(image: PIL_Image.Image, bg: tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
def convert_image_to_rgb(image: "PIL_Image.Image", bg: tuple[int, int, int] = (255, 255, 255)) -> "PIL_Image.Image":
if image.mode == "RGBA":
image.load() # for png.split()
new_img = PIL_Image.new("RGB", image.size, bg)
new_img = PIL_Image.new("RGB", image.size, bg) # type: ignore
new_img.paste(image, mask=image.split()[3]) # 3 is the alpha channel
return new_img
return image.convert("RGB")
return image.convert("RGB") # type: ignore
class ChatFormat:
@ -75,6 +115,8 @@ class ChatFormat:
self.image_transform = None
self.dynamic_image_transform = None
if vision_args:
if not _TORCH_AVAILABLE:
_raise_torch_required()
self.dynamic_image_transform = VariableSizeImageTransform(vision_args.image_size.width)
self.image_transform = ResizeNormalizeImageTransform(
vision_args.image_size.width, vision_args.image_size.height
@ -98,6 +140,9 @@ class ChatFormat:
self,
transformed_image: TransformedImage,
) -> list[int]:
if not _TORCH_AVAILABLE:
_raise_torch_required()
assert self.vision_args is not None, "The model is not vision-enabled"
image_tensor = transformed_image.image_tiles
@ -164,7 +209,7 @@ class ChatFormat:
added_bos = True
bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data
image = PIL_Image.open(bytes_io)
image = PIL_Image.open(bytes_io) # type: ignore
image = convert_image_to_rgb(image)
image_tiles, ar = self.dynamic_image_transform(image, max_num_chunks=self.max_num_chunks)

View file

@ -5,52 +5,79 @@
# the root directory of this source tree.
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
import torch
# Lazy import torch to avoid breaking text-only usage
try:
import torch # type: ignore
_TORCH_AVAILABLE = True
except (ModuleNotFoundError, ImportError):
torch = None # type: ignore
_TORCH_AVAILABLE = False
@dataclass
class MaskedEmbedding:
embedding: torch.Tensor
mask: torch.Tensor
if TYPE_CHECKING or _TORCH_AVAILABLE:
@dataclass
class MaskedEmbedding: # type: ignore
embedding: "torch.Tensor"
mask: "torch.Tensor"
@dataclass
class LLMInput:
"""
This is the input to the LLM from the "user" -- the user in this case views the
Llama4 model holistically and does not care or know about its inner workings (e.g.,
whether it has an encoder or if it is early fusion or not.)
@dataclass
class LLMInput: # type: ignore
"""
This is the input to the LLM from the "user" -- the user in this case views the
Llama4 model holistically and does not care or know about its inner workings (e.g.,
whether it has an encoder or if it is early fusion or not.)
This is distinct from the "TransformerInput" class which is really the Llama4
backbone operating on early fused modalities and producing text output
"""
This is distinct from the "TransformerInput" class which is really the Llama4
backbone operating on early fused modalities and producing text output
"""
tokens: torch.Tensor
tokens: "torch.Tensor"
# images are already pre-processed (resized, tiled, etc.)
images: list["torch.Tensor"] | None = None
# images are already pre-processed (resized, tiled, etc.)
images: list[torch.Tensor] | None = None
@dataclass
class TransformerInput: # type: ignore
"""
This is the "core" backbone transformer of the Llama4 model. Inputs for other modalities
are expected to be "embedded" via encoders sitting before this layer in the model.
"""
tokens: "torch.Tensor"
# tokens_position defines the position of the tokens in each batch,
# - when it is a tensor ([batch_size,]), it is the start position of the tokens in each batch
# - when it is an int, the start position are the same for all batches
tokens_position: "torch.Tensor" | int
image_embedding: "MaskedEmbedding" | None = None
@dataclass
class TransformerInput:
"""
This is the "core" backbone transformer of the Llama4 model. Inputs for other modalities
are expected to be "embedded" via encoders sitting before this layer in the model.
"""
@dataclass
class LLMOutput: # type: ignore
logits: "torch.Tensor"
tokens: torch.Tensor
else:
# Fallback stubs when torch unavailable
@dataclass
class MaskedEmbedding: # type: ignore
embedding: Any
mask: Any
# tokens_position defines the position of the tokens in each batch,
# - when it is a tensor ([batch_size,]), it is the start position of the tokens in each batch
# - when it is an int, the start position are the same for all batches
tokens_position: torch.Tensor | int
image_embedding: MaskedEmbedding | None = None
@dataclass
class LLMInput: # type: ignore
tokens: Any
images: Any = None
@dataclass
class TransformerInput: # type: ignore
tokens: Any
tokens_position: Any
image_embedding: Any = None
@dataclass
class LLMOutput:
logits: torch.Tensor
@dataclass
class LLMOutput: # type: ignore
logits: Any
TransformerOutput = LLMOutput

View file

@ -0,0 +1,81 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Test that Llama-4 modules can be imported and used for text-only operations
even when torch is not available (addresses issue #2584).
"""
import builtins
import importlib
import sys
import pytest
def _block_torch(monkeypatch):
"""Block torch imports to simulate torch-free environment."""
real_import = builtins.__import__
def fake_import(name, *args, **kwargs):
if name == "torch" or name.startswith("torch."):
raise ModuleNotFoundError("No module named 'torch'")
return real_import(name, *args, **kwargs)
sys.modules.pop("torch", None) # forget any cached import
monkeypatch.setattr(builtins, "__import__", fake_import)
def test_llama4_chat_format_imports_without_torch(monkeypatch):
"""Test that llama4.chat_format can be imported when torch is unavailable."""
_block_torch(monkeypatch)
# This should NOT raise ImportError anymore
chat_format_module = importlib.import_module("llama_stack.models.llama.llama4.chat_format")
assert chat_format_module is not None
def test_llama4_text_decoding_works_without_torch(monkeypatch):
"""Test that text-only tool calling decoding works without torch."""
_block_torch(monkeypatch)
from llama_stack.models.llama.datatypes import StopReason
from llama_stack.models.llama.llama4.chat_format import ChatFormat
from llama_stack.models.llama.llama4.tokenizer import Tokenizer
# Text-only operations should work fine
formatter = ChatFormat(Tokenizer.get_instance())
content = '[get_weather(location="SF")]<|eot|>'
msg = formatter.decode_assistant_message_from_content(content, StopReason.end_of_turn)
# Verify tool calling parsing works
assert msg.tool_calls, "Tool call should be detected"
tc = msg.tool_calls[0]
assert tc.tool_name == "get_weather"
assert tc.arguments == {"location": "SF"}
def test_llama4_vision_fails_gracefully_without_torch(monkeypatch):
"""Test that vision features raise clear error when torch unavailable."""
_block_torch(monkeypatch)
from llama_stack.models.llama.llama4.args import Size, VisionArgs
from llama_stack.models.llama.llama4.chat_format import ChatFormat
from llama_stack.models.llama.llama4.tokenizer import Tokenizer
# Trying to use vision features should raise helpful error
vision_args = VisionArgs(
image_size=Size(height=448, width=448),
patch_size=Size(height=14, width=14),
dim=512,
n_layers=6,
n_heads=8,
mlp_ratio=4.0,
output_dim=4096,
pixel_shuffle_ratio=2,
)
with pytest.raises(ImportError, match="vision features require.*torch.*Pillow"):
ChatFormat(Tokenizer.get_instance(), vision_args=vision_args)