mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
fix: clear model cache when run.yaml model list changes
This commit is contained in:
parent
521865c388
commit
9e79e917f6
5 changed files with 99 additions and 3 deletions
|
@ -40,6 +40,7 @@ RoutingKey = str | list[str]
|
|||
class RegistryEntrySource(StrEnum):
|
||||
via_register_api = "via_register_api"
|
||||
listed_from_provider = "listed_from_provider"
|
||||
from_config = "from_config"
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
|
|
|
@ -191,7 +191,14 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
if not is_action_allowed(self.policy, "delete", obj, user):
|
||||
raise AccessDeniedError("delete", obj, user)
|
||||
await self.dist_registry.delete(obj.type, obj.identifier)
|
||||
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
|
||||
|
||||
# Only try to unregister from provider if the provider still exists
|
||||
if obj.provider_id in self.impls_by_provider_id:
|
||||
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
|
||||
else:
|
||||
logger.debug(
|
||||
f"Provider {obj.provider_id} no longer exists, skipping provider unregistration for {obj.identifier}"
|
||||
)
|
||||
|
||||
async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider:
|
||||
# if provider_id is not specified, pick an arbitrary one from existing entries
|
||||
|
@ -254,6 +261,13 @@ async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) ->
|
|||
if model is not None:
|
||||
return model
|
||||
|
||||
# Check from_config models if this is a ModelsRoutingTable
|
||||
if hasattr(routing_table, "_generate_from_config_models"):
|
||||
from_config_models = routing_table._generate_from_config_models()
|
||||
for from_config_model in from_config_models:
|
||||
if from_config_model.identifier == model_id:
|
||||
return from_config_model
|
||||
|
||||
logger.warning(
|
||||
f"WARNING: model identifier '{model_id}' not found in routing table. Falling back to "
|
||||
"searching in all providers. This is only for backwards compatibility and will stop working "
|
||||
|
|
|
@ -12,6 +12,7 @@ from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
|
|||
from llama_stack.core.datatypes import (
|
||||
ModelWithOwner,
|
||||
RegistryEntrySource,
|
||||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
|
@ -22,6 +23,7 @@ logger = get_logger(name=__name__, category="core::routing_tables")
|
|||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
listed_providers: set[str] = set()
|
||||
current_run_config: "StackRunConfig | None" = None
|
||||
|
||||
async def refresh(self) -> None:
|
||||
for provider_id, provider in self.impls_by_provider_id.items():
|
||||
|
@ -74,6 +76,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
provider_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
model_type: ModelType | None = None,
|
||||
source: RegistryEntrySource = RegistryEntrySource.via_register_api,
|
||||
) -> Model:
|
||||
if provider_id is None:
|
||||
# If provider_id not specified, use the only provider if it supports this model
|
||||
|
@ -106,7 +109,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
provider_id=provider_id,
|
||||
metadata=metadata,
|
||||
model_type=model_type,
|
||||
source=RegistryEntrySource.via_register_api,
|
||||
source=source,
|
||||
)
|
||||
registered_model = await self.register_object(model)
|
||||
return registered_model
|
||||
|
@ -117,6 +120,61 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
raise ModelNotFoundError(model_id)
|
||||
await self.unregister_object(existing_model)
|
||||
|
||||
async def cleanup_disabled_provider_models(self) -> None:
|
||||
"""Remove models from providers that are no longer enabled in the current run config."""
|
||||
if not self.current_run_config:
|
||||
return
|
||||
|
||||
# Get enabled provider IDs from the current run config
|
||||
enabled_provider_ids = set()
|
||||
for _api, providers in self.current_run_config.providers.items():
|
||||
for provider in providers:
|
||||
if provider.provider_id and provider.provider_id != "__disabled__":
|
||||
enabled_provider_ids.add(provider.provider_id)
|
||||
|
||||
# Get all existing models
|
||||
existing_models = await self.get_all_with_type("model")
|
||||
|
||||
# Find models from disabled providers (excluding user-registered models)
|
||||
models_to_remove = []
|
||||
for model in existing_models:
|
||||
if model.provider_id not in enabled_provider_ids and model.source != RegistryEntrySource.via_register_api:
|
||||
models_to_remove.append(model)
|
||||
|
||||
# Remove the models
|
||||
for model in models_to_remove:
|
||||
logger.info(f"Removing model {model.identifier} from disabled provider {model.provider_id}")
|
||||
await self.unregister_object(model)
|
||||
|
||||
async def register_from_config_models(self) -> None:
|
||||
"""Register from_config models from the current run configuration."""
|
||||
if not self.current_run_config:
|
||||
return
|
||||
|
||||
# Register new from_config models (old ones automatically disappear since they're not persisted)
|
||||
for model_input in self.current_run_config.models:
|
||||
# Skip models with disabled providers
|
||||
if not model_input.provider_id or model_input.provider_id == "__disabled__":
|
||||
continue
|
||||
|
||||
# Generate identifier
|
||||
if model_input.model_id != (model_input.provider_model_id or model_input.model_id):
|
||||
identifier = model_input.model_id
|
||||
else:
|
||||
identifier = f"{model_input.provider_id}/{model_input.provider_model_id or model_input.model_id}"
|
||||
|
||||
model = ModelWithOwner(
|
||||
identifier=identifier,
|
||||
provider_resource_id=model_input.provider_model_id or model_input.model_id,
|
||||
provider_id=model_input.provider_id,
|
||||
metadata=model_input.metadata,
|
||||
model_type=model_input.model_type or ModelType.llm,
|
||||
source=RegistryEntrySource.from_config,
|
||||
)
|
||||
|
||||
# Register the model (will be cached in memory but not persisted to disk)
|
||||
await self.dist_registry.register(model)
|
||||
|
||||
async def update_registered_models(
|
||||
self,
|
||||
provider_id: str,
|
||||
|
|
|
@ -101,6 +101,15 @@ TEST_RECORDING_CONTEXT = None
|
|||
|
||||
|
||||
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
||||
# Set the run config on the models routing table for generating from_config models
|
||||
if Api.models in impls:
|
||||
models_impl = impls[Api.models]
|
||||
models_impl.current_run_config = run_config
|
||||
# Clean up models from disabled providers
|
||||
await models_impl.cleanup_disabled_provider_models()
|
||||
# Register from_config models
|
||||
await models_impl.register_from_config_models()
|
||||
|
||||
for rsrc, api, register_method, list_method in RESOURCES:
|
||||
objects = getattr(run_config, rsrc)
|
||||
if api not in impls:
|
||||
|
@ -118,7 +127,16 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
|||
# we want to maintain the type information in arguments to method.
|
||||
# instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
|
||||
# we use model_dump() to find all the attrs and then getattr to get the still typed value.
|
||||
await method(**{k: getattr(obj, k) for k in obj.model_dump().keys()})
|
||||
kwargs = {k: getattr(obj, k) for k in obj.model_dump().keys()}
|
||||
|
||||
# Skip registering from_config models since they are registered through the routing table's set_run_config
|
||||
if rsrc == "models":
|
||||
logger.debug(
|
||||
f"Skipping registration of from_config model {obj.model_id} - will be registered through routing table"
|
||||
)
|
||||
continue
|
||||
|
||||
await method(**kwargs)
|
||||
|
||||
method = getattr(impls[api], list_method)
|
||||
response = await method()
|
||||
|
|
|
@ -102,6 +102,11 @@ class DiskDistributionRegistry(DistributionRegistry):
|
|||
f"Object {existing_obj.type}:{existing_obj.identifier}'s {existing_obj.provider_id} provider is being replaced with {obj.provider_id}"
|
||||
)
|
||||
|
||||
# Skip persistence for from_config objects - they should only exist in memory
|
||||
if hasattr(obj, "source") and obj.source == "from_config":
|
||||
logger.debug(f"Skipping persistence for from_config object {obj.type}:{obj.identifier}")
|
||||
return True
|
||||
|
||||
await self.kvstore.set(
|
||||
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),
|
||||
obj.model_dump_json(),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue