mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-25 21:57:45 +00:00
feat(registry): make the Stack query providers for model listing (#2862)
This flips #2823 and #2805 by making the Stack periodically query the providers for models rather than the providers going behind the back and calling "register" on to the registry themselves. This also adds support for model listing for all other providers via `ModelRegistryHelper`. Once this is done, we do not need to manually list or register models via `run.yaml` and it will remove both noise and annoyance (setting `INFERENCE_MODEL` environment variables, for example) from the new user experience. In addition, it adds a configuration variable `allowed_models` which can be used to optionally restrict the set of models exposed from a provider.
This commit is contained in:
parent
537dc693ee
commit
1463b79218
23 changed files with 429 additions and 147 deletions
|
@ -15,6 +15,7 @@ from llama_stack.apis.models import Model, ModelType
|
|||
from llama_stack.apis.shields.shields import Shield
|
||||
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.distribution.datatypes import RegistryEntrySource
|
||||
from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable
|
||||
from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable
|
||||
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
|
||||
|
@ -45,6 +46,30 @@ class InferenceImpl(Impl):
|
|||
async def unregister_model(self, model_id: str):
|
||||
return model_id
|
||||
|
||||
async def should_refresh_models(self):
|
||||
return False
|
||||
|
||||
async def list_models(self):
|
||||
return [
|
||||
Model(
|
||||
identifier="provider-model-1",
|
||||
provider_resource_id="provider-model-1",
|
||||
provider_id="test_provider",
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
Model(
|
||||
identifier="provider-model-2",
|
||||
provider_resource_id="provider-model-2",
|
||||
provider_id="test_provider",
|
||||
metadata={"embedding_dimension": 512},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
]
|
||||
|
||||
async def shutdown(self):
|
||||
pass
|
||||
|
||||
|
||||
class SafetyImpl(Impl):
|
||||
def __init__(self):
|
||||
|
@ -378,3 +403,170 @@ async def test_models_fallback_lookup_behavior(cached_disk_dist_registry):
|
|||
raise AssertionError("Should have raised ValueError for non-existent model")
|
||||
except ValueError as e:
|
||||
assert "not found" in str(e)
|
||||
|
||||
|
||||
async def test_models_source_tracking_default(cached_disk_dist_registry):
|
||||
"""Test that models registered via register_model get default source."""
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register model via register_model (should get default source)
|
||||
await table.register_model(model_id="user-model", provider_id="test_provider")
|
||||
|
||||
models = await table.list_models()
|
||||
assert len(models.data) == 1
|
||||
model = models.data[0]
|
||||
assert model.source == RegistryEntrySource.via_register_api
|
||||
assert model.identifier == "test_provider/user-model"
|
||||
|
||||
# Cleanup
|
||||
await table.shutdown()
|
||||
|
||||
|
||||
async def test_models_source_tracking_provider(cached_disk_dist_registry):
|
||||
"""Test that models registered via update_registered_models get provider source."""
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Simulate provider refresh by calling update_registered_models
|
||||
provider_models = [
|
||||
Model(
|
||||
identifier="provider-model-1",
|
||||
provider_resource_id="provider-model-1",
|
||||
provider_id="test_provider",
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
Model(
|
||||
identifier="provider-model-2",
|
||||
provider_resource_id="provider-model-2",
|
||||
provider_id="test_provider",
|
||||
metadata={"embedding_dimension": 512},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
]
|
||||
await table.update_registered_models("test_provider", provider_models)
|
||||
|
||||
models = await table.list_models()
|
||||
assert len(models.data) == 2
|
||||
|
||||
# All models should have provider source
|
||||
for model in models.data:
|
||||
assert model.source == RegistryEntrySource.listed_from_provider
|
||||
assert model.provider_id == "test_provider"
|
||||
|
||||
# Cleanup
|
||||
await table.shutdown()
|
||||
|
||||
|
||||
async def test_models_source_interaction_preserves_default(cached_disk_dist_registry):
|
||||
"""Test that provider refresh preserves user-registered models with default source."""
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# First register a user model with same provider_resource_id as provider will later provide
|
||||
await table.register_model(
|
||||
model_id="my-custom-alias", provider_model_id="provider-model-1", provider_id="test_provider"
|
||||
)
|
||||
|
||||
# Verify user model is registered with default source
|
||||
models = await table.list_models()
|
||||
assert len(models.data) == 1
|
||||
user_model = models.data[0]
|
||||
assert user_model.source == RegistryEntrySource.via_register_api
|
||||
assert user_model.identifier == "my-custom-alias"
|
||||
assert user_model.provider_resource_id == "provider-model-1"
|
||||
|
||||
# Now simulate provider refresh
|
||||
provider_models = [
|
||||
Model(
|
||||
identifier="provider-model-1",
|
||||
provider_resource_id="provider-model-1",
|
||||
provider_id="test_provider",
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
Model(
|
||||
identifier="different-model",
|
||||
provider_resource_id="different-model",
|
||||
provider_id="test_provider",
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
]
|
||||
await table.update_registered_models("test_provider", provider_models)
|
||||
|
||||
# Verify user model with alias is preserved, but provider added new model
|
||||
models = await table.list_models()
|
||||
assert len(models.data) == 2
|
||||
|
||||
# Find the user model and provider model
|
||||
user_model = next((m for m in models.data if m.identifier == "my-custom-alias"), None)
|
||||
provider_model = next((m for m in models.data if m.identifier == "different-model"), None)
|
||||
|
||||
assert user_model is not None
|
||||
assert user_model.source == RegistryEntrySource.via_register_api
|
||||
assert user_model.provider_resource_id == "provider-model-1"
|
||||
|
||||
assert provider_model is not None
|
||||
assert provider_model.source == RegistryEntrySource.listed_from_provider
|
||||
assert provider_model.provider_resource_id == "different-model"
|
||||
|
||||
# Cleanup
|
||||
await table.shutdown()
|
||||
|
||||
|
||||
async def test_models_source_interaction_cleanup_provider_models(cached_disk_dist_registry):
|
||||
"""Test that provider refresh removes old provider models but keeps default ones."""
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register a user model
|
||||
await table.register_model(model_id="user-model", provider_id="test_provider")
|
||||
|
||||
# Add some provider models
|
||||
provider_models_v1 = [
|
||||
Model(
|
||||
identifier="provider-model-old",
|
||||
provider_resource_id="provider-model-old",
|
||||
provider_id="test_provider",
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
]
|
||||
await table.update_registered_models("test_provider", provider_models_v1)
|
||||
|
||||
# Verify we have both user and provider models
|
||||
models = await table.list_models()
|
||||
assert len(models.data) == 2
|
||||
|
||||
# Now update with new provider models (should remove old provider models)
|
||||
provider_models_v2 = [
|
||||
Model(
|
||||
identifier="provider-model-new",
|
||||
provider_resource_id="provider-model-new",
|
||||
provider_id="test_provider",
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
]
|
||||
await table.update_registered_models("test_provider", provider_models_v2)
|
||||
|
||||
# Should have user model + new provider model, old provider model gone
|
||||
models = await table.list_models()
|
||||
assert len(models.data) == 2
|
||||
|
||||
identifiers = {m.identifier for m in models.data}
|
||||
assert "test_provider/user-model" in identifiers # User model preserved
|
||||
assert "provider-model-new" in identifiers # New provider model (uses provider's identifier)
|
||||
assert "provider-model-old" not in identifiers # Old provider model removed
|
||||
|
||||
# Verify sources are correct
|
||||
user_model = next((m for m in models.data if m.identifier == "test_provider/user-model"), None)
|
||||
provider_model = next((m for m in models.data if m.identifier == "provider-model-new"), None)
|
||||
|
||||
assert user_model.source == RegistryEntrySource.via_register_api
|
||||
assert provider_model.source == RegistryEntrySource.listed_from_provider
|
||||
|
||||
# Cleanup
|
||||
await table.shutdown()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue