fix(models): always prefix models with provider_id when registering

This commit is contained in:
Ashwin Bharambe 2025-10-15 21:31:15 -07:00
parent f205ab6f6c
commit d8be3111db
6 changed files with 13 additions and 73 deletions

View file

@ -117,42 +117,24 @@ def client_with_models(
text_model_id,
vision_model_id,
embedding_model_id,
embedding_dimension,
judge_model_id,
):
client = llama_stack_client
providers = [p for p in client.providers.list() if p.api == "inference"]
assert len(providers) > 0, "No inference providers found"
inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"]
model_ids = {m.identifier for m in client.models.list()}
model_ids.update(m.provider_resource_id for m in client.models.list())
# TODO: fix this crap where we use the first provider randomly
# that cannot be right. I think the test should just specify the provider_id
if text_model_id and text_model_id not in model_ids:
client.models.register(model_id=text_model_id, provider_id=inference_providers[0])
raise ValueError(f"text_model_id {text_model_id} not found")
if vision_model_id and vision_model_id not in model_ids:
client.models.register(model_id=vision_model_id, provider_id=inference_providers[0])
raise ValueError(f"vision_model_id {vision_model_id} not found")
if judge_model_id and judge_model_id not in model_ids:
client.models.register(model_id=judge_model_id, provider_id=inference_providers[0])
raise ValueError(f"judge_model_id {judge_model_id} not found")
if embedding_model_id and embedding_model_id not in model_ids:
# try to find a provider that supports embeddings, if sentence-transformers is not available
selected_provider = None
for p in providers:
if p.provider_type == "inline::sentence-transformers":
selected_provider = p
break
selected_provider = selected_provider or providers[0]
client.models.register(
model_id=embedding_model_id,
provider_id=selected_provider.provider_id,
model_type="embedding",
metadata={"embedding_dimension": embedding_dimension or 768},
)
raise ValueError(f"embedding_model_id {embedding_model_id} not found")
return client