This commit is contained in:
Sumanth Kamenani 2025-07-24 16:41:17 -04:00 committed by GitHub
commit 87873fc1d2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 298 additions and 44 deletions

View file

@ -4,13 +4,46 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# NOTE: This file used to import torch unconditionally, which breaks on
# minimal CPU-only installs and causes HTTP 500s when a Llama-4 model is
# used for *text-only* tool calling (issue #2584). We now lazy-import
# torch/vision deps so that simply importing the module or using text chat
# functions does **not** require torch.
import io
import json
import uuid
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
# Attempt to import the heavy vision stack; fall back gracefully if unavailable.
try:
import torch # type: ignore
from PIL import Image as PIL_Image # type: ignore
from .preprocess import ResizeNormalizeImageTransform, VariableSizeImageTransform # noqa: E401
_TORCH_AVAILABLE = True
except (ModuleNotFoundError, ImportError): # torch or PIL or derived deps missing
torch = None # type: ignore
PIL_Image = None # type: ignore
_TORCH_AVAILABLE = False
def _raise_torch_required():
raise ImportError(
"Llama-4 vision features require the `torch` and `Pillow` packages. "
"Install them to enable image processing."
)
# Create dummy transform classes that raise if instantiated
class ResizeNormalizeImageTransform: # type: ignore
def __init__(self, *_, **__):
_raise_torch_required()
class VariableSizeImageTransform: # type: ignore
def __init__(self, *_, **__):
_raise_torch_required()
import torch
from PIL import Image as PIL_Image
# TODO: either fork these or move them to the common package
from ..datatypes import (
@ -27,7 +60,6 @@ from ..datatypes import (
from ..llama3.tool_utils import ToolUtils
from .args import VisionArgs
from .datatypes import LLMInput
from .preprocess import ResizeNormalizeImageTransform, VariableSizeImageTransform
from .tokenizer import Tokenizer
@ -41,20 +73,28 @@ def role_str(role: Role) -> str:
return role_strs[role]
@dataclass
class TransformedImage:
image_tiles: torch.Tensor
# is the aspect ratio needed anywhere?
aspect_ratio: tuple[int, int]
# Define TransformedImage so that the module is importable without torch.
if TYPE_CHECKING or _TORCH_AVAILABLE:
# Normal case - torch available.
@dataclass
class TransformedImage: # type: ignore
image_tiles: "torch.Tensor" # quotes to avoid mypy when torch is None
aspect_ratio: tuple[int, int]
else:
@dataclass
class TransformedImage: # type: ignore
image_tiles: Any # placeholder; actual methods will never be called
aspect_ratio: tuple[int, int]
def convert_image_to_rgb(image: PIL_Image.Image, bg: tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
def convert_image_to_rgb(image: "PIL_Image.Image", bg: tuple[int, int, int] = (255, 255, 255)) -> "PIL_Image.Image":
if image.mode == "RGBA":
image.load() # for png.split()
new_img = PIL_Image.new("RGB", image.size, bg)
new_img = PIL_Image.new("RGB", image.size, bg) # type: ignore
new_img.paste(image, mask=image.split()[3]) # 3 is the alpha channel
return new_img
return image.convert("RGB")
return image.convert("RGB") # type: ignore
class ChatFormat:
@ -75,6 +115,8 @@ class ChatFormat:
self.image_transform = None
self.dynamic_image_transform = None
if vision_args:
if not _TORCH_AVAILABLE:
_raise_torch_required()
self.dynamic_image_transform = VariableSizeImageTransform(vision_args.image_size.width)
self.image_transform = ResizeNormalizeImageTransform(
vision_args.image_size.width, vision_args.image_size.height
@ -98,6 +140,9 @@ class ChatFormat:
self,
transformed_image: TransformedImage,
) -> list[int]:
if not _TORCH_AVAILABLE:
_raise_torch_required()
assert self.vision_args is not None, "The model is not vision-enabled"
image_tensor = transformed_image.image_tiles
@ -164,7 +209,7 @@ class ChatFormat:
added_bos = True
bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data
image = PIL_Image.open(bytes_io)
image = PIL_Image.open(bytes_io) # type: ignore
image = convert_image_to_rgb(image)
image_tiles, ar = self.dynamic_image_transform(image, max_num_chunks=self.max_num_chunks)

View file

@ -5,52 +5,79 @@
# the root directory of this source tree.
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
import torch
# Lazy import torch to avoid breaking text-only usage
try:
import torch # type: ignore
_TORCH_AVAILABLE = True
except (ModuleNotFoundError, ImportError):
torch = None # type: ignore
_TORCH_AVAILABLE = False
@dataclass
class MaskedEmbedding:
embedding: torch.Tensor
mask: torch.Tensor
if TYPE_CHECKING or _TORCH_AVAILABLE:
@dataclass
class MaskedEmbedding: # type: ignore
embedding: "torch.Tensor"
mask: "torch.Tensor"
@dataclass
class LLMInput:
"""
This is the input to the LLM from the "user" -- the user in this case views the
Llama4 model holistically and does not care or know about its inner workings (e.g.,
whether it has an encoder or if it is early fusion or not.)
@dataclass
class LLMInput: # type: ignore
"""
This is the input to the LLM from the "user" -- the user in this case views the
Llama4 model holistically and does not care or know about its inner workings (e.g.,
whether it has an encoder or if it is early fusion or not.)
This is distinct from the "TransformerInput" class which is really the Llama4
backbone operating on early fused modalities and producing text output
"""
This is distinct from the "TransformerInput" class which is really the Llama4
backbone operating on early fused modalities and producing text output
"""
tokens: torch.Tensor
tokens: "torch.Tensor"
# images are already pre-processed (resized, tiled, etc.)
images: list["torch.Tensor"] | None = None
# images are already pre-processed (resized, tiled, etc.)
images: list[torch.Tensor] | None = None
@dataclass
class TransformerInput: # type: ignore
"""
This is the "core" backbone transformer of the Llama4 model. Inputs for other modalities
are expected to be "embedded" via encoders sitting before this layer in the model.
"""
tokens: "torch.Tensor"
# tokens_position defines the position of the tokens in each batch,
# - when it is a tensor ([batch_size,]), it is the start position of the tokens in each batch
# - when it is an int, the start position are the same for all batches
tokens_position: "torch.Tensor" | int
image_embedding: "MaskedEmbedding" | None = None
@dataclass
class TransformerInput:
"""
This is the "core" backbone transformer of the Llama4 model. Inputs for other modalities
are expected to be "embedded" via encoders sitting before this layer in the model.
"""
@dataclass
class LLMOutput: # type: ignore
logits: "torch.Tensor"
tokens: torch.Tensor
else:
# Fallback stubs when torch unavailable
@dataclass
class MaskedEmbedding: # type: ignore
embedding: Any
mask: Any
# tokens_position defines the position of the tokens in each batch,
# - when it is a tensor ([batch_size,]), it is the start position of the tokens in each batch
# - when it is an int, the start position are the same for all batches
tokens_position: torch.Tensor | int
image_embedding: MaskedEmbedding | None = None
@dataclass
class LLMInput: # type: ignore
tokens: Any
images: Any = None
@dataclass
class TransformerInput: # type: ignore
tokens: Any
tokens_position: Any
image_embedding: Any = None
@dataclass
class LLMOutput:
logits: torch.Tensor
@dataclass
class LLMOutput: # type: ignore
logits: Any
TransformerOutput = LLMOutput

View file

@ -19,9 +19,42 @@ LLAMA3_VOCAB_SIZE = 128256
def resolve_model(descriptor: str) -> Model | None:
"""Return the canonical `Model` that matches *descriptor*.
The helper originally accepted the model *descriptor* (e.g. "Llama-4-Scout-17B-16E-Instruct")
or the HuggingFace repository path (e.g. "meta-llama/Llama-4-Scout-17B-16E-Instruct").
Review feedback (see PR #2796) highlighted that callers - especially provider
adaptors - were passing provider-qualified aliases such as
"together/meta-llama/Llama-4-Scout-17B-16E-Instruct"
Having provider-specific logic here is undesirable. Instead of hard-coding
aliases in *this* file we normalise the incoming descriptor by stripping a
leading provider prefix of the form "<provider>/" (e.g. "together/",
"groq/", ) *once* and then retry the lookup. This keeps sku_list
provider-agnostic while still resolving all legitimate aliases that
individual providers register in their own modules.
"""
# Direct match against descriptor or HF repo.
for m in all_registered_models():
if descriptor in (m.descriptor(), m.huggingface_repo):
return m
# Handle provider-prefixed aliases - strip provider prefix ("together/", "groq/", etc.) if present.
if "/" in descriptor:
# Many provider aliases look like "<provider>/<repo_path>"; we only need
# the repo_path (everything after the first slash) for a successful
# lookup. Splitting just once avoids over-stripping repo paths that
# legitimately contain more than one component (e.g. "meta-llama/…").
_, remainder = descriptor.split("/", 1)
# Recursively attempt to resolve the stripped descriptor to avoid code
# duplication. The depth here is at most 1 because the second call will
# hit the fast path above.
if remainder != descriptor: # guard against infinite recursion
return resolve_model(remainder)
return None

View file

@ -0,0 +1,81 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Test that Llama-4 modules can be imported and used for text-only operations
even when torch is not available (addresses issue #2584).
"""
import builtins
import importlib
import sys
import pytest
def _block_torch(monkeypatch):
"""Block torch imports to simulate torch-free environment."""
real_import = builtins.__import__
def fake_import(name, *args, **kwargs):
if name == "torch" or name.startswith("torch."):
raise ModuleNotFoundError("No module named 'torch'")
return real_import(name, *args, **kwargs)
sys.modules.pop("torch", None) # forget any cached import
monkeypatch.setattr(builtins, "__import__", fake_import)
def test_llama4_chat_format_imports_without_torch(monkeypatch):
"""Test that llama4.chat_format can be imported when torch is unavailable."""
_block_torch(monkeypatch)
# This should NOT raise ImportError anymore
chat_format_module = importlib.import_module("llama_stack.models.llama.llama4.chat_format")
assert chat_format_module is not None
def test_llama4_text_decoding_works_without_torch(monkeypatch):
"""Test that text-only tool calling decoding works without torch."""
_block_torch(monkeypatch)
from llama_stack.models.llama.datatypes import StopReason
from llama_stack.models.llama.llama4.chat_format import ChatFormat
from llama_stack.models.llama.llama4.tokenizer import Tokenizer
# Text-only operations should work fine
formatter = ChatFormat(Tokenizer.get_instance())
content = '[get_weather(location="SF")]<|eot|>'
msg = formatter.decode_assistant_message_from_content(content, StopReason.end_of_turn)
# Verify tool calling parsing works
assert msg.tool_calls, "Tool call should be detected"
tc = msg.tool_calls[0]
assert tc.tool_name == "get_weather"
assert tc.arguments == {"location": "SF"}
def test_llama4_vision_fails_gracefully_without_torch(monkeypatch):
"""Test that vision features raise clear error when torch unavailable."""
_block_torch(monkeypatch)
from llama_stack.models.llama.llama4.args import Size, VisionArgs
from llama_stack.models.llama.llama4.chat_format import ChatFormat
from llama_stack.models.llama.llama4.tokenizer import Tokenizer
# Trying to use vision features should raise helpful error
vision_args = VisionArgs(
image_size=Size(height=448, width=448),
patch_size=Size(height=14, width=14),
dim=512,
n_layers=6,
n_heads=8,
mlp_ratio=4.0,
output_dim=4096,
pixel_shuffle_ratio=2,
)
with pytest.raises(ImportError, match="vision features require.*torch.*Pillow"):
ChatFormat(Tokenizer.get_instance(), vision_args=vision_args)

View file

@ -0,0 +1,68 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.sku_list import resolve_model
def test_resolve_by_descriptor():
"""Test normal resolution by model descriptor."""
model = resolve_model("Llama-4-Scout-17B-16E-Instruct")
assert model is not None
assert model.core_model_id.value == "Llama-4-Scout-17B-16E-Instruct"
def test_resolve_by_huggingface_repo():
"""Test normal resolution by HuggingFace repo path."""
model = resolve_model("meta-llama/Llama-4-Scout-17B-16E-Instruct")
assert model is not None
assert model.core_model_id.value == "Llama-4-Scout-17B-16E-Instruct"
def test_together_alias_resolves():
"""Test that Together-prefixed alias resolves via generic prefix stripping."""
alias = "together/meta-llama/Llama-4-Scout-17B-16E-Instruct"
model = resolve_model(alias)
assert model is not None, f"Model should resolve for alias {alias}"
assert model.core_model_id.value == "Llama-4-Scout-17B-16E-Instruct"
def test_groq_alias_resolves():
"""Test that Groq-prefixed alias resolves via generic prefix stripping."""
alias = "groq/meta-llama/Llama-4-Scout-17B-16E-Instruct"
model = resolve_model(alias)
assert model is not None, f"Model should resolve for alias {alias}"
assert model.core_model_id.value == "Llama-4-Scout-17B-16E-Instruct"
def test_unknown_model_returns_none():
"""Test that unknown model descriptors return None."""
model = resolve_model("nonexistent-model")
assert model is None
def test_unknown_provider_prefix_returns_none():
"""Test that unknown provider prefix with unknown model returns None."""
model = resolve_model("unknown-provider/nonexistent-model")
assert model is None
def test_empty_string_returns_none():
"""Test that empty string returns None."""
model = resolve_model("")
assert model is None
def test_slash_only_returns_none():
"""Test that just a slash returns None."""
model = resolve_model("/")
assert model is None
def test_multiple_slashes_handled():
"""Test that paths with multiple slashes are handled correctly."""
# This should strip "provider/" and try "path/to/model"
model = resolve_model("provider/path/to/model")
assert model is None # Should be None since "path/to/model" doesn't exist