This commit is contained in:
Sumanth Kamenani 2025-07-24 16:41:17 -04:00 committed by GitHub
commit 87873fc1d2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 298 additions and 44 deletions

View file

@ -4,13 +4,46 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# NOTE: This file used to import torch unconditionally, which breaks on
# minimal CPU-only installs and causes HTTP 500s when a Llama-4 model is
# used for *text-only* tool calling (issue #2584). We now lazy-import
# torch/vision deps so that simply importing the module or using text chat
# functions does **not** require torch.
import io import io
import json import json
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
# Attempt to import the heavy vision stack; fall back gracefully if unavailable.
try:
import torch # type: ignore
from PIL import Image as PIL_Image # type: ignore
from .preprocess import ResizeNormalizeImageTransform, VariableSizeImageTransform # noqa: E401
_TORCH_AVAILABLE = True
except (ModuleNotFoundError, ImportError): # torch or PIL or derived deps missing
torch = None # type: ignore
PIL_Image = None # type: ignore
_TORCH_AVAILABLE = False
def _raise_torch_required():
raise ImportError(
"Llama-4 vision features require the `torch` and `Pillow` packages. "
"Install them to enable image processing."
)
# Create dummy transform classes that raise if instantiated
class ResizeNormalizeImageTransform: # type: ignore
def __init__(self, *_, **__):
_raise_torch_required()
class VariableSizeImageTransform: # type: ignore
def __init__(self, *_, **__):
_raise_torch_required()
import torch
from PIL import Image as PIL_Image
# TODO: either fork these or move them to the common package # TODO: either fork these or move them to the common package
from ..datatypes import ( from ..datatypes import (
@ -27,7 +60,6 @@ from ..datatypes import (
from ..llama3.tool_utils import ToolUtils from ..llama3.tool_utils import ToolUtils
from .args import VisionArgs from .args import VisionArgs
from .datatypes import LLMInput from .datatypes import LLMInput
from .preprocess import ResizeNormalizeImageTransform, VariableSizeImageTransform
from .tokenizer import Tokenizer from .tokenizer import Tokenizer
@ -41,20 +73,28 @@ def role_str(role: Role) -> str:
return role_strs[role] return role_strs[role]
# Define TransformedImage so that the module is importable without torch.
if TYPE_CHECKING or _TORCH_AVAILABLE:
# Normal case - torch available.
@dataclass @dataclass
class TransformedImage: class TransformedImage: # type: ignore
image_tiles: torch.Tensor image_tiles: "torch.Tensor" # quotes to avoid mypy when torch is None
# is the aspect ratio needed anywhere? aspect_ratio: tuple[int, int]
else:
@dataclass
class TransformedImage: # type: ignore
image_tiles: Any # placeholder; actual methods will never be called
aspect_ratio: tuple[int, int] aspect_ratio: tuple[int, int]
def convert_image_to_rgb(image: PIL_Image.Image, bg: tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image: def convert_image_to_rgb(image: "PIL_Image.Image", bg: tuple[int, int, int] = (255, 255, 255)) -> "PIL_Image.Image":
if image.mode == "RGBA": if image.mode == "RGBA":
image.load() # for png.split() image.load() # for png.split()
new_img = PIL_Image.new("RGB", image.size, bg) new_img = PIL_Image.new("RGB", image.size, bg) # type: ignore
new_img.paste(image, mask=image.split()[3]) # 3 is the alpha channel new_img.paste(image, mask=image.split()[3]) # 3 is the alpha channel
return new_img return new_img
return image.convert("RGB") return image.convert("RGB") # type: ignore
class ChatFormat: class ChatFormat:
@ -75,6 +115,8 @@ class ChatFormat:
self.image_transform = None self.image_transform = None
self.dynamic_image_transform = None self.dynamic_image_transform = None
if vision_args: if vision_args:
if not _TORCH_AVAILABLE:
_raise_torch_required()
self.dynamic_image_transform = VariableSizeImageTransform(vision_args.image_size.width) self.dynamic_image_transform = VariableSizeImageTransform(vision_args.image_size.width)
self.image_transform = ResizeNormalizeImageTransform( self.image_transform = ResizeNormalizeImageTransform(
vision_args.image_size.width, vision_args.image_size.height vision_args.image_size.width, vision_args.image_size.height
@ -98,6 +140,9 @@ class ChatFormat:
self, self,
transformed_image: TransformedImage, transformed_image: TransformedImage,
) -> list[int]: ) -> list[int]:
if not _TORCH_AVAILABLE:
_raise_torch_required()
assert self.vision_args is not None, "The model is not vision-enabled" assert self.vision_args is not None, "The model is not vision-enabled"
image_tensor = transformed_image.image_tiles image_tensor = transformed_image.image_tiles
@ -164,7 +209,7 @@ class ChatFormat:
added_bos = True added_bos = True
bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data
image = PIL_Image.open(bytes_io) image = PIL_Image.open(bytes_io) # type: ignore
image = convert_image_to_rgb(image) image = convert_image_to_rgb(image)
image_tiles, ar = self.dynamic_image_transform(image, max_num_chunks=self.max_num_chunks) image_tiles, ar = self.dynamic_image_transform(image, max_num_chunks=self.max_num_chunks)

View file

@ -5,18 +5,27 @@
# the root directory of this source tree. # the root directory of this source tree.
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
import torch # Lazy import torch to avoid breaking text-only usage
try:
import torch # type: ignore
_TORCH_AVAILABLE = True
except (ModuleNotFoundError, ImportError):
torch = None # type: ignore
_TORCH_AVAILABLE = False
if TYPE_CHECKING or _TORCH_AVAILABLE:
@dataclass @dataclass
class MaskedEmbedding: class MaskedEmbedding: # type: ignore
embedding: torch.Tensor embedding: "torch.Tensor"
mask: torch.Tensor mask: "torch.Tensor"
@dataclass @dataclass
class LLMInput: class LLMInput: # type: ignore
""" """
This is the input to the LLM from the "user" -- the user in this case views the This is the input to the LLM from the "user" -- the user in this case views the
Llama4 model holistically and does not care or know about its inner workings (e.g., Llama4 model holistically and does not care or know about its inner workings (e.g.,
@ -26,31 +35,49 @@ class LLMInput:
backbone operating on early fused modalities and producing text output backbone operating on early fused modalities and producing text output
""" """
tokens: torch.Tensor tokens: "torch.Tensor"
# images are already pre-processed (resized, tiled, etc.) # images are already pre-processed (resized, tiled, etc.)
images: list[torch.Tensor] | None = None images: list["torch.Tensor"] | None = None
@dataclass @dataclass
class TransformerInput: class TransformerInput: # type: ignore
""" """
This is the "core" backbone transformer of the Llama4 model. Inputs for other modalities This is the "core" backbone transformer of the Llama4 model. Inputs for other modalities
are expected to be "embedded" via encoders sitting before this layer in the model. are expected to be "embedded" via encoders sitting before this layer in the model.
""" """
tokens: torch.Tensor tokens: "torch.Tensor"
# tokens_position defines the position of the tokens in each batch, # tokens_position defines the position of the tokens in each batch,
# - when it is a tensor ([batch_size,]), it is the start position of the tokens in each batch # - when it is a tensor ([batch_size,]), it is the start position of the tokens in each batch
# - when it is an int, the start position are the same for all batches # - when it is an int, the start position are the same for all batches
tokens_position: torch.Tensor | int tokens_position: "torch.Tensor" | int
image_embedding: MaskedEmbedding | None = None image_embedding: "MaskedEmbedding" | None = None
@dataclass @dataclass
class LLMOutput: class LLMOutput: # type: ignore
logits: torch.Tensor logits: "torch.Tensor"
else:
# Fallback stubs when torch unavailable
@dataclass
class MaskedEmbedding: # type: ignore
embedding: Any
mask: Any
@dataclass
class LLMInput: # type: ignore
tokens: Any
images: Any = None
@dataclass
class TransformerInput: # type: ignore
tokens: Any
tokens_position: Any
image_embedding: Any = None
@dataclass
class LLMOutput: # type: ignore
logits: Any
TransformerOutput = LLMOutput TransformerOutput = LLMOutput

View file

@ -19,9 +19,42 @@ LLAMA3_VOCAB_SIZE = 128256
def resolve_model(descriptor: str) -> Model | None: def resolve_model(descriptor: str) -> Model | None:
"""Return the canonical `Model` that matches *descriptor*.
The helper originally accepted the model *descriptor* (e.g. "Llama-4-Scout-17B-16E-Instruct")
or the HuggingFace repository path (e.g. "meta-llama/Llama-4-Scout-17B-16E-Instruct").
Review feedback (see PR #2796) highlighted that callers - especially provider
adaptors - were passing provider-qualified aliases such as
"together/meta-llama/Llama-4-Scout-17B-16E-Instruct"
Having provider-specific logic here is undesirable. Instead of hard-coding
aliases in *this* file we normalise the incoming descriptor by stripping a
leading provider prefix of the form "<provider>/" (e.g. "together/",
"groq/", ) *once* and then retry the lookup. This keeps sku_list
provider-agnostic while still resolving all legitimate aliases that
individual providers register in their own modules.
"""
# Direct match against descriptor or HF repo.
for m in all_registered_models(): for m in all_registered_models():
if descriptor in (m.descriptor(), m.huggingface_repo): if descriptor in (m.descriptor(), m.huggingface_repo):
return m return m
# Handle provider-prefixed aliases - strip provider prefix ("together/", "groq/", etc.) if present.
if "/" in descriptor:
# Many provider aliases look like "<provider>/<repo_path>"; we only need
# the repo_path (everything after the first slash) for a successful
# lookup. Splitting just once avoids over-stripping repo paths that
# legitimately contain more than one component (e.g. "meta-llama/…").
_, remainder = descriptor.split("/", 1)
# Recursively attempt to resolve the stripped descriptor to avoid code
# duplication. The depth here is at most 1 because the second call will
# hit the fast path above.
if remainder != descriptor: # guard against infinite recursion
return resolve_model(remainder)
return None return None

View file

@ -0,0 +1,81 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Test that Llama-4 modules can be imported and used for text-only operations
even when torch is not available (addresses issue #2584).
"""
import builtins
import importlib
import sys
import pytest
def _block_torch(monkeypatch):
"""Block torch imports to simulate torch-free environment."""
real_import = builtins.__import__
def fake_import(name, *args, **kwargs):
if name == "torch" or name.startswith("torch."):
raise ModuleNotFoundError("No module named 'torch'")
return real_import(name, *args, **kwargs)
sys.modules.pop("torch", None) # forget any cached import
monkeypatch.setattr(builtins, "__import__", fake_import)
def test_llama4_chat_format_imports_without_torch(monkeypatch):
"""Test that llama4.chat_format can be imported when torch is unavailable."""
_block_torch(monkeypatch)
# This should NOT raise ImportError anymore
chat_format_module = importlib.import_module("llama_stack.models.llama.llama4.chat_format")
assert chat_format_module is not None
def test_llama4_text_decoding_works_without_torch(monkeypatch):
"""Test that text-only tool calling decoding works without torch."""
_block_torch(monkeypatch)
from llama_stack.models.llama.datatypes import StopReason
from llama_stack.models.llama.llama4.chat_format import ChatFormat
from llama_stack.models.llama.llama4.tokenizer import Tokenizer
# Text-only operations should work fine
formatter = ChatFormat(Tokenizer.get_instance())
content = '[get_weather(location="SF")]<|eot|>'
msg = formatter.decode_assistant_message_from_content(content, StopReason.end_of_turn)
# Verify tool calling parsing works
assert msg.tool_calls, "Tool call should be detected"
tc = msg.tool_calls[0]
assert tc.tool_name == "get_weather"
assert tc.arguments == {"location": "SF"}
def test_llama4_vision_fails_gracefully_without_torch(monkeypatch):
"""Test that vision features raise clear error when torch unavailable."""
_block_torch(monkeypatch)
from llama_stack.models.llama.llama4.args import Size, VisionArgs
from llama_stack.models.llama.llama4.chat_format import ChatFormat
from llama_stack.models.llama.llama4.tokenizer import Tokenizer
# Trying to use vision features should raise helpful error
vision_args = VisionArgs(
image_size=Size(height=448, width=448),
patch_size=Size(height=14, width=14),
dim=512,
n_layers=6,
n_heads=8,
mlp_ratio=4.0,
output_dim=4096,
pixel_shuffle_ratio=2,
)
with pytest.raises(ImportError, match="vision features require.*torch.*Pillow"):
ChatFormat(Tokenizer.get_instance(), vision_args=vision_args)

View file

@ -0,0 +1,68 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.sku_list import resolve_model
def test_resolve_by_descriptor():
"""Test normal resolution by model descriptor."""
model = resolve_model("Llama-4-Scout-17B-16E-Instruct")
assert model is not None
assert model.core_model_id.value == "Llama-4-Scout-17B-16E-Instruct"
def test_resolve_by_huggingface_repo():
"""Test normal resolution by HuggingFace repo path."""
model = resolve_model("meta-llama/Llama-4-Scout-17B-16E-Instruct")
assert model is not None
assert model.core_model_id.value == "Llama-4-Scout-17B-16E-Instruct"
def test_together_alias_resolves():
"""Test that Together-prefixed alias resolves via generic prefix stripping."""
alias = "together/meta-llama/Llama-4-Scout-17B-16E-Instruct"
model = resolve_model(alias)
assert model is not None, f"Model should resolve for alias {alias}"
assert model.core_model_id.value == "Llama-4-Scout-17B-16E-Instruct"
def test_groq_alias_resolves():
"""Test that Groq-prefixed alias resolves via generic prefix stripping."""
alias = "groq/meta-llama/Llama-4-Scout-17B-16E-Instruct"
model = resolve_model(alias)
assert model is not None, f"Model should resolve for alias {alias}"
assert model.core_model_id.value == "Llama-4-Scout-17B-16E-Instruct"
def test_unknown_model_returns_none():
"""Test that unknown model descriptors return None."""
model = resolve_model("nonexistent-model")
assert model is None
def test_unknown_provider_prefix_returns_none():
"""Test that unknown provider prefix with unknown model returns None."""
model = resolve_model("unknown-provider/nonexistent-model")
assert model is None
def test_empty_string_returns_none():
"""Test that empty string returns None."""
model = resolve_model("")
assert model is None
def test_slash_only_returns_none():
"""Test that just a slash returns None."""
model = resolve_model("/")
assert model is None
def test_multiple_slashes_handled():
"""Test that paths with multiple slashes are handled correctly."""
# This should strip "provider/" and try "path/to/model"
model = resolve_model("provider/path/to/model")
assert model is None # Should be None since "path/to/model" doesn't exist