mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-25 21:57:45 +00:00
Merge 561912064c
into 632cf9eb72
This commit is contained in:
commit
87873fc1d2
5 changed files with 298 additions and 44 deletions
|
@ -4,13 +4,46 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# NOTE: This file used to import torch unconditionally, which breaks on
|
||||
# minimal CPU-only installs and causes HTTP 500s when a Llama-4 model is
|
||||
# used for *text-only* tool calling (issue #2584). We now lazy-import
|
||||
# torch/vision deps so that simply importing the module or using text chat
|
||||
# functions does **not** require torch.
|
||||
|
||||
import io
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
# Attempt to import the heavy vision stack; fall back gracefully if unavailable.
|
||||
try:
|
||||
import torch # type: ignore
|
||||
from PIL import Image as PIL_Image # type: ignore
|
||||
|
||||
from .preprocess import ResizeNormalizeImageTransform, VariableSizeImageTransform # noqa: E401
|
||||
|
||||
_TORCH_AVAILABLE = True
|
||||
except (ModuleNotFoundError, ImportError): # torch or PIL or derived deps missing
|
||||
torch = None # type: ignore
|
||||
PIL_Image = None # type: ignore
|
||||
_TORCH_AVAILABLE = False
|
||||
|
||||
def _raise_torch_required():
|
||||
raise ImportError(
|
||||
"Llama-4 vision features require the `torch` and `Pillow` packages. "
|
||||
"Install them to enable image processing."
|
||||
)
|
||||
|
||||
# Create dummy transform classes that raise if instantiated
|
||||
class ResizeNormalizeImageTransform: # type: ignore
|
||||
def __init__(self, *_, **__):
|
||||
_raise_torch_required()
|
||||
|
||||
class VariableSizeImageTransform: # type: ignore
|
||||
def __init__(self, *_, **__):
|
||||
_raise_torch_required()
|
||||
|
||||
import torch
|
||||
from PIL import Image as PIL_Image
|
||||
|
||||
# TODO: either fork these or move them to the common package
|
||||
from ..datatypes import (
|
||||
|
@ -27,7 +60,6 @@ from ..datatypes import (
|
|||
from ..llama3.tool_utils import ToolUtils
|
||||
from .args import VisionArgs
|
||||
from .datatypes import LLMInput
|
||||
from .preprocess import ResizeNormalizeImageTransform, VariableSizeImageTransform
|
||||
from .tokenizer import Tokenizer
|
||||
|
||||
|
||||
|
@ -41,20 +73,28 @@ def role_str(role: Role) -> str:
|
|||
return role_strs[role]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformedImage:
|
||||
image_tiles: torch.Tensor
|
||||
# is the aspect ratio needed anywhere?
|
||||
aspect_ratio: tuple[int, int]
|
||||
# Define TransformedImage so that the module is importable without torch.
|
||||
if TYPE_CHECKING or _TORCH_AVAILABLE:
|
||||
# Normal case - torch available.
|
||||
@dataclass
|
||||
class TransformedImage: # type: ignore
|
||||
image_tiles: "torch.Tensor" # quotes to avoid mypy when torch is None
|
||||
aspect_ratio: tuple[int, int]
|
||||
else:
|
||||
|
||||
@dataclass
|
||||
class TransformedImage: # type: ignore
|
||||
image_tiles: Any # placeholder; actual methods will never be called
|
||||
aspect_ratio: tuple[int, int]
|
||||
|
||||
|
||||
def convert_image_to_rgb(image: PIL_Image.Image, bg: tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
|
||||
def convert_image_to_rgb(image: "PIL_Image.Image", bg: tuple[int, int, int] = (255, 255, 255)) -> "PIL_Image.Image":
|
||||
if image.mode == "RGBA":
|
||||
image.load() # for png.split()
|
||||
new_img = PIL_Image.new("RGB", image.size, bg)
|
||||
new_img = PIL_Image.new("RGB", image.size, bg) # type: ignore
|
||||
new_img.paste(image, mask=image.split()[3]) # 3 is the alpha channel
|
||||
return new_img
|
||||
return image.convert("RGB")
|
||||
return image.convert("RGB") # type: ignore
|
||||
|
||||
|
||||
class ChatFormat:
|
||||
|
@ -75,6 +115,8 @@ class ChatFormat:
|
|||
self.image_transform = None
|
||||
self.dynamic_image_transform = None
|
||||
if vision_args:
|
||||
if not _TORCH_AVAILABLE:
|
||||
_raise_torch_required()
|
||||
self.dynamic_image_transform = VariableSizeImageTransform(vision_args.image_size.width)
|
||||
self.image_transform = ResizeNormalizeImageTransform(
|
||||
vision_args.image_size.width, vision_args.image_size.height
|
||||
|
@ -98,6 +140,9 @@ class ChatFormat:
|
|||
self,
|
||||
transformed_image: TransformedImage,
|
||||
) -> list[int]:
|
||||
if not _TORCH_AVAILABLE:
|
||||
_raise_torch_required()
|
||||
|
||||
assert self.vision_args is not None, "The model is not vision-enabled"
|
||||
|
||||
image_tensor = transformed_image.image_tiles
|
||||
|
@ -164,7 +209,7 @@ class ChatFormat:
|
|||
added_bos = True
|
||||
|
||||
bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data
|
||||
image = PIL_Image.open(bytes_io)
|
||||
image = PIL_Image.open(bytes_io) # type: ignore
|
||||
image = convert_image_to_rgb(image)
|
||||
image_tiles, ar = self.dynamic_image_transform(image, max_num_chunks=self.max_num_chunks)
|
||||
|
||||
|
|
|
@ -5,52 +5,79 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
# Lazy import torch to avoid breaking text-only usage
|
||||
try:
|
||||
import torch # type: ignore
|
||||
|
||||
_TORCH_AVAILABLE = True
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
torch = None # type: ignore
|
||||
_TORCH_AVAILABLE = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaskedEmbedding:
|
||||
embedding: torch.Tensor
|
||||
mask: torch.Tensor
|
||||
if TYPE_CHECKING or _TORCH_AVAILABLE:
|
||||
|
||||
@dataclass
|
||||
class MaskedEmbedding: # type: ignore
|
||||
embedding: "torch.Tensor"
|
||||
mask: "torch.Tensor"
|
||||
|
||||
@dataclass
|
||||
class LLMInput:
|
||||
"""
|
||||
This is the input to the LLM from the "user" -- the user in this case views the
|
||||
Llama4 model holistically and does not care or know about its inner workings (e.g.,
|
||||
whether it has an encoder or if it is early fusion or not.)
|
||||
@dataclass
|
||||
class LLMInput: # type: ignore
|
||||
"""
|
||||
This is the input to the LLM from the "user" -- the user in this case views the
|
||||
Llama4 model holistically and does not care or know about its inner workings (e.g.,
|
||||
whether it has an encoder or if it is early fusion or not.)
|
||||
|
||||
This is distinct from the "TransformerInput" class which is really the Llama4
|
||||
backbone operating on early fused modalities and producing text output
|
||||
"""
|
||||
This is distinct from the "TransformerInput" class which is really the Llama4
|
||||
backbone operating on early fused modalities and producing text output
|
||||
"""
|
||||
|
||||
tokens: torch.Tensor
|
||||
tokens: "torch.Tensor"
|
||||
# images are already pre-processed (resized, tiled, etc.)
|
||||
images: list["torch.Tensor"] | None = None
|
||||
|
||||
# images are already pre-processed (resized, tiled, etc.)
|
||||
images: list[torch.Tensor] | None = None
|
||||
@dataclass
|
||||
class TransformerInput: # type: ignore
|
||||
"""
|
||||
This is the "core" backbone transformer of the Llama4 model. Inputs for other modalities
|
||||
are expected to be "embedded" via encoders sitting before this layer in the model.
|
||||
"""
|
||||
|
||||
tokens: "torch.Tensor"
|
||||
# tokens_position defines the position of the tokens in each batch,
|
||||
# - when it is a tensor ([batch_size,]), it is the start position of the tokens in each batch
|
||||
# - when it is an int, the start position are the same for all batches
|
||||
tokens_position: "torch.Tensor" | int
|
||||
image_embedding: "MaskedEmbedding" | None = None
|
||||
|
||||
@dataclass
|
||||
class TransformerInput:
|
||||
"""
|
||||
This is the "core" backbone transformer of the Llama4 model. Inputs for other modalities
|
||||
are expected to be "embedded" via encoders sitting before this layer in the model.
|
||||
"""
|
||||
@dataclass
|
||||
class LLMOutput: # type: ignore
|
||||
logits: "torch.Tensor"
|
||||
|
||||
tokens: torch.Tensor
|
||||
else:
|
||||
# Fallback stubs when torch unavailable
|
||||
@dataclass
|
||||
class MaskedEmbedding: # type: ignore
|
||||
embedding: Any
|
||||
mask: Any
|
||||
|
||||
# tokens_position defines the position of the tokens in each batch,
|
||||
# - when it is a tensor ([batch_size,]), it is the start position of the tokens in each batch
|
||||
# - when it is an int, the start position are the same for all batches
|
||||
tokens_position: torch.Tensor | int
|
||||
image_embedding: MaskedEmbedding | None = None
|
||||
@dataclass
|
||||
class LLMInput: # type: ignore
|
||||
tokens: Any
|
||||
images: Any = None
|
||||
|
||||
@dataclass
|
||||
class TransformerInput: # type: ignore
|
||||
tokens: Any
|
||||
tokens_position: Any
|
||||
image_embedding: Any = None
|
||||
|
||||
@dataclass
|
||||
class LLMOutput:
|
||||
logits: torch.Tensor
|
||||
@dataclass
|
||||
class LLMOutput: # type: ignore
|
||||
logits: Any
|
||||
|
||||
|
||||
TransformerOutput = LLMOutput
|
||||
|
|
|
@ -19,9 +19,42 @@ LLAMA3_VOCAB_SIZE = 128256
|
|||
|
||||
|
||||
def resolve_model(descriptor: str) -> Model | None:
|
||||
"""Return the canonical `Model` that matches *descriptor*.
|
||||
|
||||
The helper originally accepted the model *descriptor* (e.g. "Llama-4-Scout-17B-16E-Instruct")
|
||||
or the HuggingFace repository path (e.g. "meta-llama/Llama-4-Scout-17B-16E-Instruct").
|
||||
|
||||
Review feedback (see PR #2796) highlighted that callers - especially provider
|
||||
adaptors - were passing provider-qualified aliases such as
|
||||
|
||||
"together/meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
Having provider-specific logic here is undesirable. Instead of hard-coding
|
||||
aliases in *this* file we normalise the incoming descriptor by stripping a
|
||||
leading provider prefix of the form "<provider>/" (e.g. "together/",
|
||||
"groq/", …) *once* and then retry the lookup. This keeps sku_list
|
||||
provider-agnostic while still resolving all legitimate aliases that
|
||||
individual providers register in their own modules.
|
||||
"""
|
||||
|
||||
# Direct match against descriptor or HF repo.
|
||||
for m in all_registered_models():
|
||||
if descriptor in (m.descriptor(), m.huggingface_repo):
|
||||
return m
|
||||
|
||||
# Handle provider-prefixed aliases - strip provider prefix ("together/", "groq/", etc.) if present.
|
||||
if "/" in descriptor:
|
||||
# Many provider aliases look like "<provider>/<repo_path>"; we only need
|
||||
# the repo_path (everything after the first slash) for a successful
|
||||
# lookup. Splitting just once avoids over-stripping repo paths that
|
||||
# legitimately contain more than one component (e.g. "meta-llama/…").
|
||||
_, remainder = descriptor.split("/", 1)
|
||||
# Recursively attempt to resolve the stripped descriptor to avoid code
|
||||
# duplication. The depth here is at most 1 because the second call will
|
||||
# hit the fast path above.
|
||||
if remainder != descriptor: # guard against infinite recursion
|
||||
return resolve_model(remainder)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
|
81
tests/unit/models/test_llama4_import_torch_free.py
Normal file
81
tests/unit/models/test_llama4_import_torch_free.py
Normal file
|
@ -0,0 +1,81 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Test that Llama-4 modules can be imported and used for text-only operations
|
||||
even when torch is not available (addresses issue #2584).
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import importlib
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _block_torch(monkeypatch):
|
||||
"""Block torch imports to simulate torch-free environment."""
|
||||
real_import = builtins.__import__
|
||||
|
||||
def fake_import(name, *args, **kwargs):
|
||||
if name == "torch" or name.startswith("torch."):
|
||||
raise ModuleNotFoundError("No module named 'torch'")
|
||||
return real_import(name, *args, **kwargs)
|
||||
|
||||
sys.modules.pop("torch", None) # forget any cached import
|
||||
monkeypatch.setattr(builtins, "__import__", fake_import)
|
||||
|
||||
|
||||
def test_llama4_chat_format_imports_without_torch(monkeypatch):
|
||||
"""Test that llama4.chat_format can be imported when torch is unavailable."""
|
||||
_block_torch(monkeypatch)
|
||||
|
||||
# This should NOT raise ImportError anymore
|
||||
chat_format_module = importlib.import_module("llama_stack.models.llama.llama4.chat_format")
|
||||
assert chat_format_module is not None
|
||||
|
||||
|
||||
def test_llama4_text_decoding_works_without_torch(monkeypatch):
|
||||
"""Test that text-only tool calling decoding works without torch."""
|
||||
_block_torch(monkeypatch)
|
||||
|
||||
from llama_stack.models.llama.datatypes import StopReason
|
||||
from llama_stack.models.llama.llama4.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer
|
||||
|
||||
# Text-only operations should work fine
|
||||
formatter = ChatFormat(Tokenizer.get_instance())
|
||||
content = '[get_weather(location="SF")]<|eot|>'
|
||||
msg = formatter.decode_assistant_message_from_content(content, StopReason.end_of_turn)
|
||||
|
||||
# Verify tool calling parsing works
|
||||
assert msg.tool_calls, "Tool call should be detected"
|
||||
tc = msg.tool_calls[0]
|
||||
assert tc.tool_name == "get_weather"
|
||||
assert tc.arguments == {"location": "SF"}
|
||||
|
||||
|
||||
def test_llama4_vision_fails_gracefully_without_torch(monkeypatch):
|
||||
"""Test that vision features raise clear error when torch unavailable."""
|
||||
_block_torch(monkeypatch)
|
||||
|
||||
from llama_stack.models.llama.llama4.args import Size, VisionArgs
|
||||
from llama_stack.models.llama.llama4.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer
|
||||
|
||||
# Trying to use vision features should raise helpful error
|
||||
vision_args = VisionArgs(
|
||||
image_size=Size(height=448, width=448),
|
||||
patch_size=Size(height=14, width=14),
|
||||
dim=512,
|
||||
n_layers=6,
|
||||
n_heads=8,
|
||||
mlp_ratio=4.0,
|
||||
output_dim=4096,
|
||||
pixel_shuffle_ratio=2,
|
||||
)
|
||||
with pytest.raises(ImportError, match="vision features require.*torch.*Pillow"):
|
||||
ChatFormat(Tokenizer.get_instance(), vision_args=vision_args)
|
68
tests/unit/models/test_sku_resolve_alias.py
Normal file
68
tests/unit/models/test_sku_resolve_alias.py
Normal file
|
@ -0,0 +1,68 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
|
||||
|
||||
def test_resolve_by_descriptor():
|
||||
"""Test normal resolution by model descriptor."""
|
||||
model = resolve_model("Llama-4-Scout-17B-16E-Instruct")
|
||||
assert model is not None
|
||||
assert model.core_model_id.value == "Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
|
||||
def test_resolve_by_huggingface_repo():
|
||||
"""Test normal resolution by HuggingFace repo path."""
|
||||
model = resolve_model("meta-llama/Llama-4-Scout-17B-16E-Instruct")
|
||||
assert model is not None
|
||||
assert model.core_model_id.value == "Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
|
||||
def test_together_alias_resolves():
|
||||
"""Test that Together-prefixed alias resolves via generic prefix stripping."""
|
||||
alias = "together/meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
model = resolve_model(alias)
|
||||
assert model is not None, f"Model should resolve for alias {alias}"
|
||||
assert model.core_model_id.value == "Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
|
||||
def test_groq_alias_resolves():
|
||||
"""Test that Groq-prefixed alias resolves via generic prefix stripping."""
|
||||
alias = "groq/meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
model = resolve_model(alias)
|
||||
assert model is not None, f"Model should resolve for alias {alias}"
|
||||
assert model.core_model_id.value == "Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
|
||||
def test_unknown_model_returns_none():
|
||||
"""Test that unknown model descriptors return None."""
|
||||
model = resolve_model("nonexistent-model")
|
||||
assert model is None
|
||||
|
||||
|
||||
def test_unknown_provider_prefix_returns_none():
|
||||
"""Test that unknown provider prefix with unknown model returns None."""
|
||||
model = resolve_model("unknown-provider/nonexistent-model")
|
||||
assert model is None
|
||||
|
||||
|
||||
def test_empty_string_returns_none():
|
||||
"""Test that empty string returns None."""
|
||||
model = resolve_model("")
|
||||
assert model is None
|
||||
|
||||
|
||||
def test_slash_only_returns_none():
|
||||
"""Test that just a slash returns None."""
|
||||
model = resolve_model("/")
|
||||
assert model is None
|
||||
|
||||
|
||||
def test_multiple_slashes_handled():
|
||||
"""Test that paths with multiple slashes are handled correctly."""
|
||||
# This should strip "provider/" and try "path/to/model"
|
||||
model = resolve_model("provider/path/to/model")
|
||||
assert model is None # Should be None since "path/to/model" doesn't exist
|
Loading…
Add table
Add a link
Reference in a new issue