diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index faaeefd01..137dd5e13 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -40,6 +40,7 @@ RoutingKey = str | list[str] class RegistryEntrySource(StrEnum): via_register_api = "via_register_api" listed_from_provider = "listed_from_provider" + from_config = "from_config" class User(BaseModel): diff --git a/llama_stack/core/routing_tables/common.py b/llama_stack/core/routing_tables/common.py index ca2f3af42..48a6250a5 100644 --- a/llama_stack/core/routing_tables/common.py +++ b/llama_stack/core/routing_tables/common.py @@ -191,7 +191,14 @@ class CommonRoutingTableImpl(RoutingTable): if not is_action_allowed(self.policy, "delete", obj, user): raise AccessDeniedError("delete", obj, user) await self.dist_registry.delete(obj.type, obj.identifier) - await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id]) + + # Only try to unregister from provider if the provider still exists + if obj.provider_id in self.impls_by_provider_id: + await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id]) + else: + logger.debug( + f"Provider {obj.provider_id} no longer exists, skipping provider unregistration for {obj.identifier}" + ) async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: # if provider_id is not specified, pick an arbitrary one from existing entries @@ -254,6 +261,13 @@ async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> if model is not None: return model + # Check from_config models if this is a ModelsRoutingTable + if hasattr(routing_table, "_generate_from_config_models"): + from_config_models = routing_table._generate_from_config_models() + for from_config_model in from_config_models: + if from_config_model.identifier == model_id: + return from_config_model + logger.warning( f"WARNING: model identifier '{model_id}' not found in routing table. Falling back to " "searching in all providers. This is only for backwards compatibility and will stop working " diff --git a/llama_stack/core/routing_tables/models.py b/llama_stack/core/routing_tables/models.py index b6141efa9..54ff580c8 100644 --- a/llama_stack/core/routing_tables/models.py +++ b/llama_stack/core/routing_tables/models.py @@ -12,6 +12,7 @@ from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType from llama_stack.core.datatypes import ( ModelWithOwner, RegistryEntrySource, + StackRunConfig, ) from llama_stack.log import get_logger @@ -22,6 +23,7 @@ logger = get_logger(name=__name__, category="core::routing_tables") class ModelsRoutingTable(CommonRoutingTableImpl, Models): listed_providers: set[str] = set() + current_run_config: "StackRunConfig | None" = None async def refresh(self) -> None: for provider_id, provider in self.impls_by_provider_id.items(): @@ -74,6 +76,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): provider_id: str | None = None, metadata: dict[str, Any] | None = None, model_type: ModelType | None = None, + source: RegistryEntrySource = RegistryEntrySource.via_register_api, ) -> Model: if provider_id is None: # If provider_id not specified, use the only provider if it supports this model @@ -106,7 +109,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): provider_id=provider_id, metadata=metadata, model_type=model_type, - source=RegistryEntrySource.via_register_api, + source=source, ) registered_model = await self.register_object(model) return registered_model @@ -117,6 +120,61 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): raise ModelNotFoundError(model_id) await self.unregister_object(existing_model) + async def cleanup_disabled_provider_models(self) -> None: + """Remove models from providers that are no longer enabled in the current run config.""" + if not self.current_run_config: + return + + # Get enabled provider IDs from the current run config + enabled_provider_ids = set() + for _api, providers in self.current_run_config.providers.items(): + for provider in providers: + if provider.provider_id and provider.provider_id != "__disabled__": + enabled_provider_ids.add(provider.provider_id) + + # Get all existing models + existing_models = await self.get_all_with_type("model") + + # Find models from disabled providers (excluding user-registered models) + models_to_remove = [] + for model in existing_models: + if model.provider_id not in enabled_provider_ids and model.source != RegistryEntrySource.via_register_api: + models_to_remove.append(model) + + # Remove the models + for model in models_to_remove: + logger.info(f"Removing model {model.identifier} from disabled provider {model.provider_id}") + await self.unregister_object(model) + + async def register_from_config_models(self) -> None: + """Register from_config models from the current run configuration.""" + if not self.current_run_config: + return + + # Register new from_config models (old ones automatically disappear since they're not persisted) + for model_input in self.current_run_config.models: + # Skip models with disabled providers + if not model_input.provider_id or model_input.provider_id == "__disabled__": + continue + + # Generate identifier + if model_input.model_id != (model_input.provider_model_id or model_input.model_id): + identifier = model_input.model_id + else: + identifier = f"{model_input.provider_id}/{model_input.provider_model_id or model_input.model_id}" + + model = ModelWithOwner( + identifier=identifier, + provider_resource_id=model_input.provider_model_id or model_input.model_id, + provider_id=model_input.provider_id, + metadata=model_input.metadata, + model_type=model_input.model_type or ModelType.llm, + source=RegistryEntrySource.from_config, + ) + + # Register the model (will be cached in memory but not persisted to disk) + await self.dist_registry.register(model) + async def update_registered_models( self, provider_id: str, diff --git a/llama_stack/core/stack.py b/llama_stack/core/stack.py index 7ab8d2c64..a13cd0211 100644 --- a/llama_stack/core/stack.py +++ b/llama_stack/core/stack.py @@ -101,6 +101,15 @@ TEST_RECORDING_CONTEXT = None async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]): + # Set the run config on the models routing table for generating from_config models + if Api.models in impls: + models_impl = impls[Api.models] + models_impl.current_run_config = run_config + # Clean up models from disabled providers + await models_impl.cleanup_disabled_provider_models() + # Register from_config models + await models_impl.register_from_config_models() + for rsrc, api, register_method, list_method in RESOURCES: objects = getattr(run_config, rsrc) if api not in impls: @@ -118,7 +127,16 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]): # we want to maintain the type information in arguments to method. # instead of method(**obj.model_dump()), which may convert a typed attr to a dict, # we use model_dump() to find all the attrs and then getattr to get the still typed value. - await method(**{k: getattr(obj, k) for k in obj.model_dump().keys()}) + kwargs = {k: getattr(obj, k) for k in obj.model_dump().keys()} + + # Skip registering from_config models since they are registered through the routing table's set_run_config + if rsrc == "models": + logger.debug( + f"Skipping registration of from_config model {obj.model_id} - will be registered through routing table" + ) + continue + + await method(**kwargs) method = getattr(impls[api], list_method) response = await method() diff --git a/llama_stack/core/store/registry.py b/llama_stack/core/store/registry.py index a764d692a..d1243cf33 100644 --- a/llama_stack/core/store/registry.py +++ b/llama_stack/core/store/registry.py @@ -102,6 +102,11 @@ class DiskDistributionRegistry(DistributionRegistry): f"Object {existing_obj.type}:{existing_obj.identifier}'s {existing_obj.provider_id} provider is being replaced with {obj.provider_id}" ) + # Skip persistence for from_config objects - they should only exist in memory + if hasattr(obj, "source") and obj.source == "from_config": + logger.debug(f"Skipping persistence for from_config object {obj.type}:{obj.identifier}") + return True + await self.kvstore.set( KEY_FORMAT.format(type=obj.type, identifier=obj.identifier), obj.model_dump_json(),