diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index 930cf2646..1f963324b 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -40,6 +40,7 @@ RoutingKey = str | list[str] class RegistryEntrySource(StrEnum): via_register_api = "via_register_api" listed_from_provider = "listed_from_provider" + from_config = "from_config" class User(BaseModel): diff --git a/llama_stack/core/routing_tables/common.py b/llama_stack/core/routing_tables/common.py index ca2f3af42..48a6250a5 100644 --- a/llama_stack/core/routing_tables/common.py +++ b/llama_stack/core/routing_tables/common.py @@ -191,7 +191,14 @@ class CommonRoutingTableImpl(RoutingTable): if not is_action_allowed(self.policy, "delete", obj, user): raise AccessDeniedError("delete", obj, user) await self.dist_registry.delete(obj.type, obj.identifier) - await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id]) + + # Only try to unregister from provider if the provider still exists + if obj.provider_id in self.impls_by_provider_id: + await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id]) + else: + logger.debug( + f"Provider {obj.provider_id} no longer exists, skipping provider unregistration for {obj.identifier}" + ) async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: # if provider_id is not specified, pick an arbitrary one from existing entries @@ -254,6 +261,13 @@ async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> if model is not None: return model + # Check from_config models if this is a ModelsRoutingTable + if hasattr(routing_table, "_generate_from_config_models"): + from_config_models = routing_table._generate_from_config_models() + for from_config_model in from_config_models: + if from_config_model.identifier == model_id: + return from_config_model + logger.warning( f"WARNING: model identifier '{model_id}' not found in routing table. Falling back to " "searching in all providers. This is only for backwards compatibility and will stop working " diff --git a/llama_stack/core/routing_tables/models.py b/llama_stack/core/routing_tables/models.py index 641c73c16..9363b6ab5 100644 --- a/llama_stack/core/routing_tables/models.py +++ b/llama_stack/core/routing_tables/models.py @@ -12,6 +12,7 @@ from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType from llama_stack.core.datatypes import ( ModelWithOwner, RegistryEntrySource, + StackRunConfig, ) from llama_stack.log import get_logger @@ -22,6 +23,7 @@ logger = get_logger(name=__name__, category="core::routing_tables") class ModelsRoutingTable(CommonRoutingTableImpl, Models): listed_providers: set[str] = set() + current_run_config: "StackRunConfig | None" = None async def refresh(self) -> None: for provider_id, provider in self.impls_by_provider_id.items(): @@ -74,6 +76,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): provider_id: str | None = None, metadata: dict[str, Any] | None = None, model_type: ModelType | None = None, + source: RegistryEntrySource = RegistryEntrySource.via_register_api, ) -> Model: if provider_id is None: # If provider_id not specified, use the only provider if it supports this model @@ -106,7 +109,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): provider_id=provider_id, metadata=metadata, model_type=model_type, - source=RegistryEntrySource.via_register_api, + source=source, ) registered_model = await self.register_object(model) return registered_model @@ -117,6 +120,61 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): raise ModelNotFoundError(model_id) await self.unregister_object(existing_model) + async def cleanup_disabled_provider_models(self) -> None: + """Remove models from providers that are no longer enabled in the current run config.""" + if not self.current_run_config: + return + + # Get enabled provider IDs from the current run config + enabled_provider_ids = set() + for _api, providers in self.current_run_config.providers.items(): + for provider in providers: + if provider.provider_id and provider.provider_id != "__disabled__": + enabled_provider_ids.add(provider.provider_id) + + # Get all existing models + existing_models = await self.get_all_with_type("model") + + # Find models from disabled providers (excluding user-registered models) + models_to_remove = [] + for model in existing_models: + if model.provider_id not in enabled_provider_ids and model.source != RegistryEntrySource.via_register_api: + models_to_remove.append(model) + + # Remove the models + for model in models_to_remove: + logger.info(f"Removing model {model.identifier} from disabled provider {model.provider_id}") + await self.unregister_object(model) + + async def register_from_config_models(self) -> None: + """Register from_config models from the current run configuration.""" + if not self.current_run_config: + return + + # Register new from_config models (old ones automatically disappear since they're not persisted) + for model_input in self.current_run_config.models: + # Skip models with disabled providers + if not model_input.provider_id or model_input.provider_id == "__disabled__": + continue + + # Generate identifier + if model_input.model_id != (model_input.provider_model_id or model_input.model_id): + identifier = model_input.model_id + else: + identifier = f"{model_input.provider_id}/{model_input.provider_model_id or model_input.model_id}" + + model = ModelWithOwner( + identifier=identifier, + provider_resource_id=model_input.provider_model_id or model_input.model_id, + provider_id=model_input.provider_id, + metadata=model_input.metadata, + model_type=model_input.model_type or ModelType.llm, + source=RegistryEntrySource.from_config, + ) + + # Register the model (will be cached in memory but not persisted to disk) + await self.dist_registry.register(model) + async def update_registered_models( self, provider_id: str, diff --git a/llama_stack/core/stack.py b/llama_stack/core/stack.py index 3e14328a3..1a46be3b6 100644 --- a/llama_stack/core/stack.py +++ b/llama_stack/core/stack.py @@ -99,6 +99,15 @@ TEST_RECORDING_CONTEXT = None async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]): + # Set the run config on the models routing table for generating from_config models + if Api.models in impls: + models_impl = impls[Api.models] + models_impl.current_run_config = run_config + # Clean up models from disabled providers + await models_impl.cleanup_disabled_provider_models() + # Register from_config models + await models_impl.register_from_config_models() + for rsrc, api, register_method, list_method in RESOURCES: objects = getattr(run_config, rsrc) if api not in impls: @@ -116,7 +125,16 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]): # we want to maintain the type information in arguments to method. # instead of method(**obj.model_dump()), which may convert a typed attr to a dict, # we use model_dump() to find all the attrs and then getattr to get the still typed value. - await method(**{k: getattr(obj, k) for k in obj.model_dump().keys()}) + kwargs = {k: getattr(obj, k) for k in obj.model_dump().keys()} + + # Skip registering from_config models since they are registered through the routing table's set_run_config + if rsrc == "models": + logger.debug( + f"Skipping registration of from_config model {obj.model_id} - will be registered through routing table" + ) + continue + + await method(**kwargs) method = getattr(impls[api], list_method) response = await method() diff --git a/llama_stack/core/store/registry.py b/llama_stack/core/store/registry.py index 624dbd176..a2f3dd230 100644 --- a/llama_stack/core/store/registry.py +++ b/llama_stack/core/store/registry.py @@ -100,6 +100,11 @@ class DiskDistributionRegistry(DistributionRegistry): if existing_obj and existing_obj.provider_id == obj.provider_id: return False + # Skip persistence for from_config objects - they should only exist in memory + if hasattr(obj, "source") and obj.source == "from_config": + logger.debug(f"Skipping persistence for from_config object {obj.type}:{obj.identifier}") + return True + await self.kvstore.set( KEY_FORMAT.format(type=obj.type, identifier=obj.identifier), obj.model_dump_json(), diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index 54a9dd72e..d4919fa70 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -18,7 +18,7 @@ from llama_stack.apis.models import Model, ModelType from llama_stack.apis.shields.shields import Shield from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup from llama_stack.apis.vector_dbs import VectorDB -from llama_stack.core.datatypes import RegistryEntrySource +from llama_stack.core.datatypes import RegistryEntrySource, StackRunConfig from llama_stack.core.routing_tables.benchmarks import BenchmarksRoutingTable from llama_stack.core.routing_tables.datasets import DatasetsRoutingTable from llama_stack.core.routing_tables.models import ModelsRoutingTable @@ -538,114 +538,271 @@ async def test_models_source_tracking_provider(cached_disk_dist_registry): await table.shutdown() -async def test_models_source_interaction_preserves_default(cached_disk_dist_registry): - """Test that provider refresh preserves user-registered models with default source.""" - table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) +async def test_models_dynamic_from_config_generation(cached_disk_dist_registry): + """Test that from_config models are generated dynamically from run_config.""" + table = ModelsRoutingTable({}, cached_disk_dist_registry, {}) await table.initialize() - # First register a user model with same provider_resource_id as provider will later provide - await table.register_model( - model_id="my-custom-alias", provider_model_id="provider-model-1", provider_id="test_provider" + # Test that no from_config models are registered when no run_config + all_models = await table.get_all_with_type("model") + from_config_models = [m for m in all_models if m.source == RegistryEntrySource.from_config] + assert len(from_config_models) == 0 + + # Create a run config with from_config models + run_config = StackRunConfig( + image_name="test", + providers={}, + models=[ + { + "model_id": "from_config_model_1", + "provider_id": "test_provider", + "model_type": "llm", + "provider_model_id": "gpt-3.5-turbo", + }, + { + "model_id": "from_config_model_2", + "provider_id": "test_provider", + "model_type": "llm", + "provider_model_id": "gpt-4", + }, + ], ) - # Verify user model is registered with default source - models = await table.list_models() - assert len(models.data) == 1 - user_model = models.data[0] - assert user_model.source == RegistryEntrySource.via_register_api - assert user_model.identifier == "my-custom-alias" - assert user_model.provider_resource_id == "provider-model-1" + # Set the run config + table.current_run_config = run_config + await table.cleanup_disabled_provider_models() + await table.register_from_config_models() - # Now simulate provider refresh - provider_models = [ - Model( - identifier="provider-model-1", - provider_resource_id="provider-model-1", - provider_id="test_provider", - metadata={}, - model_type=ModelType.llm, - ), - Model( - identifier="different-model", - provider_resource_id="different-model", - provider_id="test_provider", - metadata={}, - model_type=ModelType.llm, - ), - ] - await table.update_registered_models("test_provider", provider_models) + # Test that from_config models are registered in the registry + all_models = await table.get_all_with_type("model") + from_config_models = [m for m in all_models if m.source == RegistryEntrySource.from_config] + assert len(from_config_models) == 2 - # Verify user model with alias is preserved, but provider added new model - models = await table.list_models() - assert len(models.data) == 2 + model_identifiers = {m.identifier for m in from_config_models} + assert "from_config_model_1" in model_identifiers + assert "from_config_model_2" in model_identifiers - # Find the user model and provider model - user_model = next((m for m in models.data if m.identifier == "my-custom-alias"), None) - provider_model = next((m for m in models.data if m.identifier == "test_provider/different-model"), None) - - assert user_model is not None - assert user_model.source == RegistryEntrySource.via_register_api - assert user_model.provider_resource_id == "provider-model-1" - - assert provider_model is not None - assert provider_model.source == RegistryEntrySource.listed_from_provider - assert provider_model.provider_resource_id == "different-model" + # Test that from_config models have correct attributes + model_1 = next(m for m in from_config_models if m.identifier == "from_config_model_1") + assert model_1.provider_id == "test_provider" + assert model_1.provider_resource_id == "gpt-3.5-turbo" + assert model_1.model_type == ModelType.llm + assert model_1.source == RegistryEntrySource.from_config # Cleanup await table.shutdown() -async def test_models_source_interaction_cleanup_provider_models(cached_disk_dist_registry): - """Test that provider refresh removes old provider models but keeps default ones.""" - table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) +async def test_models_dynamic_from_config_lookup(cached_disk_dist_registry): + """Test that from_config models can be looked up individually.""" + table = ModelsRoutingTable({}, cached_disk_dist_registry, {}) await table.initialize() - # Register a user model - await table.register_model(model_id="user-model", provider_id="test_provider") + # Create a run config with from_config models + run_config = StackRunConfig( + image_name="test", + providers={}, + models=[ + { + "model_id": "lookup_test_model", + "provider_id": "test_provider", + "model_type": "llm", + "provider_model_id": "gpt-3.5-turbo", + } + ], + ) - # Add some provider models - provider_models_v1 = [ - Model( - identifier="provider-model-old", - provider_resource_id="provider-model-old", - provider_id="test_provider", - metadata={}, - model_type=ModelType.llm, - ), - ] - await table.update_registered_models("test_provider", provider_models_v1) + # Set the run config + table.current_run_config = run_config + await table.cleanup_disabled_provider_models() + await table.register_from_config_models() - # Verify we have both user and provider models + # Test that we can get the from_config model individually + model = await table.get_model("lookup_test_model") + assert model is not None + assert model.identifier == "lookup_test_model" + assert model.provider_id == "test_provider" + assert model.provider_resource_id == "gpt-3.5-turbo" + assert model.source == RegistryEntrySource.from_config + + # Cleanup + await table.shutdown() + + +async def test_models_dynamic_from_config_mixed_with_persistent(cached_disk_dist_registry): + """Test that from_config models work alongside persistent models.""" + table = ModelsRoutingTable({}, cached_disk_dist_registry, {}) + await table.initialize() + + # Create a run config with from_config models + run_config = StackRunConfig( + image_name="test", + providers={}, + models=[ + { + "model_id": "from_config_model", + "provider_id": "test_provider", + "model_type": "llm", + "provider_model_id": "gpt-3.5-turbo", + } + ], + ) + + # Set the run config + table.current_run_config = run_config + await table.cleanup_disabled_provider_models() + await table.register_from_config_models() + + # Test that from_config models are included models = await table.list_models() - assert len(models.data) == 2 + from_config_models = [m for m in models.data if m.source == RegistryEntrySource.from_config] - # Now update with new provider models (should remove old provider models) - provider_models_v2 = [ - Model( - identifier="provider-model-new", - provider_resource_id="provider-model-new", - provider_id="test_provider", - metadata={}, - model_type=ModelType.llm, - ), - ] - await table.update_registered_models("test_provider", provider_models_v2) + assert len(from_config_models) == 1 + assert from_config_models[0].identifier == "from_config_model" - # Should have user model + new provider model, old provider model gone + # Test that we can get the from_config model individually + from_config_model = await table.get_model("from_config_model") + assert from_config_model is not None + assert from_config_model.source == RegistryEntrySource.from_config + + # Cleanup + await table.shutdown() + + +async def test_models_dynamic_from_config_disabled_providers(cached_disk_dist_registry): + """Test that from_config models with disabled providers are skipped.""" + table = ModelsRoutingTable({}, cached_disk_dist_registry, {}) + await table.initialize() + + # Create a run config with disabled provider models + run_config = StackRunConfig( + image_name="test", + providers={}, + models=[ + { + "model_id": "enabled_model", + "provider_id": "test_provider", + "model_type": "llm", + "provider_model_id": "gpt-3.5-turbo", + }, + { + "model_id": "disabled_model", + "provider_id": "__disabled__", + "model_type": "llm", + "provider_model_id": "gpt-4", + }, + ], + ) + + # Set the run config + table.current_run_config = run_config + await table.cleanup_disabled_provider_models() + await table.register_from_config_models() + + # Test that only enabled models are included + all_models = await table.get_all_with_type("model") + from_config_models = [m for m in all_models if m.source == RegistryEntrySource.from_config] + assert len(from_config_models) == 1 + assert from_config_models[0].identifier == "enabled_model" + + # Cleanup + await table.shutdown() + + +async def test_models_dynamic_from_config_no_run_config(cached_disk_dist_registry): + """Test that from_config models work when no run_config is set.""" + table = ModelsRoutingTable({}, cached_disk_dist_registry, {}) + await table.initialize() + + # Test that list_models works without run_config models = await table.list_models() - assert len(models.data) == 2 + from_config_models = [m for m in models.data if m.source == RegistryEntrySource.from_config] + assert len(from_config_models) == 0 # No from_config models when no run_config - identifiers = {m.identifier for m in models.data} - assert "test_provider/user-model" in identifiers # User model preserved - assert "test_provider/provider-model-new" in identifiers # New provider model (uses provider's identifier) - assert "test_provider/provider-model-old" not in identifiers # Old provider model removed + # Cleanup + await table.shutdown() - # Verify sources are correct - user_model = next((m for m in models.data if m.identifier == "test_provider/user-model"), None) - provider_model = next((m for m in models.data if m.identifier == "test_provider/provider-model-new"), None) - assert user_model.source == RegistryEntrySource.via_register_api - assert provider_model.source == RegistryEntrySource.listed_from_provider +async def test_models_filter_persistent_models_from_removed_providers(cached_disk_dist_registry): + """Test that models from removed providers are filtered out from persistent models.""" + from llama_stack.apis.models import ModelType + from llama_stack.core.datatypes import ModelWithOwner, Provider, RegistryEntrySource, StackRunConfig + from llama_stack.core.routing_tables.models import ModelsRoutingTable + + # Create a routing table + table = ModelsRoutingTable({}, cached_disk_dist_registry, {}) + await table.initialize() + + # Create some mock persistent models + model1 = ModelWithOwner( + identifier="test_provider_1/model1", + provider_resource_id="model1", + provider_id="test_provider_1", + metadata={}, + model_type=ModelType.llm, + source=RegistryEntrySource.listed_from_provider, + ) + model2 = ModelWithOwner( + identifier="test_provider_2/model2", + provider_resource_id="model2", + provider_id="test_provider_2", + metadata={}, + model_type=ModelType.llm, + source=RegistryEntrySource.listed_from_provider, + ) + user_model = ModelWithOwner( + identifier="user_model", + provider_resource_id="user_model", + provider_id="test_provider_1", + metadata={}, + model_type=ModelType.llm, + source=RegistryEntrySource.via_register_api, + ) + + # Create a run config that only includes test_provider_1 (test_provider_2 is removed) + run_config = StackRunConfig( + image_name="test", + providers={ + "inference": [ + Provider(provider_id="test_provider_1", provider_type="openai", config={"api_key": "test_key"}), + # test_provider_2 is removed from run.yaml + ] + }, + models=[], + ) + + # Set the run config + table.current_run_config = run_config + await table.cleanup_disabled_provider_models() + await table.register_from_config_models() + + # Test the cleanup logic directly + # First, manually add models to the registry to simulate existing models + await table.dist_registry.register(model1) + await table.dist_registry.register(model2) + await table.dist_registry.register(user_model) + + # Now set the run config which should trigger cleanup + table.current_run_config = run_config + await table.cleanup_disabled_provider_models() + await table.register_from_config_models() + + # Get the list of models after cleanup + response = await table.list_models() + model_identifiers = {m.identifier for m in response.data} + + # Should have user_model (user-registered) and model1 (from enabled provider), but not model2 (from disabled provider) + # model1 should be kept because test_provider_1 is in the run config (enabled) + # model2 should be removed because test_provider_2 is not in the run config (disabled) + # user_model should be kept because it's user-registered + assert "user_model" in model_identifiers + assert "test_provider_1/model1" in model_identifiers + assert "test_provider_2/model2" not in model_identifiers + + # Test that user-registered models are always kept regardless of provider status + user_model_found = next((m for m in response.data if m.identifier == "user_model"), None) + assert user_model_found is not None + assert user_model_found.source == RegistryEntrySource.via_register_api # Cleanup await table.shutdown()