diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index ead1331f3..2a565c93c 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -36,6 +36,11 @@ LLAMA_STACK_RUN_CONFIG_VERSION = 2 RoutingKey = str | list[str] +class RegistryEntrySource(StrEnum): + default = "default" + provider = "provider" + + class User(BaseModel): principal: str # further attributes that may be used for access control decisions @@ -50,6 +55,7 @@ class ResourceWithOwner(Resource): resource. This can be used to constrain access to the resource.""" owner: User | None = None + source: RegistryEntrySource = RegistryEntrySource.default # Use the extended Resource for all routable objects diff --git a/llama_stack/distribution/routing_tables/common.py b/llama_stack/distribution/routing_tables/common.py index 2f6ac90bb..421e4162b 100644 --- a/llama_stack/distribution/routing_tables/common.py +++ b/llama_stack/distribution/routing_tables/common.py @@ -206,7 +206,6 @@ class CommonRoutingTableImpl(RoutingTable): if obj.type == ResourceType.model.value: await self.dist_registry.register(registered_obj) return registered_obj - else: await self.dist_registry.register(obj) return obj diff --git a/llama_stack/distribution/routing_tables/models.py b/llama_stack/distribution/routing_tables/models.py index 1454bf45f..7bbc1697f 100644 --- a/llama_stack/distribution/routing_tables/models.py +++ b/llama_stack/distribution/routing_tables/models.py @@ -11,6 +11,7 @@ from typing import Any from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel from llama_stack.distribution.datatypes import ( ModelWithOwner, + RegistryEntrySource, ) from llama_stack.log import get_logger @@ -65,7 +66,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): if models is None: continue - await self.update_registered_llm_models(provider_id, models) + await self.update_registered_models(provider_id, models) await asyncio.sleep(self.model_refresh_interval_seconds) @@ -131,6 +132,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): provider_id=provider_id, metadata=metadata, model_type=model_type, + source=RegistryEntrySource.default, ) registered_model = await self.register_object(model) return registered_model @@ -141,7 +143,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): raise ValueError(f"Model {model_id} not found") await self.unregister_object(existing_model) - async def update_registered_llm_models( + async def update_registered_models( self, provider_id: str, models: list[Model], @@ -152,18 +154,19 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): # from run.yaml) that we need to keep track of model_ids = {} for model in existing_models: - # we leave embeddings models alone because often we don't get metadata - # (embedding dimension, etc.) from the provider - if model.provider_id == provider_id and model.model_type == ModelType.llm: + if model.provider_id != provider_id: + continue + if model.source == RegistryEntrySource.default: model_ids[model.provider_resource_id] = model.identifier - logger.debug(f"unregistering model {model.identifier}") - await self.unregister_object(model) + continue + + logger.debug(f"unregistering model {model.identifier}") + await self.unregister_object(model) for model in models: - if model.model_type != ModelType.llm: - continue if model.provider_resource_id in model_ids: - model.identifier = model_ids[model.provider_resource_id] + # avoid overwriting a non-provider-registered model entry + continue logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})") await self.register_object( @@ -173,5 +176,6 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): provider_id=provider_id, metadata=model.metadata, model_type=model.model_type, + source=RegistryEntrySource.provider, ) ) diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index 0beecd2c4..95b99c65e 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -20,7 +20,7 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) -from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate +from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate, ModelType from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, ) @@ -41,6 +41,8 @@ class SentenceTransformersInferenceImpl( InferenceProvider, ModelsProtocolPrivate, ): + __provider_id__: str + def __init__(self, config: SentenceTransformersInferenceConfig) -> None: self.config = config @@ -54,8 +56,17 @@ class SentenceTransformersInferenceImpl( return False async def list_models(self) -> list[Model] | None: - # TODO: add all-mini-lm models - return None + return [ + Model( + identifier="all-MiniLM-L6-v2", + provider_resource_id="all-MiniLM-L6-v2", + provider_id=self.__provider_id__, + metadata={ + "embedding_dimension": 384, + }, + model_type=ModelType.embedding, + ), + ] async def register_model(self, model: Model) -> Model: return model diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index e9a10d0a8..ee0049b2e 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -126,10 +126,44 @@ class OllamaInferenceAdapter( async def list_models(self) -> list[Model] | None: provider_id = self.__provider_id__ response = await self.client.list() - models = [] + + # always add the two embedding models which can be pulled on demand + models = [ + Model( + identifier="all-minilm:l6-v2", + provider_resource_id="all-minilm:l6-v2", + provider_id=provider_id, + metadata={ + "embedding_dimension": 384, + "context_length": 512, + }, + model_type=ModelType.embedding, + ), + # add all-minilm alias + Model( + identifier="all-minilm", + provider_resource_id="all-minilm:l6-v2", + provider_id=provider_id, + metadata={ + "embedding_dimension": 384, + "context_length": 512, + }, + model_type=ModelType.embedding, + ), + Model( + identifier="nomic-embed-text", + provider_resource_id="nomic-embed-text", + provider_id=provider_id, + metadata={ + "embedding_dimension": 768, + "context_length": 8192, + }, + model_type=ModelType.embedding, + ), + ] for m in response.models: - model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm - if model_type == ModelType.embedding: + # kill embedding models since we don't know dimensions for them + if m.details.family in ["bert"]: continue models.append( Model( @@ -137,7 +171,7 @@ class OllamaInferenceAdapter( provider_resource_id=m.model, provider_id=provider_id, metadata={}, - model_type=model_type, + model_type=ModelType.llm, ) ) return models diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index 12b05ebff..fd1a7462f 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -15,6 +15,7 @@ from llama_stack.apis.models import Model, ModelType from llama_stack.apis.shields.shields import Shield from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter from llama_stack.apis.vector_dbs import VectorDB +from llama_stack.distribution.datatypes import RegistryEntrySource from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable from llama_stack.distribution.routing_tables.models import ModelsRoutingTable @@ -45,6 +46,30 @@ class InferenceImpl(Impl): async def unregister_model(self, model_id: str): return model_id + async def should_refresh_models(self): + return False + + async def list_models(self): + return [ + Model( + identifier="provider-model-1", + provider_resource_id="provider-model-1", + provider_id="test_provider", + metadata={}, + model_type=ModelType.llm, + ), + Model( + identifier="provider-model-2", + provider_resource_id="provider-model-2", + provider_id="test_provider", + metadata={"embedding_dimension": 512}, + model_type=ModelType.embedding, + ), + ] + + async def shutdown(self): + pass + class SafetyImpl(Impl): def __init__(self): @@ -378,3 +403,170 @@ async def test_models_fallback_lookup_behavior(cached_disk_dist_registry): raise AssertionError("Should have raised ValueError for non-existent model") except ValueError as e: assert "not found" in str(e) + + +async def test_models_source_tracking_default(cached_disk_dist_registry): + """Test that models registered via register_model get default source.""" + table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + # Register model via register_model (should get default source) + await table.register_model(model_id="user-model", provider_id="test_provider") + + models = await table.list_models() + assert len(models.data) == 1 + model = models.data[0] + assert model.source == RegistryEntrySource.default + assert model.identifier == "test_provider/user-model" + + # Cleanup + await table.shutdown() + + +async def test_models_source_tracking_provider(cached_disk_dist_registry): + """Test that models registered via update_registered_models get provider source.""" + table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + # Simulate provider refresh by calling update_registered_models + provider_models = [ + Model( + identifier="provider-model-1", + provider_resource_id="provider-model-1", + provider_id="test_provider", + metadata={}, + model_type=ModelType.llm, + ), + Model( + identifier="provider-model-2", + provider_resource_id="provider-model-2", + provider_id="test_provider", + metadata={"embedding_dimension": 512}, + model_type=ModelType.embedding, + ), + ] + await table.update_registered_models("test_provider", provider_models) + + models = await table.list_models() + assert len(models.data) == 2 + + # All models should have provider source + for model in models.data: + assert model.source == RegistryEntrySource.provider + assert model.provider_id == "test_provider" + + # Cleanup + await table.shutdown() + + +async def test_models_source_interaction_preserves_default(cached_disk_dist_registry): + """Test that provider refresh preserves user-registered models with default source.""" + table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + # First register a user model with same provider_resource_id as provider will later provide + await table.register_model( + model_id="my-custom-alias", provider_model_id="provider-model-1", provider_id="test_provider" + ) + + # Verify user model is registered with default source + models = await table.list_models() + assert len(models.data) == 1 + user_model = models.data[0] + assert user_model.source == RegistryEntrySource.default + assert user_model.identifier == "my-custom-alias" + assert user_model.provider_resource_id == "provider-model-1" + + # Now simulate provider refresh + provider_models = [ + Model( + identifier="provider-model-1", + provider_resource_id="provider-model-1", + provider_id="test_provider", + metadata={}, + model_type=ModelType.llm, + ), + Model( + identifier="different-model", + provider_resource_id="different-model", + provider_id="test_provider", + metadata={}, + model_type=ModelType.llm, + ), + ] + await table.update_registered_models("test_provider", provider_models) + + # Verify user model with alias is preserved, but provider added new model + models = await table.list_models() + assert len(models.data) == 2 + + # Find the user model and provider model + user_model = next((m for m in models.data if m.identifier == "my-custom-alias"), None) + provider_model = next((m for m in models.data if m.identifier == "different-model"), None) + + assert user_model is not None + assert user_model.source == RegistryEntrySource.default + assert user_model.provider_resource_id == "provider-model-1" + + assert provider_model is not None + assert provider_model.source == RegistryEntrySource.provider + assert provider_model.provider_resource_id == "different-model" + + # Cleanup + await table.shutdown() + + +async def test_models_source_interaction_cleanup_provider_models(cached_disk_dist_registry): + """Test that provider refresh removes old provider models but keeps default ones.""" + table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + # Register a user model + await table.register_model(model_id="user-model", provider_id="test_provider") + + # Add some provider models + provider_models_v1 = [ + Model( + identifier="provider-model-old", + provider_resource_id="provider-model-old", + provider_id="test_provider", + metadata={}, + model_type=ModelType.llm, + ), + ] + await table.update_registered_models("test_provider", provider_models_v1) + + # Verify we have both user and provider models + models = await table.list_models() + assert len(models.data) == 2 + + # Now update with new provider models (should remove old provider models) + provider_models_v2 = [ + Model( + identifier="provider-model-new", + provider_resource_id="provider-model-new", + provider_id="test_provider", + metadata={}, + model_type=ModelType.llm, + ), + ] + await table.update_registered_models("test_provider", provider_models_v2) + + # Should have user model + new provider model, old provider model gone + models = await table.list_models() + assert len(models.data) == 2 + + identifiers = {m.identifier for m in models.data} + assert "test_provider/user-model" in identifiers # User model preserved + assert "provider-model-new" in identifiers # New provider model (uses provider's identifier) + assert "provider-model-old" not in identifiers # Old provider model removed + + # Verify sources are correct + user_model = next((m for m in models.data if m.identifier == "test_provider/user-model"), None) + provider_model = next((m for m in models.data if m.identifier == "provider-model-new"), None) + + assert user_model.source == RegistryEntrySource.default + assert provider_model.source == RegistryEntrySource.provider + + # Cleanup + await table.shutdown()