This commit is contained in:
Ignas Baranauskas 2025-10-03 14:11:23 +02:00 committed by GitHub
commit c645c49924
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 342 additions and 89 deletions

View file

@ -191,7 +191,14 @@ class CommonRoutingTableImpl(RoutingTable):
if not is_action_allowed(self.policy, "delete", obj, user):
raise AccessDeniedError("delete", obj, user)
await self.dist_registry.delete(obj.type, obj.identifier)
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
# Only try to unregister from provider if the provider still exists
if obj.provider_id in self.impls_by_provider_id:
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
else:
logger.debug(
f"Provider {obj.provider_id} no longer exists, skipping provider unregistration for {obj.identifier}"
)
async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider:
# if provider_id is not specified, pick an arbitrary one from existing entries
@ -254,6 +261,13 @@ async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) ->
if model is not None:
return model
# Check from_config models if this is a ModelsRoutingTable
if hasattr(routing_table, "_generate_from_config_models"):
from_config_models = routing_table._generate_from_config_models()
for from_config_model in from_config_models:
if from_config_model.identifier == model_id:
return from_config_model
logger.warning(
f"WARNING: model identifier '{model_id}' not found in routing table. Falling back to "
"searching in all providers. This is only for backwards compatibility and will stop working "

View file

@ -12,6 +12,7 @@ from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
from llama_stack.core.datatypes import (
ModelWithOwner,
RegistryEntrySource,
StackRunConfig,
)
from llama_stack.log import get_logger
@ -22,6 +23,7 @@ logger = get_logger(name=__name__, category="core::routing_tables")
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
listed_providers: set[str] = set()
current_run_config: "StackRunConfig | None" = None
async def refresh(self) -> None:
for provider_id, provider in self.impls_by_provider_id.items():
@ -74,6 +76,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
provider_id: str | None = None,
metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None,
source: RegistryEntrySource = RegistryEntrySource.via_register_api,
) -> Model:
if provider_id is None:
# If provider_id not specified, use the only provider if it supports this model
@ -106,7 +109,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
provider_id=provider_id,
metadata=metadata,
model_type=model_type,
source=RegistryEntrySource.via_register_api,
source=source,
)
registered_model = await self.register_object(model)
return registered_model
@ -117,6 +120,61 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
raise ModelNotFoundError(model_id)
await self.unregister_object(existing_model)
async def cleanup_disabled_provider_models(self) -> None:
"""Remove models from providers that are no longer enabled in the current run config."""
if not self.current_run_config:
return
# Get enabled provider IDs from the current run config
enabled_provider_ids = set()
for _api, providers in self.current_run_config.providers.items():
for provider in providers:
if provider.provider_id and provider.provider_id != "__disabled__":
enabled_provider_ids.add(provider.provider_id)
# Get all existing models
existing_models = await self.get_all_with_type("model")
# Find models from disabled providers (excluding user-registered models)
models_to_remove = []
for model in existing_models:
if model.provider_id not in enabled_provider_ids and model.source != RegistryEntrySource.via_register_api:
models_to_remove.append(model)
# Remove the models
for model in models_to_remove:
logger.info(f"Removing model {model.identifier} from disabled provider {model.provider_id}")
await self.unregister_object(model)
async def register_from_config_models(self) -> None:
"""Register from_config models from the current run configuration."""
if not self.current_run_config:
return
# Register new from_config models (old ones automatically disappear since they're not persisted)
for model_input in self.current_run_config.models:
# Skip models with disabled providers
if not model_input.provider_id or model_input.provider_id == "__disabled__":
continue
# Generate identifier
if model_input.model_id != (model_input.provider_model_id or model_input.model_id):
identifier = model_input.model_id
else:
identifier = f"{model_input.provider_id}/{model_input.provider_model_id or model_input.model_id}"
model = ModelWithOwner(
identifier=identifier,
provider_resource_id=model_input.provider_model_id or model_input.model_id,
provider_id=model_input.provider_id,
metadata=model_input.metadata,
model_type=model_input.model_type or ModelType.llm,
source=RegistryEntrySource.from_config,
)
# Register the model (will be cached in memory but not persisted to disk)
await self.dist_registry.register(model)
async def update_registered_models(
self,
provider_id: str,