From 3d43e143d213786459a626275323cd5e7da12029 Mon Sep 17 00:00:00 2001 From: skamenan7 Date: Mon, 21 Jul 2025 14:12:24 -0400 Subject: [PATCH] refactor: make sku_list resolve provider aliases generically - Replace Together-specific import logic with generic provider prefix stripping - Addresses PR #2796 review feedback about provider-specific code in core modules - Add comprehensive test suite covering normal cases, provider prefixes, and edge cases - Maintains backward compatibility while keeping sku_list provider-agnostic --- llama_stack/models/llama/sku_list.py | 33 ++++++++++ tests/unit/models/test_sku_resolve_alias.py | 68 +++++++++++++++++++++ 2 files changed, 101 insertions(+) create mode 100644 tests/unit/models/test_sku_resolve_alias.py diff --git a/llama_stack/models/llama/sku_list.py b/llama_stack/models/llama/sku_list.py index 271cec63f..7b3b18f85 100644 --- a/llama_stack/models/llama/sku_list.py +++ b/llama_stack/models/llama/sku_list.py @@ -19,9 +19,42 @@ LLAMA3_VOCAB_SIZE = 128256 def resolve_model(descriptor: str) -> Model | None: + """Return the canonical `Model` that matches *descriptor*. + + The helper originally accepted the model *descriptor* (e.g. "Llama-4-Scout-17B-16E-Instruct") + or the HuggingFace repository path (e.g. "meta-llama/Llama-4-Scout-17B-16E-Instruct"). + + Review feedback (see PR #2796) highlighted that callers - especially provider + adaptors - were passing provider-qualified aliases such as + + "together/meta-llama/Llama-4-Scout-17B-16E-Instruct" + + Having provider-specific logic here is undesirable. Instead of hard-coding + aliases in *this* file we normalise the incoming descriptor by stripping a + leading provider prefix of the form "/" (e.g. "together/", + "groq/", …) *once* and then retry the lookup. This keeps sku_list + provider-agnostic while still resolving all legitimate aliases that + individual providers register in their own modules. + """ + + # Direct match against descriptor or HF repo. for m in all_registered_models(): if descriptor in (m.descriptor(), m.huggingface_repo): return m + + # Handle provider-prefixed aliases - strip provider prefix ("together/", "groq/", etc.) if present. + if "/" in descriptor: + # Many provider aliases look like "/"; we only need + # the repo_path (everything after the first slash) for a successful + # lookup. Splitting just once avoids over-stripping repo paths that + # legitimately contain more than one component (e.g. "meta-llama/…"). + _, remainder = descriptor.split("/", 1) + # Recursively attempt to resolve the stripped descriptor to avoid code + # duplication. The depth here is at most 1 because the second call will + # hit the fast path above. + if remainder != descriptor: # guard against infinite recursion + return resolve_model(remainder) + return None diff --git a/tests/unit/models/test_sku_resolve_alias.py b/tests/unit/models/test_sku_resolve_alias.py new file mode 100644 index 000000000..f2cfa2bee --- /dev/null +++ b/tests/unit/models/test_sku_resolve_alias.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.models.llama.sku_list import resolve_model + + +def test_resolve_by_descriptor(): + """Test normal resolution by model descriptor.""" + model = resolve_model("Llama-4-Scout-17B-16E-Instruct") + assert model is not None + assert model.core_model_id.value == "Llama-4-Scout-17B-16E-Instruct" + + +def test_resolve_by_huggingface_repo(): + """Test normal resolution by HuggingFace repo path.""" + model = resolve_model("meta-llama/Llama-4-Scout-17B-16E-Instruct") + assert model is not None + assert model.core_model_id.value == "Llama-4-Scout-17B-16E-Instruct" + + +def test_together_alias_resolves(): + """Test that Together-prefixed alias resolves via generic prefix stripping.""" + alias = "together/meta-llama/Llama-4-Scout-17B-16E-Instruct" + model = resolve_model(alias) + assert model is not None, f"Model should resolve for alias {alias}" + assert model.core_model_id.value == "Llama-4-Scout-17B-16E-Instruct" + + +def test_groq_alias_resolves(): + """Test that Groq-prefixed alias resolves via generic prefix stripping.""" + alias = "groq/meta-llama/Llama-4-Scout-17B-16E-Instruct" + model = resolve_model(alias) + assert model is not None, f"Model should resolve for alias {alias}" + assert model.core_model_id.value == "Llama-4-Scout-17B-16E-Instruct" + + +def test_unknown_model_returns_none(): + """Test that unknown model descriptors return None.""" + model = resolve_model("nonexistent-model") + assert model is None + + +def test_unknown_provider_prefix_returns_none(): + """Test that unknown provider prefix with unknown model returns None.""" + model = resolve_model("unknown-provider/nonexistent-model") + assert model is None + + +def test_empty_string_returns_none(): + """Test that empty string returns None.""" + model = resolve_model("") + assert model is None + + +def test_slash_only_returns_none(): + """Test that just a slash returns None.""" + model = resolve_model("/") + assert model is None + + +def test_multiple_slashes_handled(): + """Test that paths with multiple slashes are handled correctly.""" + # This should strip "provider/" and try "path/to/model" + model = resolve_model("provider/path/to/model") + assert model is None # Should be None since "path/to/model" doesn't exist