mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 06:28:50 +00:00
Merge 561912064c
into 632cf9eb72
This commit is contained in:
commit
87873fc1d2
5 changed files with 298 additions and 44 deletions
81
tests/unit/models/test_llama4_import_torch_free.py
Normal file
81
tests/unit/models/test_llama4_import_torch_free.py
Normal file
|
@ -0,0 +1,81 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Test that Llama-4 modules can be imported and used for text-only operations
|
||||
even when torch is not available (addresses issue #2584).
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import importlib
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _block_torch(monkeypatch):
|
||||
"""Block torch imports to simulate torch-free environment."""
|
||||
real_import = builtins.__import__
|
||||
|
||||
def fake_import(name, *args, **kwargs):
|
||||
if name == "torch" or name.startswith("torch."):
|
||||
raise ModuleNotFoundError("No module named 'torch'")
|
||||
return real_import(name, *args, **kwargs)
|
||||
|
||||
sys.modules.pop("torch", None) # forget any cached import
|
||||
monkeypatch.setattr(builtins, "__import__", fake_import)
|
||||
|
||||
|
||||
def test_llama4_chat_format_imports_without_torch(monkeypatch):
|
||||
"""Test that llama4.chat_format can be imported when torch is unavailable."""
|
||||
_block_torch(monkeypatch)
|
||||
|
||||
# This should NOT raise ImportError anymore
|
||||
chat_format_module = importlib.import_module("llama_stack.models.llama.llama4.chat_format")
|
||||
assert chat_format_module is not None
|
||||
|
||||
|
||||
def test_llama4_text_decoding_works_without_torch(monkeypatch):
|
||||
"""Test that text-only tool calling decoding works without torch."""
|
||||
_block_torch(monkeypatch)
|
||||
|
||||
from llama_stack.models.llama.datatypes import StopReason
|
||||
from llama_stack.models.llama.llama4.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer
|
||||
|
||||
# Text-only operations should work fine
|
||||
formatter = ChatFormat(Tokenizer.get_instance())
|
||||
content = '[get_weather(location="SF")]<|eot|>'
|
||||
msg = formatter.decode_assistant_message_from_content(content, StopReason.end_of_turn)
|
||||
|
||||
# Verify tool calling parsing works
|
||||
assert msg.tool_calls, "Tool call should be detected"
|
||||
tc = msg.tool_calls[0]
|
||||
assert tc.tool_name == "get_weather"
|
||||
assert tc.arguments == {"location": "SF"}
|
||||
|
||||
|
||||
def test_llama4_vision_fails_gracefully_without_torch(monkeypatch):
|
||||
"""Test that vision features raise clear error when torch unavailable."""
|
||||
_block_torch(monkeypatch)
|
||||
|
||||
from llama_stack.models.llama.llama4.args import Size, VisionArgs
|
||||
from llama_stack.models.llama.llama4.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer
|
||||
|
||||
# Trying to use vision features should raise helpful error
|
||||
vision_args = VisionArgs(
|
||||
image_size=Size(height=448, width=448),
|
||||
patch_size=Size(height=14, width=14),
|
||||
dim=512,
|
||||
n_layers=6,
|
||||
n_heads=8,
|
||||
mlp_ratio=4.0,
|
||||
output_dim=4096,
|
||||
pixel_shuffle_ratio=2,
|
||||
)
|
||||
with pytest.raises(ImportError, match="vision features require.*torch.*Pillow"):
|
||||
ChatFormat(Tokenizer.get_instance(), vision_args=vision_args)
|
68
tests/unit/models/test_sku_resolve_alias.py
Normal file
68
tests/unit/models/test_sku_resolve_alias.py
Normal file
|
@ -0,0 +1,68 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
|
||||
|
||||
def test_resolve_by_descriptor():
|
||||
"""Test normal resolution by model descriptor."""
|
||||
model = resolve_model("Llama-4-Scout-17B-16E-Instruct")
|
||||
assert model is not None
|
||||
assert model.core_model_id.value == "Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
|
||||
def test_resolve_by_huggingface_repo():
|
||||
"""Test normal resolution by HuggingFace repo path."""
|
||||
model = resolve_model("meta-llama/Llama-4-Scout-17B-16E-Instruct")
|
||||
assert model is not None
|
||||
assert model.core_model_id.value == "Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
|
||||
def test_together_alias_resolves():
|
||||
"""Test that Together-prefixed alias resolves via generic prefix stripping."""
|
||||
alias = "together/meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
model = resolve_model(alias)
|
||||
assert model is not None, f"Model should resolve for alias {alias}"
|
||||
assert model.core_model_id.value == "Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
|
||||
def test_groq_alias_resolves():
|
||||
"""Test that Groq-prefixed alias resolves via generic prefix stripping."""
|
||||
alias = "groq/meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
model = resolve_model(alias)
|
||||
assert model is not None, f"Model should resolve for alias {alias}"
|
||||
assert model.core_model_id.value == "Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
|
||||
def test_unknown_model_returns_none():
|
||||
"""Test that unknown model descriptors return None."""
|
||||
model = resolve_model("nonexistent-model")
|
||||
assert model is None
|
||||
|
||||
|
||||
def test_unknown_provider_prefix_returns_none():
|
||||
"""Test that unknown provider prefix with unknown model returns None."""
|
||||
model = resolve_model("unknown-provider/nonexistent-model")
|
||||
assert model is None
|
||||
|
||||
|
||||
def test_empty_string_returns_none():
|
||||
"""Test that empty string returns None."""
|
||||
model = resolve_model("")
|
||||
assert model is None
|
||||
|
||||
|
||||
def test_slash_only_returns_none():
|
||||
"""Test that just a slash returns None."""
|
||||
model = resolve_model("/")
|
||||
assert model is None
|
||||
|
||||
|
||||
def test_multiple_slashes_handled():
|
||||
"""Test that paths with multiple slashes are handled correctly."""
|
||||
# This should strip "provider/" and try "path/to/model"
|
||||
model = resolve_model("provider/path/to/model")
|
||||
assert model is None # Should be None since "path/to/model" doesn't exist
|
Loading…
Add table
Add a link
Reference in a new issue