mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 06:28:50 +00:00
refactor: make sku_list resolve provider aliases generically
- Replace Together-specific import logic with generic provider prefix stripping - Addresses PR #2796 review feedback about provider-specific code in core modules - Add comprehensive test suite covering normal cases, provider prefixes, and edge cases - Maintains backward compatibility while keeping sku_list provider-agnostic
This commit is contained in:
parent
427136bb63
commit
3d43e143d2
2 changed files with 101 additions and 0 deletions
|
@ -19,9 +19,42 @@ LLAMA3_VOCAB_SIZE = 128256
|
|||
|
||||
|
||||
def resolve_model(descriptor: str) -> Model | None:
|
||||
"""Return the canonical `Model` that matches *descriptor*.
|
||||
|
||||
The helper originally accepted the model *descriptor* (e.g. "Llama-4-Scout-17B-16E-Instruct")
|
||||
or the HuggingFace repository path (e.g. "meta-llama/Llama-4-Scout-17B-16E-Instruct").
|
||||
|
||||
Review feedback (see PR #2796) highlighted that callers - especially provider
|
||||
adaptors - were passing provider-qualified aliases such as
|
||||
|
||||
"together/meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
Having provider-specific logic here is undesirable. Instead of hard-coding
|
||||
aliases in *this* file we normalise the incoming descriptor by stripping a
|
||||
leading provider prefix of the form "<provider>/" (e.g. "together/",
|
||||
"groq/", …) *once* and then retry the lookup. This keeps sku_list
|
||||
provider-agnostic while still resolving all legitimate aliases that
|
||||
individual providers register in their own modules.
|
||||
"""
|
||||
|
||||
# Direct match against descriptor or HF repo.
|
||||
for m in all_registered_models():
|
||||
if descriptor in (m.descriptor(), m.huggingface_repo):
|
||||
return m
|
||||
|
||||
# Handle provider-prefixed aliases - strip provider prefix ("together/", "groq/", etc.) if present.
|
||||
if "/" in descriptor:
|
||||
# Many provider aliases look like "<provider>/<repo_path>"; we only need
|
||||
# the repo_path (everything after the first slash) for a successful
|
||||
# lookup. Splitting just once avoids over-stripping repo paths that
|
||||
# legitimately contain more than one component (e.g. "meta-llama/…").
|
||||
_, remainder = descriptor.split("/", 1)
|
||||
# Recursively attempt to resolve the stripped descriptor to avoid code
|
||||
# duplication. The depth here is at most 1 because the second call will
|
||||
# hit the fast path above.
|
||||
if remainder != descriptor: # guard against infinite recursion
|
||||
return resolve_model(remainder)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
|
68
tests/unit/models/test_sku_resolve_alias.py
Normal file
68
tests/unit/models/test_sku_resolve_alias.py
Normal file
|
@ -0,0 +1,68 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
|
||||
|
||||
def test_resolve_by_descriptor():
|
||||
"""Test normal resolution by model descriptor."""
|
||||
model = resolve_model("Llama-4-Scout-17B-16E-Instruct")
|
||||
assert model is not None
|
||||
assert model.core_model_id.value == "Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
|
||||
def test_resolve_by_huggingface_repo():
|
||||
"""Test normal resolution by HuggingFace repo path."""
|
||||
model = resolve_model("meta-llama/Llama-4-Scout-17B-16E-Instruct")
|
||||
assert model is not None
|
||||
assert model.core_model_id.value == "Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
|
||||
def test_together_alias_resolves():
|
||||
"""Test that Together-prefixed alias resolves via generic prefix stripping."""
|
||||
alias = "together/meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
model = resolve_model(alias)
|
||||
assert model is not None, f"Model should resolve for alias {alias}"
|
||||
assert model.core_model_id.value == "Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
|
||||
def test_groq_alias_resolves():
|
||||
"""Test that Groq-prefixed alias resolves via generic prefix stripping."""
|
||||
alias = "groq/meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
model = resolve_model(alias)
|
||||
assert model is not None, f"Model should resolve for alias {alias}"
|
||||
assert model.core_model_id.value == "Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
|
||||
def test_unknown_model_returns_none():
|
||||
"""Test that unknown model descriptors return None."""
|
||||
model = resolve_model("nonexistent-model")
|
||||
assert model is None
|
||||
|
||||
|
||||
def test_unknown_provider_prefix_returns_none():
|
||||
"""Test that unknown provider prefix with unknown model returns None."""
|
||||
model = resolve_model("unknown-provider/nonexistent-model")
|
||||
assert model is None
|
||||
|
||||
|
||||
def test_empty_string_returns_none():
|
||||
"""Test that empty string returns None."""
|
||||
model = resolve_model("")
|
||||
assert model is None
|
||||
|
||||
|
||||
def test_slash_only_returns_none():
|
||||
"""Test that just a slash returns None."""
|
||||
model = resolve_model("/")
|
||||
assert model is None
|
||||
|
||||
|
||||
def test_multiple_slashes_handled():
|
||||
"""Test that paths with multiple slashes are handled correctly."""
|
||||
# This should strip "provider/" and try "path/to/model"
|
||||
model = resolve_model("provider/path/to/model")
|
||||
assert model is None # Should be None since "path/to/model" doesn't exist
|
Loading…
Add table
Add a link
Reference in a new issue