mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 14:38:49 +00:00
fix: Resolve Llama4 tool calling 500 errors
This commit addresses issue #2584 by: - Implementing lazy torch imports in llama4/chat_format.py and datatypes.py to prevent ModuleNotFoundError in torch-free environments. - Adding comprehensive unit tests to verify that text-only functionality works without torch and that vision features fail gracefully. - Ensuring the module remains importable and functional for text-based operations, thus resolving the 500 internal server errors.
This commit is contained in:
parent
3d43e143d2
commit
f5c1935c18
3 changed files with 197 additions and 44 deletions
|
@ -4,13 +4,46 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# NOTE: This file used to import torch unconditionally, which breaks on
|
||||||
|
# minimal CPU-only installs and causes HTTP 500s when a Llama-4 model is
|
||||||
|
# used for *text-only* tool calling (issue #2584). We now lazy-import
|
||||||
|
# torch/vision deps so that simply importing the module or using text chat
|
||||||
|
# functions does **not** require torch.
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
# Attempt to import the heavy vision stack; fall back gracefully if unavailable.
|
||||||
|
try:
|
||||||
|
import torch # type: ignore
|
||||||
|
from PIL import Image as PIL_Image # type: ignore
|
||||||
|
|
||||||
|
from .preprocess import ResizeNormalizeImageTransform, VariableSizeImageTransform # noqa: E401
|
||||||
|
|
||||||
|
_TORCH_AVAILABLE = True
|
||||||
|
except (ModuleNotFoundError, ImportError): # torch or PIL or derived deps missing
|
||||||
|
torch = None # type: ignore
|
||||||
|
PIL_Image = None # type: ignore
|
||||||
|
_TORCH_AVAILABLE = False
|
||||||
|
|
||||||
|
def _raise_torch_required():
|
||||||
|
raise ImportError(
|
||||||
|
"Llama-4 vision features require the `torch` and `Pillow` packages. "
|
||||||
|
"Install them to enable image processing."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create dummy transform classes that raise if instantiated
|
||||||
|
class ResizeNormalizeImageTransform: # type: ignore
|
||||||
|
def __init__(self, *_, **__):
|
||||||
|
_raise_torch_required()
|
||||||
|
|
||||||
|
class VariableSizeImageTransform: # type: ignore
|
||||||
|
def __init__(self, *_, **__):
|
||||||
|
_raise_torch_required()
|
||||||
|
|
||||||
import torch
|
|
||||||
from PIL import Image as PIL_Image
|
|
||||||
|
|
||||||
# TODO: either fork these or move them to the common package
|
# TODO: either fork these or move them to the common package
|
||||||
from ..datatypes import (
|
from ..datatypes import (
|
||||||
|
@ -27,7 +60,6 @@ from ..datatypes import (
|
||||||
from ..llama3.tool_utils import ToolUtils
|
from ..llama3.tool_utils import ToolUtils
|
||||||
from .args import VisionArgs
|
from .args import VisionArgs
|
||||||
from .datatypes import LLMInput
|
from .datatypes import LLMInput
|
||||||
from .preprocess import ResizeNormalizeImageTransform, VariableSizeImageTransform
|
|
||||||
from .tokenizer import Tokenizer
|
from .tokenizer import Tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@ -41,20 +73,28 @@ def role_str(role: Role) -> str:
|
||||||
return role_strs[role]
|
return role_strs[role]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
# Define TransformedImage so that the module is importable without torch.
|
||||||
class TransformedImage:
|
if TYPE_CHECKING or _TORCH_AVAILABLE:
|
||||||
image_tiles: torch.Tensor
|
# Normal case - torch available.
|
||||||
# is the aspect ratio needed anywhere?
|
@dataclass
|
||||||
aspect_ratio: tuple[int, int]
|
class TransformedImage: # type: ignore
|
||||||
|
image_tiles: "torch.Tensor" # quotes to avoid mypy when torch is None
|
||||||
|
aspect_ratio: tuple[int, int]
|
||||||
|
else:
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TransformedImage: # type: ignore
|
||||||
|
image_tiles: Any # placeholder; actual methods will never be called
|
||||||
|
aspect_ratio: tuple[int, int]
|
||||||
|
|
||||||
|
|
||||||
def convert_image_to_rgb(image: PIL_Image.Image, bg: tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
|
def convert_image_to_rgb(image: "PIL_Image.Image", bg: tuple[int, int, int] = (255, 255, 255)) -> "PIL_Image.Image":
|
||||||
if image.mode == "RGBA":
|
if image.mode == "RGBA":
|
||||||
image.load() # for png.split()
|
image.load() # for png.split()
|
||||||
new_img = PIL_Image.new("RGB", image.size, bg)
|
new_img = PIL_Image.new("RGB", image.size, bg) # type: ignore
|
||||||
new_img.paste(image, mask=image.split()[3]) # 3 is the alpha channel
|
new_img.paste(image, mask=image.split()[3]) # 3 is the alpha channel
|
||||||
return new_img
|
return new_img
|
||||||
return image.convert("RGB")
|
return image.convert("RGB") # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class ChatFormat:
|
class ChatFormat:
|
||||||
|
@ -75,6 +115,8 @@ class ChatFormat:
|
||||||
self.image_transform = None
|
self.image_transform = None
|
||||||
self.dynamic_image_transform = None
|
self.dynamic_image_transform = None
|
||||||
if vision_args:
|
if vision_args:
|
||||||
|
if not _TORCH_AVAILABLE:
|
||||||
|
_raise_torch_required()
|
||||||
self.dynamic_image_transform = VariableSizeImageTransform(vision_args.image_size.width)
|
self.dynamic_image_transform = VariableSizeImageTransform(vision_args.image_size.width)
|
||||||
self.image_transform = ResizeNormalizeImageTransform(
|
self.image_transform = ResizeNormalizeImageTransform(
|
||||||
vision_args.image_size.width, vision_args.image_size.height
|
vision_args.image_size.width, vision_args.image_size.height
|
||||||
|
@ -98,6 +140,9 @@ class ChatFormat:
|
||||||
self,
|
self,
|
||||||
transformed_image: TransformedImage,
|
transformed_image: TransformedImage,
|
||||||
) -> list[int]:
|
) -> list[int]:
|
||||||
|
if not _TORCH_AVAILABLE:
|
||||||
|
_raise_torch_required()
|
||||||
|
|
||||||
assert self.vision_args is not None, "The model is not vision-enabled"
|
assert self.vision_args is not None, "The model is not vision-enabled"
|
||||||
|
|
||||||
image_tensor = transformed_image.image_tiles
|
image_tensor = transformed_image.image_tiles
|
||||||
|
@ -164,7 +209,7 @@ class ChatFormat:
|
||||||
added_bos = True
|
added_bos = True
|
||||||
|
|
||||||
bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data
|
bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data
|
||||||
image = PIL_Image.open(bytes_io)
|
image = PIL_Image.open(bytes_io) # type: ignore
|
||||||
image = convert_image_to_rgb(image)
|
image = convert_image_to_rgb(image)
|
||||||
image_tiles, ar = self.dynamic_image_transform(image, max_num_chunks=self.max_num_chunks)
|
image_tiles, ar = self.dynamic_image_transform(image, max_num_chunks=self.max_num_chunks)
|
||||||
|
|
||||||
|
|
|
@ -5,52 +5,79 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import torch
|
# Lazy import torch to avoid breaking text-only usage
|
||||||
|
try:
|
||||||
|
import torch # type: ignore
|
||||||
|
|
||||||
|
_TORCH_AVAILABLE = True
|
||||||
|
except (ModuleNotFoundError, ImportError):
|
||||||
|
torch = None # type: ignore
|
||||||
|
_TORCH_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
if TYPE_CHECKING or _TORCH_AVAILABLE:
|
||||||
class MaskedEmbedding:
|
|
||||||
embedding: torch.Tensor
|
|
||||||
mask: torch.Tensor
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MaskedEmbedding: # type: ignore
|
||||||
|
embedding: "torch.Tensor"
|
||||||
|
mask: "torch.Tensor"
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LLMInput:
|
class LLMInput: # type: ignore
|
||||||
"""
|
"""
|
||||||
This is the input to the LLM from the "user" -- the user in this case views the
|
This is the input to the LLM from the "user" -- the user in this case views the
|
||||||
Llama4 model holistically and does not care or know about its inner workings (e.g.,
|
Llama4 model holistically and does not care or know about its inner workings (e.g.,
|
||||||
whether it has an encoder or if it is early fusion or not.)
|
whether it has an encoder or if it is early fusion or not.)
|
||||||
|
|
||||||
This is distinct from the "TransformerInput" class which is really the Llama4
|
This is distinct from the "TransformerInput" class which is really the Llama4
|
||||||
backbone operating on early fused modalities and producing text output
|
backbone operating on early fused modalities and producing text output
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tokens: torch.Tensor
|
tokens: "torch.Tensor"
|
||||||
|
# images are already pre-processed (resized, tiled, etc.)
|
||||||
|
images: list["torch.Tensor"] | None = None
|
||||||
|
|
||||||
# images are already pre-processed (resized, tiled, etc.)
|
@dataclass
|
||||||
images: list[torch.Tensor] | None = None
|
class TransformerInput: # type: ignore
|
||||||
|
"""
|
||||||
|
This is the "core" backbone transformer of the Llama4 model. Inputs for other modalities
|
||||||
|
are expected to be "embedded" via encoders sitting before this layer in the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tokens: "torch.Tensor"
|
||||||
|
# tokens_position defines the position of the tokens in each batch,
|
||||||
|
# - when it is a tensor ([batch_size,]), it is the start position of the tokens in each batch
|
||||||
|
# - when it is an int, the start position are the same for all batches
|
||||||
|
tokens_position: "torch.Tensor" | int
|
||||||
|
image_embedding: "MaskedEmbedding" | None = None
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TransformerInput:
|
class LLMOutput: # type: ignore
|
||||||
"""
|
logits: "torch.Tensor"
|
||||||
This is the "core" backbone transformer of the Llama4 model. Inputs for other modalities
|
|
||||||
are expected to be "embedded" via encoders sitting before this layer in the model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
tokens: torch.Tensor
|
else:
|
||||||
|
# Fallback stubs when torch unavailable
|
||||||
|
@dataclass
|
||||||
|
class MaskedEmbedding: # type: ignore
|
||||||
|
embedding: Any
|
||||||
|
mask: Any
|
||||||
|
|
||||||
# tokens_position defines the position of the tokens in each batch,
|
@dataclass
|
||||||
# - when it is a tensor ([batch_size,]), it is the start position of the tokens in each batch
|
class LLMInput: # type: ignore
|
||||||
# - when it is an int, the start position are the same for all batches
|
tokens: Any
|
||||||
tokens_position: torch.Tensor | int
|
images: Any = None
|
||||||
image_embedding: MaskedEmbedding | None = None
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TransformerInput: # type: ignore
|
||||||
|
tokens: Any
|
||||||
|
tokens_position: Any
|
||||||
|
image_embedding: Any = None
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LLMOutput:
|
class LLMOutput: # type: ignore
|
||||||
logits: torch.Tensor
|
logits: Any
|
||||||
|
|
||||||
|
|
||||||
TransformerOutput = LLMOutput
|
TransformerOutput = LLMOutput
|
||||||
|
|
81
tests/unit/models/test_llama4_import_torch_free.py
Normal file
81
tests/unit/models/test_llama4_import_torch_free.py
Normal file
|
@ -0,0 +1,81 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Test that Llama-4 modules can be imported and used for text-only operations
|
||||||
|
even when torch is not available (addresses issue #2584).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import builtins
|
||||||
|
import importlib
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def _block_torch(monkeypatch):
|
||||||
|
"""Block torch imports to simulate torch-free environment."""
|
||||||
|
real_import = builtins.__import__
|
||||||
|
|
||||||
|
def fake_import(name, *args, **kwargs):
|
||||||
|
if name == "torch" or name.startswith("torch."):
|
||||||
|
raise ModuleNotFoundError("No module named 'torch'")
|
||||||
|
return real_import(name, *args, **kwargs)
|
||||||
|
|
||||||
|
sys.modules.pop("torch", None) # forget any cached import
|
||||||
|
monkeypatch.setattr(builtins, "__import__", fake_import)
|
||||||
|
|
||||||
|
|
||||||
|
def test_llama4_chat_format_imports_without_torch(monkeypatch):
|
||||||
|
"""Test that llama4.chat_format can be imported when torch is unavailable."""
|
||||||
|
_block_torch(monkeypatch)
|
||||||
|
|
||||||
|
# This should NOT raise ImportError anymore
|
||||||
|
chat_format_module = importlib.import_module("llama_stack.models.llama.llama4.chat_format")
|
||||||
|
assert chat_format_module is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_llama4_text_decoding_works_without_torch(monkeypatch):
|
||||||
|
"""Test that text-only tool calling decoding works without torch."""
|
||||||
|
_block_torch(monkeypatch)
|
||||||
|
|
||||||
|
from llama_stack.models.llama.datatypes import StopReason
|
||||||
|
from llama_stack.models.llama.llama4.chat_format import ChatFormat
|
||||||
|
from llama_stack.models.llama.llama4.tokenizer import Tokenizer
|
||||||
|
|
||||||
|
# Text-only operations should work fine
|
||||||
|
formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
|
content = '[get_weather(location="SF")]<|eot|>'
|
||||||
|
msg = formatter.decode_assistant_message_from_content(content, StopReason.end_of_turn)
|
||||||
|
|
||||||
|
# Verify tool calling parsing works
|
||||||
|
assert msg.tool_calls, "Tool call should be detected"
|
||||||
|
tc = msg.tool_calls[0]
|
||||||
|
assert tc.tool_name == "get_weather"
|
||||||
|
assert tc.arguments == {"location": "SF"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_llama4_vision_fails_gracefully_without_torch(monkeypatch):
|
||||||
|
"""Test that vision features raise clear error when torch unavailable."""
|
||||||
|
_block_torch(monkeypatch)
|
||||||
|
|
||||||
|
from llama_stack.models.llama.llama4.args import Size, VisionArgs
|
||||||
|
from llama_stack.models.llama.llama4.chat_format import ChatFormat
|
||||||
|
from llama_stack.models.llama.llama4.tokenizer import Tokenizer
|
||||||
|
|
||||||
|
# Trying to use vision features should raise helpful error
|
||||||
|
vision_args = VisionArgs(
|
||||||
|
image_size=Size(height=448, width=448),
|
||||||
|
patch_size=Size(height=14, width=14),
|
||||||
|
dim=512,
|
||||||
|
n_layers=6,
|
||||||
|
n_heads=8,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
output_dim=4096,
|
||||||
|
pixel_shuffle_ratio=2,
|
||||||
|
)
|
||||||
|
with pytest.raises(ImportError, match="vision features require.*torch.*Pillow"):
|
||||||
|
ChatFormat(Tokenizer.get_instance(), vision_args=vision_args)
|
Loading…
Add table
Add a link
Reference in a new issue