mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 14:38:49 +00:00
fix: Resolve Llama4 tool calling 500 errors
This commit addresses issue #2584 by: - Implementing lazy torch imports in llama4/chat_format.py and datatypes.py to prevent ModuleNotFoundError in torch-free environments. - Adding comprehensive unit tests to verify that text-only functionality works without torch and that vision features fail gracefully. - Ensuring the module remains importable and functional for text-based operations, thus resolving the 500 internal server errors.
This commit is contained in:
parent
3d43e143d2
commit
f5c1935c18
3 changed files with 197 additions and 44 deletions
|
@ -4,13 +4,46 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# NOTE: This file used to import torch unconditionally, which breaks on
|
||||
# minimal CPU-only installs and causes HTTP 500s when a Llama-4 model is
|
||||
# used for *text-only* tool calling (issue #2584). We now lazy-import
|
||||
# torch/vision deps so that simply importing the module or using text chat
|
||||
# functions does **not** require torch.
|
||||
|
||||
import io
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
# Attempt to import the heavy vision stack; fall back gracefully if unavailable.
|
||||
try:
|
||||
import torch # type: ignore
|
||||
from PIL import Image as PIL_Image # type: ignore
|
||||
|
||||
from .preprocess import ResizeNormalizeImageTransform, VariableSizeImageTransform # noqa: E401
|
||||
|
||||
_TORCH_AVAILABLE = True
|
||||
except (ModuleNotFoundError, ImportError): # torch or PIL or derived deps missing
|
||||
torch = None # type: ignore
|
||||
PIL_Image = None # type: ignore
|
||||
_TORCH_AVAILABLE = False
|
||||
|
||||
def _raise_torch_required():
|
||||
raise ImportError(
|
||||
"Llama-4 vision features require the `torch` and `Pillow` packages. "
|
||||
"Install them to enable image processing."
|
||||
)
|
||||
|
||||
# Create dummy transform classes that raise if instantiated
|
||||
class ResizeNormalizeImageTransform: # type: ignore
|
||||
def __init__(self, *_, **__):
|
||||
_raise_torch_required()
|
||||
|
||||
class VariableSizeImageTransform: # type: ignore
|
||||
def __init__(self, *_, **__):
|
||||
_raise_torch_required()
|
||||
|
||||
import torch
|
||||
from PIL import Image as PIL_Image
|
||||
|
||||
# TODO: either fork these or move them to the common package
|
||||
from ..datatypes import (
|
||||
|
@ -27,7 +60,6 @@ from ..datatypes import (
|
|||
from ..llama3.tool_utils import ToolUtils
|
||||
from .args import VisionArgs
|
||||
from .datatypes import LLMInput
|
||||
from .preprocess import ResizeNormalizeImageTransform, VariableSizeImageTransform
|
||||
from .tokenizer import Tokenizer
|
||||
|
||||
|
||||
|
@ -41,20 +73,28 @@ def role_str(role: Role) -> str:
|
|||
return role_strs[role]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformedImage:
|
||||
image_tiles: torch.Tensor
|
||||
# is the aspect ratio needed anywhere?
|
||||
aspect_ratio: tuple[int, int]
|
||||
# Define TransformedImage so that the module is importable without torch.
|
||||
if TYPE_CHECKING or _TORCH_AVAILABLE:
|
||||
# Normal case - torch available.
|
||||
@dataclass
|
||||
class TransformedImage: # type: ignore
|
||||
image_tiles: "torch.Tensor" # quotes to avoid mypy when torch is None
|
||||
aspect_ratio: tuple[int, int]
|
||||
else:
|
||||
|
||||
@dataclass
|
||||
class TransformedImage: # type: ignore
|
||||
image_tiles: Any # placeholder; actual methods will never be called
|
||||
aspect_ratio: tuple[int, int]
|
||||
|
||||
|
||||
def convert_image_to_rgb(image: PIL_Image.Image, bg: tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
|
||||
def convert_image_to_rgb(image: "PIL_Image.Image", bg: tuple[int, int, int] = (255, 255, 255)) -> "PIL_Image.Image":
|
||||
if image.mode == "RGBA":
|
||||
image.load() # for png.split()
|
||||
new_img = PIL_Image.new("RGB", image.size, bg)
|
||||
new_img = PIL_Image.new("RGB", image.size, bg) # type: ignore
|
||||
new_img.paste(image, mask=image.split()[3]) # 3 is the alpha channel
|
||||
return new_img
|
||||
return image.convert("RGB")
|
||||
return image.convert("RGB") # type: ignore
|
||||
|
||||
|
||||
class ChatFormat:
|
||||
|
@ -75,6 +115,8 @@ class ChatFormat:
|
|||
self.image_transform = None
|
||||
self.dynamic_image_transform = None
|
||||
if vision_args:
|
||||
if not _TORCH_AVAILABLE:
|
||||
_raise_torch_required()
|
||||
self.dynamic_image_transform = VariableSizeImageTransform(vision_args.image_size.width)
|
||||
self.image_transform = ResizeNormalizeImageTransform(
|
||||
vision_args.image_size.width, vision_args.image_size.height
|
||||
|
@ -98,6 +140,9 @@ class ChatFormat:
|
|||
self,
|
||||
transformed_image: TransformedImage,
|
||||
) -> list[int]:
|
||||
if not _TORCH_AVAILABLE:
|
||||
_raise_torch_required()
|
||||
|
||||
assert self.vision_args is not None, "The model is not vision-enabled"
|
||||
|
||||
image_tensor = transformed_image.image_tiles
|
||||
|
@ -164,7 +209,7 @@ class ChatFormat:
|
|||
added_bos = True
|
||||
|
||||
bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data
|
||||
image = PIL_Image.open(bytes_io)
|
||||
image = PIL_Image.open(bytes_io) # type: ignore
|
||||
image = convert_image_to_rgb(image)
|
||||
image_tiles, ar = self.dynamic_image_transform(image, max_num_chunks=self.max_num_chunks)
|
||||
|
||||
|
|
|
@ -5,52 +5,79 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
# Lazy import torch to avoid breaking text-only usage
|
||||
try:
|
||||
import torch # type: ignore
|
||||
|
||||
_TORCH_AVAILABLE = True
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
torch = None # type: ignore
|
||||
_TORCH_AVAILABLE = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaskedEmbedding:
|
||||
embedding: torch.Tensor
|
||||
mask: torch.Tensor
|
||||
if TYPE_CHECKING or _TORCH_AVAILABLE:
|
||||
|
||||
@dataclass
|
||||
class MaskedEmbedding: # type: ignore
|
||||
embedding: "torch.Tensor"
|
||||
mask: "torch.Tensor"
|
||||
|
||||
@dataclass
|
||||
class LLMInput:
|
||||
"""
|
||||
This is the input to the LLM from the "user" -- the user in this case views the
|
||||
Llama4 model holistically and does not care or know about its inner workings (e.g.,
|
||||
whether it has an encoder or if it is early fusion or not.)
|
||||
@dataclass
|
||||
class LLMInput: # type: ignore
|
||||
"""
|
||||
This is the input to the LLM from the "user" -- the user in this case views the
|
||||
Llama4 model holistically and does not care or know about its inner workings (e.g.,
|
||||
whether it has an encoder or if it is early fusion or not.)
|
||||
|
||||
This is distinct from the "TransformerInput" class which is really the Llama4
|
||||
backbone operating on early fused modalities and producing text output
|
||||
"""
|
||||
This is distinct from the "TransformerInput" class which is really the Llama4
|
||||
backbone operating on early fused modalities and producing text output
|
||||
"""
|
||||
|
||||
tokens: torch.Tensor
|
||||
tokens: "torch.Tensor"
|
||||
# images are already pre-processed (resized, tiled, etc.)
|
||||
images: list["torch.Tensor"] | None = None
|
||||
|
||||
# images are already pre-processed (resized, tiled, etc.)
|
||||
images: list[torch.Tensor] | None = None
|
||||
@dataclass
|
||||
class TransformerInput: # type: ignore
|
||||
"""
|
||||
This is the "core" backbone transformer of the Llama4 model. Inputs for other modalities
|
||||
are expected to be "embedded" via encoders sitting before this layer in the model.
|
||||
"""
|
||||
|
||||
tokens: "torch.Tensor"
|
||||
# tokens_position defines the position of the tokens in each batch,
|
||||
# - when it is a tensor ([batch_size,]), it is the start position of the tokens in each batch
|
||||
# - when it is an int, the start position are the same for all batches
|
||||
tokens_position: "torch.Tensor" | int
|
||||
image_embedding: "MaskedEmbedding" | None = None
|
||||
|
||||
@dataclass
|
||||
class TransformerInput:
|
||||
"""
|
||||
This is the "core" backbone transformer of the Llama4 model. Inputs for other modalities
|
||||
are expected to be "embedded" via encoders sitting before this layer in the model.
|
||||
"""
|
||||
@dataclass
|
||||
class LLMOutput: # type: ignore
|
||||
logits: "torch.Tensor"
|
||||
|
||||
tokens: torch.Tensor
|
||||
else:
|
||||
# Fallback stubs when torch unavailable
|
||||
@dataclass
|
||||
class MaskedEmbedding: # type: ignore
|
||||
embedding: Any
|
||||
mask: Any
|
||||
|
||||
# tokens_position defines the position of the tokens in each batch,
|
||||
# - when it is a tensor ([batch_size,]), it is the start position of the tokens in each batch
|
||||
# - when it is an int, the start position are the same for all batches
|
||||
tokens_position: torch.Tensor | int
|
||||
image_embedding: MaskedEmbedding | None = None
|
||||
@dataclass
|
||||
class LLMInput: # type: ignore
|
||||
tokens: Any
|
||||
images: Any = None
|
||||
|
||||
@dataclass
|
||||
class TransformerInput: # type: ignore
|
||||
tokens: Any
|
||||
tokens_position: Any
|
||||
image_embedding: Any = None
|
||||
|
||||
@dataclass
|
||||
class LLMOutput:
|
||||
logits: torch.Tensor
|
||||
@dataclass
|
||||
class LLMOutput: # type: ignore
|
||||
logits: Any
|
||||
|
||||
|
||||
TransformerOutput = LLMOutput
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue