From f5c1935c18e43aee1889313e10d838165575d166 Mon Sep 17 00:00:00 2001 From: skamenan7 Date: Mon, 21 Jul 2025 14:12:55 -0400 Subject: [PATCH] fix: Resolve Llama4 tool calling 500 errors This commit addresses issue #2584 by: - Implementing lazy torch imports in llama4/chat_format.py and datatypes.py to prevent ModuleNotFoundError in torch-free environments. - Adding comprehensive unit tests to verify that text-only functionality works without torch and that vision features fail gracefully. - Ensuring the module remains importable and functional for text-based operations, thus resolving the 500 internal server errors. --- .../models/llama/llama4/chat_format.py | 69 +++++++++++--- llama_stack/models/llama/llama4/datatypes.py | 91 ++++++++++++------- .../models/test_llama4_import_torch_free.py | 81 +++++++++++++++++ 3 files changed, 197 insertions(+), 44 deletions(-) create mode 100644 tests/unit/models/test_llama4_import_torch_free.py diff --git a/llama_stack/models/llama/llama4/chat_format.py b/llama_stack/models/llama/llama4/chat_format.py index 96ebd0881..99df3c1ce 100644 --- a/llama_stack/models/llama/llama4/chat_format.py +++ b/llama_stack/models/llama/llama4/chat_format.py @@ -4,13 +4,46 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +# NOTE: This file used to import torch unconditionally, which breaks on +# minimal CPU-only installs and causes HTTP 500s when a Llama-4 model is +# used for *text-only* tool calling (issue #2584). We now lazy-import +# torch/vision deps so that simply importing the module or using text chat +# functions does **not** require torch. + import io import json import uuid from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +# Attempt to import the heavy vision stack; fall back gracefully if unavailable. +try: + import torch # type: ignore + from PIL import Image as PIL_Image # type: ignore + + from .preprocess import ResizeNormalizeImageTransform, VariableSizeImageTransform # noqa: E401 + + _TORCH_AVAILABLE = True +except (ModuleNotFoundError, ImportError): # torch or PIL or derived deps missing + torch = None # type: ignore + PIL_Image = None # type: ignore + _TORCH_AVAILABLE = False + + def _raise_torch_required(): + raise ImportError( + "Llama-4 vision features require the `torch` and `Pillow` packages. " + "Install them to enable image processing." + ) + + # Create dummy transform classes that raise if instantiated + class ResizeNormalizeImageTransform: # type: ignore + def __init__(self, *_, **__): + _raise_torch_required() + + class VariableSizeImageTransform: # type: ignore + def __init__(self, *_, **__): + _raise_torch_required() -import torch -from PIL import Image as PIL_Image # TODO: either fork these or move them to the common package from ..datatypes import ( @@ -27,7 +60,6 @@ from ..datatypes import ( from ..llama3.tool_utils import ToolUtils from .args import VisionArgs from .datatypes import LLMInput -from .preprocess import ResizeNormalizeImageTransform, VariableSizeImageTransform from .tokenizer import Tokenizer @@ -41,20 +73,28 @@ def role_str(role: Role) -> str: return role_strs[role] -@dataclass -class TransformedImage: - image_tiles: torch.Tensor - # is the aspect ratio needed anywhere? - aspect_ratio: tuple[int, int] +# Define TransformedImage so that the module is importable without torch. +if TYPE_CHECKING or _TORCH_AVAILABLE: + # Normal case - torch available. + @dataclass + class TransformedImage: # type: ignore + image_tiles: "torch.Tensor" # quotes to avoid mypy when torch is None + aspect_ratio: tuple[int, int] +else: + + @dataclass + class TransformedImage: # type: ignore + image_tiles: Any # placeholder; actual methods will never be called + aspect_ratio: tuple[int, int] -def convert_image_to_rgb(image: PIL_Image.Image, bg: tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image: +def convert_image_to_rgb(image: "PIL_Image.Image", bg: tuple[int, int, int] = (255, 255, 255)) -> "PIL_Image.Image": if image.mode == "RGBA": image.load() # for png.split() - new_img = PIL_Image.new("RGB", image.size, bg) + new_img = PIL_Image.new("RGB", image.size, bg) # type: ignore new_img.paste(image, mask=image.split()[3]) # 3 is the alpha channel return new_img - return image.convert("RGB") + return image.convert("RGB") # type: ignore class ChatFormat: @@ -75,6 +115,8 @@ class ChatFormat: self.image_transform = None self.dynamic_image_transform = None if vision_args: + if not _TORCH_AVAILABLE: + _raise_torch_required() self.dynamic_image_transform = VariableSizeImageTransform(vision_args.image_size.width) self.image_transform = ResizeNormalizeImageTransform( vision_args.image_size.width, vision_args.image_size.height @@ -98,6 +140,9 @@ class ChatFormat: self, transformed_image: TransformedImage, ) -> list[int]: + if not _TORCH_AVAILABLE: + _raise_torch_required() + assert self.vision_args is not None, "The model is not vision-enabled" image_tensor = transformed_image.image_tiles @@ -164,7 +209,7 @@ class ChatFormat: added_bos = True bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data - image = PIL_Image.open(bytes_io) + image = PIL_Image.open(bytes_io) # type: ignore image = convert_image_to_rgb(image) image_tiles, ar = self.dynamic_image_transform(image, max_num_chunks=self.max_num_chunks) diff --git a/llama_stack/models/llama/llama4/datatypes.py b/llama_stack/models/llama/llama4/datatypes.py index 24d8ae948..cdfa34be9 100644 --- a/llama_stack/models/llama/llama4/datatypes.py +++ b/llama_stack/models/llama/llama4/datatypes.py @@ -5,52 +5,79 @@ # the root directory of this source tree. from dataclasses import dataclass +from typing import TYPE_CHECKING, Any -import torch +# Lazy import torch to avoid breaking text-only usage +try: + import torch # type: ignore + + _TORCH_AVAILABLE = True +except (ModuleNotFoundError, ImportError): + torch = None # type: ignore + _TORCH_AVAILABLE = False -@dataclass -class MaskedEmbedding: - embedding: torch.Tensor - mask: torch.Tensor +if TYPE_CHECKING or _TORCH_AVAILABLE: + @dataclass + class MaskedEmbedding: # type: ignore + embedding: "torch.Tensor" + mask: "torch.Tensor" -@dataclass -class LLMInput: - """ - This is the input to the LLM from the "user" -- the user in this case views the - Llama4 model holistically and does not care or know about its inner workings (e.g., - whether it has an encoder or if it is early fusion or not.) + @dataclass + class LLMInput: # type: ignore + """ + This is the input to the LLM from the "user" -- the user in this case views the + Llama4 model holistically and does not care or know about its inner workings (e.g., + whether it has an encoder or if it is early fusion or not.) - This is distinct from the "TransformerInput" class which is really the Llama4 - backbone operating on early fused modalities and producing text output - """ + This is distinct from the "TransformerInput" class which is really the Llama4 + backbone operating on early fused modalities and producing text output + """ - tokens: torch.Tensor + tokens: "torch.Tensor" + # images are already pre-processed (resized, tiled, etc.) + images: list["torch.Tensor"] | None = None - # images are already pre-processed (resized, tiled, etc.) - images: list[torch.Tensor] | None = None + @dataclass + class TransformerInput: # type: ignore + """ + This is the "core" backbone transformer of the Llama4 model. Inputs for other modalities + are expected to be "embedded" via encoders sitting before this layer in the model. + """ + tokens: "torch.Tensor" + # tokens_position defines the position of the tokens in each batch, + # - when it is a tensor ([batch_size,]), it is the start position of the tokens in each batch + # - when it is an int, the start position are the same for all batches + tokens_position: "torch.Tensor" | int + image_embedding: "MaskedEmbedding" | None = None -@dataclass -class TransformerInput: - """ - This is the "core" backbone transformer of the Llama4 model. Inputs for other modalities - are expected to be "embedded" via encoders sitting before this layer in the model. - """ + @dataclass + class LLMOutput: # type: ignore + logits: "torch.Tensor" - tokens: torch.Tensor +else: + # Fallback stubs when torch unavailable + @dataclass + class MaskedEmbedding: # type: ignore + embedding: Any + mask: Any - # tokens_position defines the position of the tokens in each batch, - # - when it is a tensor ([batch_size,]), it is the start position of the tokens in each batch - # - when it is an int, the start position are the same for all batches - tokens_position: torch.Tensor | int - image_embedding: MaskedEmbedding | None = None + @dataclass + class LLMInput: # type: ignore + tokens: Any + images: Any = None + @dataclass + class TransformerInput: # type: ignore + tokens: Any + tokens_position: Any + image_embedding: Any = None -@dataclass -class LLMOutput: - logits: torch.Tensor + @dataclass + class LLMOutput: # type: ignore + logits: Any TransformerOutput = LLMOutput diff --git a/tests/unit/models/test_llama4_import_torch_free.py b/tests/unit/models/test_llama4_import_torch_free.py new file mode 100644 index 000000000..852ac4cfa --- /dev/null +++ b/tests/unit/models/test_llama4_import_torch_free.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Test that Llama-4 modules can be imported and used for text-only operations +even when torch is not available (addresses issue #2584). +""" + +import builtins +import importlib +import sys + +import pytest + + +def _block_torch(monkeypatch): + """Block torch imports to simulate torch-free environment.""" + real_import = builtins.__import__ + + def fake_import(name, *args, **kwargs): + if name == "torch" or name.startswith("torch."): + raise ModuleNotFoundError("No module named 'torch'") + return real_import(name, *args, **kwargs) + + sys.modules.pop("torch", None) # forget any cached import + monkeypatch.setattr(builtins, "__import__", fake_import) + + +def test_llama4_chat_format_imports_without_torch(monkeypatch): + """Test that llama4.chat_format can be imported when torch is unavailable.""" + _block_torch(monkeypatch) + + # This should NOT raise ImportError anymore + chat_format_module = importlib.import_module("llama_stack.models.llama.llama4.chat_format") + assert chat_format_module is not None + + +def test_llama4_text_decoding_works_without_torch(monkeypatch): + """Test that text-only tool calling decoding works without torch.""" + _block_torch(monkeypatch) + + from llama_stack.models.llama.datatypes import StopReason + from llama_stack.models.llama.llama4.chat_format import ChatFormat + from llama_stack.models.llama.llama4.tokenizer import Tokenizer + + # Text-only operations should work fine + formatter = ChatFormat(Tokenizer.get_instance()) + content = '[get_weather(location="SF")]<|eot|>' + msg = formatter.decode_assistant_message_from_content(content, StopReason.end_of_turn) + + # Verify tool calling parsing works + assert msg.tool_calls, "Tool call should be detected" + tc = msg.tool_calls[0] + assert tc.tool_name == "get_weather" + assert tc.arguments == {"location": "SF"} + + +def test_llama4_vision_fails_gracefully_without_torch(monkeypatch): + """Test that vision features raise clear error when torch unavailable.""" + _block_torch(monkeypatch) + + from llama_stack.models.llama.llama4.args import Size, VisionArgs + from llama_stack.models.llama.llama4.chat_format import ChatFormat + from llama_stack.models.llama.llama4.tokenizer import Tokenizer + + # Trying to use vision features should raise helpful error + vision_args = VisionArgs( + image_size=Size(height=448, width=448), + patch_size=Size(height=14, width=14), + dim=512, + n_layers=6, + n_heads=8, + mlp_ratio=4.0, + output_dim=4096, + pixel_shuffle_ratio=2, + ) + with pytest.raises(ImportError, match="vision features require.*torch.*Pillow"): + ChatFormat(Tokenizer.get_instance(), vision_args=vision_args)