mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
test: add tests for model not persistant models
This commit is contained in:
parent
9e79e917f6
commit
52503490d8
1 changed files with 250 additions and 93 deletions
|
@ -17,7 +17,7 @@ from llama_stack.apis.models import Model, ModelType
|
|||
from llama_stack.apis.shields.shields import Shield
|
||||
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.core.datatypes import RegistryEntrySource
|
||||
from llama_stack.core.datatypes import RegistryEntrySource, StackRunConfig
|
||||
from llama_stack.core.routing_tables.benchmarks import BenchmarksRoutingTable
|
||||
from llama_stack.core.routing_tables.datasets import DatasetsRoutingTable
|
||||
from llama_stack.core.routing_tables.models import ModelsRoutingTable
|
||||
|
@ -534,114 +534,271 @@ async def test_models_source_tracking_provider(cached_disk_dist_registry):
|
|||
await table.shutdown()
|
||||
|
||||
|
||||
async def test_models_source_interaction_preserves_default(cached_disk_dist_registry):
|
||||
"""Test that provider refresh preserves user-registered models with default source."""
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
async def test_models_dynamic_from_config_generation(cached_disk_dist_registry):
|
||||
"""Test that from_config models are generated dynamically from run_config."""
|
||||
table = ModelsRoutingTable({}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# First register a user model with same provider_resource_id as provider will later provide
|
||||
await table.register_model(
|
||||
model_id="my-custom-alias", provider_model_id="provider-model-1", provider_id="test_provider"
|
||||
# Test that no from_config models are registered when no run_config
|
||||
all_models = await table.get_all_with_type("model")
|
||||
from_config_models = [m for m in all_models if m.source == RegistryEntrySource.from_config]
|
||||
assert len(from_config_models) == 0
|
||||
|
||||
# Create a run config with from_config models
|
||||
run_config = StackRunConfig(
|
||||
image_name="test",
|
||||
providers={},
|
||||
models=[
|
||||
{
|
||||
"model_id": "from_config_model_1",
|
||||
"provider_id": "test_provider",
|
||||
"model_type": "llm",
|
||||
"provider_model_id": "gpt-3.5-turbo",
|
||||
},
|
||||
{
|
||||
"model_id": "from_config_model_2",
|
||||
"provider_id": "test_provider",
|
||||
"model_type": "llm",
|
||||
"provider_model_id": "gpt-4",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# Verify user model is registered with default source
|
||||
models = await table.list_models()
|
||||
assert len(models.data) == 1
|
||||
user_model = models.data[0]
|
||||
assert user_model.source == RegistryEntrySource.via_register_api
|
||||
assert user_model.identifier == "my-custom-alias"
|
||||
assert user_model.provider_resource_id == "provider-model-1"
|
||||
# Set the run config
|
||||
table.current_run_config = run_config
|
||||
await table.cleanup_disabled_provider_models()
|
||||
await table.register_from_config_models()
|
||||
|
||||
# Now simulate provider refresh
|
||||
provider_models = [
|
||||
Model(
|
||||
identifier="provider-model-1",
|
||||
provider_resource_id="provider-model-1",
|
||||
provider_id="test_provider",
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
Model(
|
||||
identifier="different-model",
|
||||
provider_resource_id="different-model",
|
||||
provider_id="test_provider",
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
]
|
||||
await table.update_registered_models("test_provider", provider_models)
|
||||
# Test that from_config models are registered in the registry
|
||||
all_models = await table.get_all_with_type("model")
|
||||
from_config_models = [m for m in all_models if m.source == RegistryEntrySource.from_config]
|
||||
assert len(from_config_models) == 2
|
||||
|
||||
# Verify user model with alias is preserved, but provider added new model
|
||||
models = await table.list_models()
|
||||
assert len(models.data) == 2
|
||||
model_identifiers = {m.identifier for m in from_config_models}
|
||||
assert "from_config_model_1" in model_identifiers
|
||||
assert "from_config_model_2" in model_identifiers
|
||||
|
||||
# Find the user model and provider model
|
||||
user_model = next((m for m in models.data if m.identifier == "my-custom-alias"), None)
|
||||
provider_model = next((m for m in models.data if m.identifier == "test_provider/different-model"), None)
|
||||
|
||||
assert user_model is not None
|
||||
assert user_model.source == RegistryEntrySource.via_register_api
|
||||
assert user_model.provider_resource_id == "provider-model-1"
|
||||
|
||||
assert provider_model is not None
|
||||
assert provider_model.source == RegistryEntrySource.listed_from_provider
|
||||
assert provider_model.provider_resource_id == "different-model"
|
||||
# Test that from_config models have correct attributes
|
||||
model_1 = next(m for m in from_config_models if m.identifier == "from_config_model_1")
|
||||
assert model_1.provider_id == "test_provider"
|
||||
assert model_1.provider_resource_id == "gpt-3.5-turbo"
|
||||
assert model_1.model_type == ModelType.llm
|
||||
assert model_1.source == RegistryEntrySource.from_config
|
||||
|
||||
# Cleanup
|
||||
await table.shutdown()
|
||||
|
||||
|
||||
async def test_models_source_interaction_cleanup_provider_models(cached_disk_dist_registry):
|
||||
"""Test that provider refresh removes old provider models but keeps default ones."""
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
async def test_models_dynamic_from_config_lookup(cached_disk_dist_registry):
|
||||
"""Test that from_config models can be looked up individually."""
|
||||
table = ModelsRoutingTable({}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register a user model
|
||||
await table.register_model(model_id="user-model", provider_id="test_provider")
|
||||
# Create a run config with from_config models
|
||||
run_config = StackRunConfig(
|
||||
image_name="test",
|
||||
providers={},
|
||||
models=[
|
||||
{
|
||||
"model_id": "lookup_test_model",
|
||||
"provider_id": "test_provider",
|
||||
"model_type": "llm",
|
||||
"provider_model_id": "gpt-3.5-turbo",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
# Add some provider models
|
||||
provider_models_v1 = [
|
||||
Model(
|
||||
identifier="provider-model-old",
|
||||
provider_resource_id="provider-model-old",
|
||||
provider_id="test_provider",
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
]
|
||||
await table.update_registered_models("test_provider", provider_models_v1)
|
||||
# Set the run config
|
||||
table.current_run_config = run_config
|
||||
await table.cleanup_disabled_provider_models()
|
||||
await table.register_from_config_models()
|
||||
|
||||
# Verify we have both user and provider models
|
||||
models = await table.list_models()
|
||||
assert len(models.data) == 2
|
||||
|
||||
# Now update with new provider models (should remove old provider models)
|
||||
provider_models_v2 = [
|
||||
Model(
|
||||
identifier="provider-model-new",
|
||||
provider_resource_id="provider-model-new",
|
||||
provider_id="test_provider",
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
]
|
||||
await table.update_registered_models("test_provider", provider_models_v2)
|
||||
|
||||
# Should have user model + new provider model, old provider model gone
|
||||
models = await table.list_models()
|
||||
assert len(models.data) == 2
|
||||
|
||||
identifiers = {m.identifier for m in models.data}
|
||||
assert "test_provider/user-model" in identifiers # User model preserved
|
||||
assert "test_provider/provider-model-new" in identifiers # New provider model (uses provider's identifier)
|
||||
assert "test_provider/provider-model-old" not in identifiers # Old provider model removed
|
||||
|
||||
# Verify sources are correct
|
||||
user_model = next((m for m in models.data if m.identifier == "test_provider/user-model"), None)
|
||||
provider_model = next((m for m in models.data if m.identifier == "test_provider/provider-model-new"), None)
|
||||
|
||||
assert user_model.source == RegistryEntrySource.via_register_api
|
||||
assert provider_model.source == RegistryEntrySource.listed_from_provider
|
||||
# Test that we can get the from_config model individually
|
||||
model = await table.get_model("lookup_test_model")
|
||||
assert model is not None
|
||||
assert model.identifier == "lookup_test_model"
|
||||
assert model.provider_id == "test_provider"
|
||||
assert model.provider_resource_id == "gpt-3.5-turbo"
|
||||
assert model.source == RegistryEntrySource.from_config
|
||||
|
||||
# Cleanup
|
||||
await table.shutdown()
|
||||
|
||||
|
||||
async def test_models_dynamic_from_config_mixed_with_persistent(cached_disk_dist_registry):
|
||||
"""Test that from_config models work alongside persistent models."""
|
||||
table = ModelsRoutingTable({}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Create a run config with from_config models
|
||||
run_config = StackRunConfig(
|
||||
image_name="test",
|
||||
providers={},
|
||||
models=[
|
||||
{
|
||||
"model_id": "from_config_model",
|
||||
"provider_id": "test_provider",
|
||||
"model_type": "llm",
|
||||
"provider_model_id": "gpt-3.5-turbo",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
# Set the run config
|
||||
table.current_run_config = run_config
|
||||
await table.cleanup_disabled_provider_models()
|
||||
await table.register_from_config_models()
|
||||
|
||||
# Test that from_config models are included
|
||||
models = await table.list_models()
|
||||
from_config_models = [m for m in models.data if m.source == RegistryEntrySource.from_config]
|
||||
|
||||
assert len(from_config_models) == 1
|
||||
assert from_config_models[0].identifier == "from_config_model"
|
||||
|
||||
# Test that we can get the from_config model individually
|
||||
from_config_model = await table.get_model("from_config_model")
|
||||
assert from_config_model is not None
|
||||
assert from_config_model.source == RegistryEntrySource.from_config
|
||||
|
||||
# Cleanup
|
||||
await table.shutdown()
|
||||
|
||||
|
||||
async def test_models_dynamic_from_config_disabled_providers(cached_disk_dist_registry):
|
||||
"""Test that from_config models with disabled providers are skipped."""
|
||||
table = ModelsRoutingTable({}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Create a run config with disabled provider models
|
||||
run_config = StackRunConfig(
|
||||
image_name="test",
|
||||
providers={},
|
||||
models=[
|
||||
{
|
||||
"model_id": "enabled_model",
|
||||
"provider_id": "test_provider",
|
||||
"model_type": "llm",
|
||||
"provider_model_id": "gpt-3.5-turbo",
|
||||
},
|
||||
{
|
||||
"model_id": "disabled_model",
|
||||
"provider_id": "__disabled__",
|
||||
"model_type": "llm",
|
||||
"provider_model_id": "gpt-4",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# Set the run config
|
||||
table.current_run_config = run_config
|
||||
await table.cleanup_disabled_provider_models()
|
||||
await table.register_from_config_models()
|
||||
|
||||
# Test that only enabled models are included
|
||||
all_models = await table.get_all_with_type("model")
|
||||
from_config_models = [m for m in all_models if m.source == RegistryEntrySource.from_config]
|
||||
assert len(from_config_models) == 1
|
||||
assert from_config_models[0].identifier == "enabled_model"
|
||||
|
||||
# Cleanup
|
||||
await table.shutdown()
|
||||
|
||||
|
||||
async def test_models_dynamic_from_config_no_run_config(cached_disk_dist_registry):
|
||||
"""Test that from_config models work when no run_config is set."""
|
||||
table = ModelsRoutingTable({}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Test that list_models works without run_config
|
||||
models = await table.list_models()
|
||||
from_config_models = [m for m in models.data if m.source == RegistryEntrySource.from_config]
|
||||
assert len(from_config_models) == 0 # No from_config models when no run_config
|
||||
|
||||
# Cleanup
|
||||
await table.shutdown()
|
||||
|
||||
|
||||
async def test_models_filter_persistent_models_from_removed_providers(cached_disk_dist_registry):
|
||||
"""Test that models from removed providers are filtered out from persistent models."""
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.datatypes import ModelWithOwner, Provider, RegistryEntrySource, StackRunConfig
|
||||
from llama_stack.core.routing_tables.models import ModelsRoutingTable
|
||||
|
||||
# Create a routing table
|
||||
table = ModelsRoutingTable({}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Create some mock persistent models
|
||||
model1 = ModelWithOwner(
|
||||
identifier="test_provider_1/model1",
|
||||
provider_resource_id="model1",
|
||||
provider_id="test_provider_1",
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
source=RegistryEntrySource.listed_from_provider,
|
||||
)
|
||||
model2 = ModelWithOwner(
|
||||
identifier="test_provider_2/model2",
|
||||
provider_resource_id="model2",
|
||||
provider_id="test_provider_2",
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
source=RegistryEntrySource.listed_from_provider,
|
||||
)
|
||||
user_model = ModelWithOwner(
|
||||
identifier="user_model",
|
||||
provider_resource_id="user_model",
|
||||
provider_id="test_provider_1",
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
source=RegistryEntrySource.via_register_api,
|
||||
)
|
||||
|
||||
# Create a run config that only includes test_provider_1 (test_provider_2 is removed)
|
||||
run_config = StackRunConfig(
|
||||
image_name="test",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(provider_id="test_provider_1", provider_type="openai", config={"api_key": "test_key"}),
|
||||
# test_provider_2 is removed from run.yaml
|
||||
]
|
||||
},
|
||||
models=[],
|
||||
)
|
||||
|
||||
# Set the run config
|
||||
table.current_run_config = run_config
|
||||
await table.cleanup_disabled_provider_models()
|
||||
await table.register_from_config_models()
|
||||
|
||||
# Test the cleanup logic directly
|
||||
# First, manually add models to the registry to simulate existing models
|
||||
await table.dist_registry.register(model1)
|
||||
await table.dist_registry.register(model2)
|
||||
await table.dist_registry.register(user_model)
|
||||
|
||||
# Now set the run config which should trigger cleanup
|
||||
table.current_run_config = run_config
|
||||
await table.cleanup_disabled_provider_models()
|
||||
await table.register_from_config_models()
|
||||
|
||||
# Get the list of models after cleanup
|
||||
response = await table.list_models()
|
||||
model_identifiers = {m.identifier for m in response.data}
|
||||
|
||||
# Should have user_model (user-registered) and model1 (from enabled provider), but not model2 (from disabled provider)
|
||||
# model1 should be kept because test_provider_1 is in the run config (enabled)
|
||||
# model2 should be removed because test_provider_2 is not in the run config (disabled)
|
||||
# user_model should be kept because it's user-registered
|
||||
assert "user_model" in model_identifiers
|
||||
assert "test_provider_1/model1" in model_identifiers
|
||||
assert "test_provider_2/model2" not in model_identifiers
|
||||
|
||||
# Test that user-registered models are always kept regardless of provider status
|
||||
user_model_found = next((m for m in response.data if m.identifier == "user_model"), None)
|
||||
assert user_model_found is not None
|
||||
assert user_model_found.source == RegistryEntrySource.via_register_api
|
||||
|
||||
# Cleanup
|
||||
await table.shutdown()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue