refactor: make sku_list resolve provider aliases generically

- Replace Together-specific import logic with generic provider prefix stripping
- Addresses PR #2796 review feedback about provider-specific code in core modules
- Add comprehensive test suite covering normal cases, provider prefixes, and edge cases
- Maintains backward compatibility while keeping sku_list provider-agnostic
This commit is contained in:
skamenan7 2025-07-21 14:12:24 -04:00 committed by Sumanth Kamenani
parent 427136bb63
commit 3d43e143d2
2 changed files with 101 additions and 0 deletions

View file

@ -19,9 +19,42 @@ LLAMA3_VOCAB_SIZE = 128256
def resolve_model(descriptor: str) -> Model | None:
"""Return the canonical `Model` that matches *descriptor*.
The helper originally accepted the model *descriptor* (e.g. "Llama-4-Scout-17B-16E-Instruct")
or the HuggingFace repository path (e.g. "meta-llama/Llama-4-Scout-17B-16E-Instruct").
Review feedback (see PR #2796) highlighted that callers - especially provider
adaptors - were passing provider-qualified aliases such as
"together/meta-llama/Llama-4-Scout-17B-16E-Instruct"
Having provider-specific logic here is undesirable. Instead of hard-coding
aliases in *this* file we normalise the incoming descriptor by stripping a
leading provider prefix of the form "<provider>/" (e.g. "together/",
"groq/", ) *once* and then retry the lookup. This keeps sku_list
provider-agnostic while still resolving all legitimate aliases that
individual providers register in their own modules.
"""
# Direct match against descriptor or HF repo.
for m in all_registered_models():
if descriptor in (m.descriptor(), m.huggingface_repo):
return m
# Handle provider-prefixed aliases - strip provider prefix ("together/", "groq/", etc.) if present.
if "/" in descriptor:
# Many provider aliases look like "<provider>/<repo_path>"; we only need
# the repo_path (everything after the first slash) for a successful
# lookup. Splitting just once avoids over-stripping repo paths that
# legitimately contain more than one component (e.g. "meta-llama/…").
_, remainder = descriptor.split("/", 1)
# Recursively attempt to resolve the stripped descriptor to avoid code
# duplication. The depth here is at most 1 because the second call will
# hit the fast path above.
if remainder != descriptor: # guard against infinite recursion
return resolve_model(remainder)
return None