mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
Merge 52503490d8
into d266c59c2a
This commit is contained in:
commit
c645c49924
6 changed files with 342 additions and 89 deletions
|
@ -40,6 +40,7 @@ RoutingKey = str | list[str]
|
||||||
class RegistryEntrySource(StrEnum):
|
class RegistryEntrySource(StrEnum):
|
||||||
via_register_api = "via_register_api"
|
via_register_api = "via_register_api"
|
||||||
listed_from_provider = "listed_from_provider"
|
listed_from_provider = "listed_from_provider"
|
||||||
|
from_config = "from_config"
|
||||||
|
|
||||||
|
|
||||||
class User(BaseModel):
|
class User(BaseModel):
|
||||||
|
|
|
@ -191,7 +191,14 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
if not is_action_allowed(self.policy, "delete", obj, user):
|
if not is_action_allowed(self.policy, "delete", obj, user):
|
||||||
raise AccessDeniedError("delete", obj, user)
|
raise AccessDeniedError("delete", obj, user)
|
||||||
await self.dist_registry.delete(obj.type, obj.identifier)
|
await self.dist_registry.delete(obj.type, obj.identifier)
|
||||||
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
|
|
||||||
|
# Only try to unregister from provider if the provider still exists
|
||||||
|
if obj.provider_id in self.impls_by_provider_id:
|
||||||
|
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
f"Provider {obj.provider_id} no longer exists, skipping provider unregistration for {obj.identifier}"
|
||||||
|
)
|
||||||
|
|
||||||
async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider:
|
async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider:
|
||||||
# if provider_id is not specified, pick an arbitrary one from existing entries
|
# if provider_id is not specified, pick an arbitrary one from existing entries
|
||||||
|
@ -254,6 +261,13 @@ async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) ->
|
||||||
if model is not None:
|
if model is not None:
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
# Check from_config models if this is a ModelsRoutingTable
|
||||||
|
if hasattr(routing_table, "_generate_from_config_models"):
|
||||||
|
from_config_models = routing_table._generate_from_config_models()
|
||||||
|
for from_config_model in from_config_models:
|
||||||
|
if from_config_model.identifier == model_id:
|
||||||
|
return from_config_model
|
||||||
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"WARNING: model identifier '{model_id}' not found in routing table. Falling back to "
|
f"WARNING: model identifier '{model_id}' not found in routing table. Falling back to "
|
||||||
"searching in all providers. This is only for backwards compatibility and will stop working "
|
"searching in all providers. This is only for backwards compatibility and will stop working "
|
||||||
|
|
|
@ -12,6 +12,7 @@ from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
|
||||||
from llama_stack.core.datatypes import (
|
from llama_stack.core.datatypes import (
|
||||||
ModelWithOwner,
|
ModelWithOwner,
|
||||||
RegistryEntrySource,
|
RegistryEntrySource,
|
||||||
|
StackRunConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
@ -22,6 +23,7 @@ logger = get_logger(name=__name__, category="core::routing_tables")
|
||||||
|
|
||||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
listed_providers: set[str] = set()
|
listed_providers: set[str] = set()
|
||||||
|
current_run_config: "StackRunConfig | None" = None
|
||||||
|
|
||||||
async def refresh(self) -> None:
|
async def refresh(self) -> None:
|
||||||
for provider_id, provider in self.impls_by_provider_id.items():
|
for provider_id, provider in self.impls_by_provider_id.items():
|
||||||
|
@ -74,6 +76,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
provider_id: str | None = None,
|
provider_id: str | None = None,
|
||||||
metadata: dict[str, Any] | None = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
model_type: ModelType | None = None,
|
model_type: ModelType | None = None,
|
||||||
|
source: RegistryEntrySource = RegistryEntrySource.via_register_api,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
if provider_id is None:
|
if provider_id is None:
|
||||||
# If provider_id not specified, use the only provider if it supports this model
|
# If provider_id not specified, use the only provider if it supports this model
|
||||||
|
@ -106,7 +109,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
source=RegistryEntrySource.via_register_api,
|
source=source,
|
||||||
)
|
)
|
||||||
registered_model = await self.register_object(model)
|
registered_model = await self.register_object(model)
|
||||||
return registered_model
|
return registered_model
|
||||||
|
@ -117,6 +120,61 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
raise ModelNotFoundError(model_id)
|
raise ModelNotFoundError(model_id)
|
||||||
await self.unregister_object(existing_model)
|
await self.unregister_object(existing_model)
|
||||||
|
|
||||||
|
async def cleanup_disabled_provider_models(self) -> None:
|
||||||
|
"""Remove models from providers that are no longer enabled in the current run config."""
|
||||||
|
if not self.current_run_config:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get enabled provider IDs from the current run config
|
||||||
|
enabled_provider_ids = set()
|
||||||
|
for _api, providers in self.current_run_config.providers.items():
|
||||||
|
for provider in providers:
|
||||||
|
if provider.provider_id and provider.provider_id != "__disabled__":
|
||||||
|
enabled_provider_ids.add(provider.provider_id)
|
||||||
|
|
||||||
|
# Get all existing models
|
||||||
|
existing_models = await self.get_all_with_type("model")
|
||||||
|
|
||||||
|
# Find models from disabled providers (excluding user-registered models)
|
||||||
|
models_to_remove = []
|
||||||
|
for model in existing_models:
|
||||||
|
if model.provider_id not in enabled_provider_ids and model.source != RegistryEntrySource.via_register_api:
|
||||||
|
models_to_remove.append(model)
|
||||||
|
|
||||||
|
# Remove the models
|
||||||
|
for model in models_to_remove:
|
||||||
|
logger.info(f"Removing model {model.identifier} from disabled provider {model.provider_id}")
|
||||||
|
await self.unregister_object(model)
|
||||||
|
|
||||||
|
async def register_from_config_models(self) -> None:
|
||||||
|
"""Register from_config models from the current run configuration."""
|
||||||
|
if not self.current_run_config:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Register new from_config models (old ones automatically disappear since they're not persisted)
|
||||||
|
for model_input in self.current_run_config.models:
|
||||||
|
# Skip models with disabled providers
|
||||||
|
if not model_input.provider_id or model_input.provider_id == "__disabled__":
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Generate identifier
|
||||||
|
if model_input.model_id != (model_input.provider_model_id or model_input.model_id):
|
||||||
|
identifier = model_input.model_id
|
||||||
|
else:
|
||||||
|
identifier = f"{model_input.provider_id}/{model_input.provider_model_id or model_input.model_id}"
|
||||||
|
|
||||||
|
model = ModelWithOwner(
|
||||||
|
identifier=identifier,
|
||||||
|
provider_resource_id=model_input.provider_model_id or model_input.model_id,
|
||||||
|
provider_id=model_input.provider_id,
|
||||||
|
metadata=model_input.metadata,
|
||||||
|
model_type=model_input.model_type or ModelType.llm,
|
||||||
|
source=RegistryEntrySource.from_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register the model (will be cached in memory but not persisted to disk)
|
||||||
|
await self.dist_registry.register(model)
|
||||||
|
|
||||||
async def update_registered_models(
|
async def update_registered_models(
|
||||||
self,
|
self,
|
||||||
provider_id: str,
|
provider_id: str,
|
||||||
|
|
|
@ -99,6 +99,15 @@ TEST_RECORDING_CONTEXT = None
|
||||||
|
|
||||||
|
|
||||||
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
||||||
|
# Set the run config on the models routing table for generating from_config models
|
||||||
|
if Api.models in impls:
|
||||||
|
models_impl = impls[Api.models]
|
||||||
|
models_impl.current_run_config = run_config
|
||||||
|
# Clean up models from disabled providers
|
||||||
|
await models_impl.cleanup_disabled_provider_models()
|
||||||
|
# Register from_config models
|
||||||
|
await models_impl.register_from_config_models()
|
||||||
|
|
||||||
for rsrc, api, register_method, list_method in RESOURCES:
|
for rsrc, api, register_method, list_method in RESOURCES:
|
||||||
objects = getattr(run_config, rsrc)
|
objects = getattr(run_config, rsrc)
|
||||||
if api not in impls:
|
if api not in impls:
|
||||||
|
@ -116,7 +125,16 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
||||||
# we want to maintain the type information in arguments to method.
|
# we want to maintain the type information in arguments to method.
|
||||||
# instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
|
# instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
|
||||||
# we use model_dump() to find all the attrs and then getattr to get the still typed value.
|
# we use model_dump() to find all the attrs and then getattr to get the still typed value.
|
||||||
await method(**{k: getattr(obj, k) for k in obj.model_dump().keys()})
|
kwargs = {k: getattr(obj, k) for k in obj.model_dump().keys()}
|
||||||
|
|
||||||
|
# Skip registering from_config models since they are registered through the routing table's set_run_config
|
||||||
|
if rsrc == "models":
|
||||||
|
logger.debug(
|
||||||
|
f"Skipping registration of from_config model {obj.model_id} - will be registered through routing table"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
await method(**kwargs)
|
||||||
|
|
||||||
method = getattr(impls[api], list_method)
|
method = getattr(impls[api], list_method)
|
||||||
response = await method()
|
response = await method()
|
||||||
|
|
|
@ -100,6 +100,11 @@ class DiskDistributionRegistry(DistributionRegistry):
|
||||||
if existing_obj and existing_obj.provider_id == obj.provider_id:
|
if existing_obj and existing_obj.provider_id == obj.provider_id:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# Skip persistence for from_config objects - they should only exist in memory
|
||||||
|
if hasattr(obj, "source") and obj.source == "from_config":
|
||||||
|
logger.debug(f"Skipping persistence for from_config object {obj.type}:{obj.identifier}")
|
||||||
|
return True
|
||||||
|
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),
|
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),
|
||||||
obj.model_dump_json(),
|
obj.model_dump_json(),
|
||||||
|
|
|
@ -18,7 +18,7 @@ from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.apis.shields.shields import Shield
|
from llama_stack.apis.shields.shields import Shield
|
||||||
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup
|
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.core.datatypes import RegistryEntrySource
|
from llama_stack.core.datatypes import RegistryEntrySource, StackRunConfig
|
||||||
from llama_stack.core.routing_tables.benchmarks import BenchmarksRoutingTable
|
from llama_stack.core.routing_tables.benchmarks import BenchmarksRoutingTable
|
||||||
from llama_stack.core.routing_tables.datasets import DatasetsRoutingTable
|
from llama_stack.core.routing_tables.datasets import DatasetsRoutingTable
|
||||||
from llama_stack.core.routing_tables.models import ModelsRoutingTable
|
from llama_stack.core.routing_tables.models import ModelsRoutingTable
|
||||||
|
@ -538,114 +538,271 @@ async def test_models_source_tracking_provider(cached_disk_dist_registry):
|
||||||
await table.shutdown()
|
await table.shutdown()
|
||||||
|
|
||||||
|
|
||||||
async def test_models_source_interaction_preserves_default(cached_disk_dist_registry):
|
async def test_models_dynamic_from_config_generation(cached_disk_dist_registry):
|
||||||
"""Test that provider refresh preserves user-registered models with default source."""
|
"""Test that from_config models are generated dynamically from run_config."""
|
||||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
table = ModelsRoutingTable({}, cached_disk_dist_registry, {})
|
||||||
await table.initialize()
|
await table.initialize()
|
||||||
|
|
||||||
# First register a user model with same provider_resource_id as provider will later provide
|
# Test that no from_config models are registered when no run_config
|
||||||
await table.register_model(
|
all_models = await table.get_all_with_type("model")
|
||||||
model_id="my-custom-alias", provider_model_id="provider-model-1", provider_id="test_provider"
|
from_config_models = [m for m in all_models if m.source == RegistryEntrySource.from_config]
|
||||||
|
assert len(from_config_models) == 0
|
||||||
|
|
||||||
|
# Create a run config with from_config models
|
||||||
|
run_config = StackRunConfig(
|
||||||
|
image_name="test",
|
||||||
|
providers={},
|
||||||
|
models=[
|
||||||
|
{
|
||||||
|
"model_id": "from_config_model_1",
|
||||||
|
"provider_id": "test_provider",
|
||||||
|
"model_type": "llm",
|
||||||
|
"provider_model_id": "gpt-3.5-turbo",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_id": "from_config_model_2",
|
||||||
|
"provider_id": "test_provider",
|
||||||
|
"model_type": "llm",
|
||||||
|
"provider_model_id": "gpt-4",
|
||||||
|
},
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify user model is registered with default source
|
# Set the run config
|
||||||
models = await table.list_models()
|
table.current_run_config = run_config
|
||||||
assert len(models.data) == 1
|
await table.cleanup_disabled_provider_models()
|
||||||
user_model = models.data[0]
|
await table.register_from_config_models()
|
||||||
assert user_model.source == RegistryEntrySource.via_register_api
|
|
||||||
assert user_model.identifier == "my-custom-alias"
|
|
||||||
assert user_model.provider_resource_id == "provider-model-1"
|
|
||||||
|
|
||||||
# Now simulate provider refresh
|
# Test that from_config models are registered in the registry
|
||||||
provider_models = [
|
all_models = await table.get_all_with_type("model")
|
||||||
Model(
|
from_config_models = [m for m in all_models if m.source == RegistryEntrySource.from_config]
|
||||||
identifier="provider-model-1",
|
assert len(from_config_models) == 2
|
||||||
provider_resource_id="provider-model-1",
|
|
||||||
provider_id="test_provider",
|
|
||||||
metadata={},
|
|
||||||
model_type=ModelType.llm,
|
|
||||||
),
|
|
||||||
Model(
|
|
||||||
identifier="different-model",
|
|
||||||
provider_resource_id="different-model",
|
|
||||||
provider_id="test_provider",
|
|
||||||
metadata={},
|
|
||||||
model_type=ModelType.llm,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
await table.update_registered_models("test_provider", provider_models)
|
|
||||||
|
|
||||||
# Verify user model with alias is preserved, but provider added new model
|
model_identifiers = {m.identifier for m in from_config_models}
|
||||||
models = await table.list_models()
|
assert "from_config_model_1" in model_identifiers
|
||||||
assert len(models.data) == 2
|
assert "from_config_model_2" in model_identifiers
|
||||||
|
|
||||||
# Find the user model and provider model
|
# Test that from_config models have correct attributes
|
||||||
user_model = next((m for m in models.data if m.identifier == "my-custom-alias"), None)
|
model_1 = next(m for m in from_config_models if m.identifier == "from_config_model_1")
|
||||||
provider_model = next((m for m in models.data if m.identifier == "test_provider/different-model"), None)
|
assert model_1.provider_id == "test_provider"
|
||||||
|
assert model_1.provider_resource_id == "gpt-3.5-turbo"
|
||||||
assert user_model is not None
|
assert model_1.model_type == ModelType.llm
|
||||||
assert user_model.source == RegistryEntrySource.via_register_api
|
assert model_1.source == RegistryEntrySource.from_config
|
||||||
assert user_model.provider_resource_id == "provider-model-1"
|
|
||||||
|
|
||||||
assert provider_model is not None
|
|
||||||
assert provider_model.source == RegistryEntrySource.listed_from_provider
|
|
||||||
assert provider_model.provider_resource_id == "different-model"
|
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
await table.shutdown()
|
await table.shutdown()
|
||||||
|
|
||||||
|
|
||||||
async def test_models_source_interaction_cleanup_provider_models(cached_disk_dist_registry):
|
async def test_models_dynamic_from_config_lookup(cached_disk_dist_registry):
|
||||||
"""Test that provider refresh removes old provider models but keeps default ones."""
|
"""Test that from_config models can be looked up individually."""
|
||||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
table = ModelsRoutingTable({}, cached_disk_dist_registry, {})
|
||||||
await table.initialize()
|
await table.initialize()
|
||||||
|
|
||||||
# Register a user model
|
# Create a run config with from_config models
|
||||||
await table.register_model(model_id="user-model", provider_id="test_provider")
|
run_config = StackRunConfig(
|
||||||
|
image_name="test",
|
||||||
|
providers={},
|
||||||
|
models=[
|
||||||
|
{
|
||||||
|
"model_id": "lookup_test_model",
|
||||||
|
"provider_id": "test_provider",
|
||||||
|
"model_type": "llm",
|
||||||
|
"provider_model_id": "gpt-3.5-turbo",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# Add some provider models
|
# Set the run config
|
||||||
provider_models_v1 = [
|
table.current_run_config = run_config
|
||||||
Model(
|
await table.cleanup_disabled_provider_models()
|
||||||
identifier="provider-model-old",
|
await table.register_from_config_models()
|
||||||
provider_resource_id="provider-model-old",
|
|
||||||
provider_id="test_provider",
|
|
||||||
metadata={},
|
|
||||||
model_type=ModelType.llm,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
await table.update_registered_models("test_provider", provider_models_v1)
|
|
||||||
|
|
||||||
# Verify we have both user and provider models
|
# Test that we can get the from_config model individually
|
||||||
|
model = await table.get_model("lookup_test_model")
|
||||||
|
assert model is not None
|
||||||
|
assert model.identifier == "lookup_test_model"
|
||||||
|
assert model.provider_id == "test_provider"
|
||||||
|
assert model.provider_resource_id == "gpt-3.5-turbo"
|
||||||
|
assert model.source == RegistryEntrySource.from_config
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
await table.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_models_dynamic_from_config_mixed_with_persistent(cached_disk_dist_registry):
|
||||||
|
"""Test that from_config models work alongside persistent models."""
|
||||||
|
table = ModelsRoutingTable({}, cached_disk_dist_registry, {})
|
||||||
|
await table.initialize()
|
||||||
|
|
||||||
|
# Create a run config with from_config models
|
||||||
|
run_config = StackRunConfig(
|
||||||
|
image_name="test",
|
||||||
|
providers={},
|
||||||
|
models=[
|
||||||
|
{
|
||||||
|
"model_id": "from_config_model",
|
||||||
|
"provider_id": "test_provider",
|
||||||
|
"model_type": "llm",
|
||||||
|
"provider_model_id": "gpt-3.5-turbo",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set the run config
|
||||||
|
table.current_run_config = run_config
|
||||||
|
await table.cleanup_disabled_provider_models()
|
||||||
|
await table.register_from_config_models()
|
||||||
|
|
||||||
|
# Test that from_config models are included
|
||||||
models = await table.list_models()
|
models = await table.list_models()
|
||||||
assert len(models.data) == 2
|
from_config_models = [m for m in models.data if m.source == RegistryEntrySource.from_config]
|
||||||
|
|
||||||
# Now update with new provider models (should remove old provider models)
|
assert len(from_config_models) == 1
|
||||||
provider_models_v2 = [
|
assert from_config_models[0].identifier == "from_config_model"
|
||||||
Model(
|
|
||||||
identifier="provider-model-new",
|
|
||||||
provider_resource_id="provider-model-new",
|
|
||||||
provider_id="test_provider",
|
|
||||||
metadata={},
|
|
||||||
model_type=ModelType.llm,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
await table.update_registered_models("test_provider", provider_models_v2)
|
|
||||||
|
|
||||||
# Should have user model + new provider model, old provider model gone
|
# Test that we can get the from_config model individually
|
||||||
|
from_config_model = await table.get_model("from_config_model")
|
||||||
|
assert from_config_model is not None
|
||||||
|
assert from_config_model.source == RegistryEntrySource.from_config
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
await table.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_models_dynamic_from_config_disabled_providers(cached_disk_dist_registry):
|
||||||
|
"""Test that from_config models with disabled providers are skipped."""
|
||||||
|
table = ModelsRoutingTable({}, cached_disk_dist_registry, {})
|
||||||
|
await table.initialize()
|
||||||
|
|
||||||
|
# Create a run config with disabled provider models
|
||||||
|
run_config = StackRunConfig(
|
||||||
|
image_name="test",
|
||||||
|
providers={},
|
||||||
|
models=[
|
||||||
|
{
|
||||||
|
"model_id": "enabled_model",
|
||||||
|
"provider_id": "test_provider",
|
||||||
|
"model_type": "llm",
|
||||||
|
"provider_model_id": "gpt-3.5-turbo",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_id": "disabled_model",
|
||||||
|
"provider_id": "__disabled__",
|
||||||
|
"model_type": "llm",
|
||||||
|
"provider_model_id": "gpt-4",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set the run config
|
||||||
|
table.current_run_config = run_config
|
||||||
|
await table.cleanup_disabled_provider_models()
|
||||||
|
await table.register_from_config_models()
|
||||||
|
|
||||||
|
# Test that only enabled models are included
|
||||||
|
all_models = await table.get_all_with_type("model")
|
||||||
|
from_config_models = [m for m in all_models if m.source == RegistryEntrySource.from_config]
|
||||||
|
assert len(from_config_models) == 1
|
||||||
|
assert from_config_models[0].identifier == "enabled_model"
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
await table.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_models_dynamic_from_config_no_run_config(cached_disk_dist_registry):
|
||||||
|
"""Test that from_config models work when no run_config is set."""
|
||||||
|
table = ModelsRoutingTable({}, cached_disk_dist_registry, {})
|
||||||
|
await table.initialize()
|
||||||
|
|
||||||
|
# Test that list_models works without run_config
|
||||||
models = await table.list_models()
|
models = await table.list_models()
|
||||||
assert len(models.data) == 2
|
from_config_models = [m for m in models.data if m.source == RegistryEntrySource.from_config]
|
||||||
|
assert len(from_config_models) == 0 # No from_config models when no run_config
|
||||||
|
|
||||||
identifiers = {m.identifier for m in models.data}
|
# Cleanup
|
||||||
assert "test_provider/user-model" in identifiers # User model preserved
|
await table.shutdown()
|
||||||
assert "test_provider/provider-model-new" in identifiers # New provider model (uses provider's identifier)
|
|
||||||
assert "test_provider/provider-model-old" not in identifiers # Old provider model removed
|
|
||||||
|
|
||||||
# Verify sources are correct
|
|
||||||
user_model = next((m for m in models.data if m.identifier == "test_provider/user-model"), None)
|
|
||||||
provider_model = next((m for m in models.data if m.identifier == "test_provider/provider-model-new"), None)
|
|
||||||
|
|
||||||
assert user_model.source == RegistryEntrySource.via_register_api
|
async def test_models_filter_persistent_models_from_removed_providers(cached_disk_dist_registry):
|
||||||
assert provider_model.source == RegistryEntrySource.listed_from_provider
|
"""Test that models from removed providers are filtered out from persistent models."""
|
||||||
|
from llama_stack.apis.models import ModelType
|
||||||
|
from llama_stack.core.datatypes import ModelWithOwner, Provider, RegistryEntrySource, StackRunConfig
|
||||||
|
from llama_stack.core.routing_tables.models import ModelsRoutingTable
|
||||||
|
|
||||||
|
# Create a routing table
|
||||||
|
table = ModelsRoutingTable({}, cached_disk_dist_registry, {})
|
||||||
|
await table.initialize()
|
||||||
|
|
||||||
|
# Create some mock persistent models
|
||||||
|
model1 = ModelWithOwner(
|
||||||
|
identifier="test_provider_1/model1",
|
||||||
|
provider_resource_id="model1",
|
||||||
|
provider_id="test_provider_1",
|
||||||
|
metadata={},
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
source=RegistryEntrySource.listed_from_provider,
|
||||||
|
)
|
||||||
|
model2 = ModelWithOwner(
|
||||||
|
identifier="test_provider_2/model2",
|
||||||
|
provider_resource_id="model2",
|
||||||
|
provider_id="test_provider_2",
|
||||||
|
metadata={},
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
source=RegistryEntrySource.listed_from_provider,
|
||||||
|
)
|
||||||
|
user_model = ModelWithOwner(
|
||||||
|
identifier="user_model",
|
||||||
|
provider_resource_id="user_model",
|
||||||
|
provider_id="test_provider_1",
|
||||||
|
metadata={},
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
source=RegistryEntrySource.via_register_api,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a run config that only includes test_provider_1 (test_provider_2 is removed)
|
||||||
|
run_config = StackRunConfig(
|
||||||
|
image_name="test",
|
||||||
|
providers={
|
||||||
|
"inference": [
|
||||||
|
Provider(provider_id="test_provider_1", provider_type="openai", config={"api_key": "test_key"}),
|
||||||
|
# test_provider_2 is removed from run.yaml
|
||||||
|
]
|
||||||
|
},
|
||||||
|
models=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set the run config
|
||||||
|
table.current_run_config = run_config
|
||||||
|
await table.cleanup_disabled_provider_models()
|
||||||
|
await table.register_from_config_models()
|
||||||
|
|
||||||
|
# Test the cleanup logic directly
|
||||||
|
# First, manually add models to the registry to simulate existing models
|
||||||
|
await table.dist_registry.register(model1)
|
||||||
|
await table.dist_registry.register(model2)
|
||||||
|
await table.dist_registry.register(user_model)
|
||||||
|
|
||||||
|
# Now set the run config which should trigger cleanup
|
||||||
|
table.current_run_config = run_config
|
||||||
|
await table.cleanup_disabled_provider_models()
|
||||||
|
await table.register_from_config_models()
|
||||||
|
|
||||||
|
# Get the list of models after cleanup
|
||||||
|
response = await table.list_models()
|
||||||
|
model_identifiers = {m.identifier for m in response.data}
|
||||||
|
|
||||||
|
# Should have user_model (user-registered) and model1 (from enabled provider), but not model2 (from disabled provider)
|
||||||
|
# model1 should be kept because test_provider_1 is in the run config (enabled)
|
||||||
|
# model2 should be removed because test_provider_2 is not in the run config (disabled)
|
||||||
|
# user_model should be kept because it's user-registered
|
||||||
|
assert "user_model" in model_identifiers
|
||||||
|
assert "test_provider_1/model1" in model_identifiers
|
||||||
|
assert "test_provider_2/model2" not in model_identifiers
|
||||||
|
|
||||||
|
# Test that user-registered models are always kept regardless of provider status
|
||||||
|
user_model_found = next((m for m in response.data if m.identifier == "user_model"), None)
|
||||||
|
assert user_model_found is not None
|
||||||
|
assert user_model_found.source == RegistryEntrySource.via_register_api
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
await table.shutdown()
|
await table.shutdown()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue