diff --git a/llama_stack/templates/ci-tests/run.yaml b/llama_stack/templates/ci-tests/run.yaml index 4a9baaf90..2a1270107 100644 --- a/llama_stack/templates/ci-tests/run.yaml +++ b/llama_stack/templates/ci-tests/run.yaml @@ -206,7 +206,7 @@ inference_store: models: [] shields: - shield_id: llama-guard - provider_id: ${env.SAFETY_MODEL:+inline::llama-guard} + provider_id: ${env.SAFETY_MODEL:+llama-guard} provider_shield_id: ${env.SAFETY_MODEL:=} vector_dbs: [] datasets: [] diff --git a/llama_stack/templates/starter/run.yaml b/llama_stack/templates/starter/run.yaml index bc38387c9..40e43cde9 100644 --- a/llama_stack/templates/starter/run.yaml +++ b/llama_stack/templates/starter/run.yaml @@ -206,7 +206,7 @@ inference_store: models: [] shields: - shield_id: llama-guard - provider_id: ${env.SAFETY_MODEL:+inline::llama-guard} + provider_id: ${env.SAFETY_MODEL:+llama-guard} provider_shield_id: ${env.SAFETY_MODEL:=} vector_dbs: [] datasets: [] diff --git a/llama_stack/templates/starter/starter.py b/llama_stack/templates/starter/starter.py index 4931c6a42..d0782797f 100644 --- a/llama_stack/templates/starter/starter.py +++ b/llama_stack/templates/starter/starter.py @@ -161,7 +161,7 @@ def get_distribution_template() -> DistributionTemplate: # if the ShieldInput( shield_id="llama-guard", - provider_id="${env.SAFETY_MODEL:+inline::llama-guard}", + provider_id="${env.SAFETY_MODEL:+llama-guard}", provider_shield_id="${env.SAFETY_MODEL:=}", ), ] diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index c1b57cb4f..308b5c28f 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -502,7 +502,7 @@ async def test_models_source_interaction_preserves_default(cached_disk_dist_regi # Find the user model and provider model user_model = next((m for m in models.data if m.identifier == "my-custom-alias"), None) - provider_model = next((m for m in models.data if m.identifier == "different-model"), None) + provider_model = next((m for m in models.data if m.identifier == "test_provider/different-model"), None) assert user_model is not None assert user_model.source == RegistryEntrySource.via_register_api @@ -558,12 +558,12 @@ async def test_models_source_interaction_cleanup_provider_models(cached_disk_dis identifiers = {m.identifier for m in models.data} assert "test_provider/user-model" in identifiers # User model preserved - assert "provider-model-new" in identifiers # New provider model (uses provider's identifier) - assert "provider-model-old" not in identifiers # Old provider model removed + assert "test_provider/provider-model-new" in identifiers # New provider model (uses provider's identifier) + assert "test_provider/provider-model-old" not in identifiers # Old provider model removed # Verify sources are correct user_model = next((m for m in models.data if m.identifier == "test_provider/user-model"), None) - provider_model = next((m for m in models.data if m.identifier == "provider-model-new"), None) + provider_model = next((m for m in models.data if m.identifier == "test_provider/provider-model-new"), None) assert user_model.source == RegistryEntrySource.via_register_api assert provider_model.source == RegistryEntrySource.listed_from_provider