This commit is contained in:
Ashwin Bharambe 2025-07-25 10:58:36 -07:00
parent f6ba8a123d
commit 145da06fdf
4 changed files with 7 additions and 7 deletions

View file

@ -206,7 +206,7 @@ inference_store:
models: [] models: []
shields: shields:
- shield_id: llama-guard - shield_id: llama-guard
provider_id: ${env.SAFETY_MODEL:+inline::llama-guard} provider_id: ${env.SAFETY_MODEL:+llama-guard}
provider_shield_id: ${env.SAFETY_MODEL:=} provider_shield_id: ${env.SAFETY_MODEL:=}
vector_dbs: [] vector_dbs: []
datasets: [] datasets: []

View file

@ -206,7 +206,7 @@ inference_store:
models: [] models: []
shields: shields:
- shield_id: llama-guard - shield_id: llama-guard
provider_id: ${env.SAFETY_MODEL:+inline::llama-guard} provider_id: ${env.SAFETY_MODEL:+llama-guard}
provider_shield_id: ${env.SAFETY_MODEL:=} provider_shield_id: ${env.SAFETY_MODEL:=}
vector_dbs: [] vector_dbs: []
datasets: [] datasets: []

View file

@ -161,7 +161,7 @@ def get_distribution_template() -> DistributionTemplate:
# if the # if the
ShieldInput( ShieldInput(
shield_id="llama-guard", shield_id="llama-guard",
provider_id="${env.SAFETY_MODEL:+inline::llama-guard}", provider_id="${env.SAFETY_MODEL:+llama-guard}",
provider_shield_id="${env.SAFETY_MODEL:=}", provider_shield_id="${env.SAFETY_MODEL:=}",
), ),
] ]

View file

@ -502,7 +502,7 @@ async def test_models_source_interaction_preserves_default(cached_disk_dist_regi
# Find the user model and provider model # Find the user model and provider model
user_model = next((m for m in models.data if m.identifier == "my-custom-alias"), None) user_model = next((m for m in models.data if m.identifier == "my-custom-alias"), None)
provider_model = next((m for m in models.data if m.identifier == "different-model"), None) provider_model = next((m for m in models.data if m.identifier == "test_provider/different-model"), None)
assert user_model is not None assert user_model is not None
assert user_model.source == RegistryEntrySource.via_register_api assert user_model.source == RegistryEntrySource.via_register_api
@ -558,12 +558,12 @@ async def test_models_source_interaction_cleanup_provider_models(cached_disk_dis
identifiers = {m.identifier for m in models.data} identifiers = {m.identifier for m in models.data}
assert "test_provider/user-model" in identifiers # User model preserved assert "test_provider/user-model" in identifiers # User model preserved
assert "provider-model-new" in identifiers # New provider model (uses provider's identifier) assert "test_provider/provider-model-new" in identifiers # New provider model (uses provider's identifier)
assert "provider-model-old" not in identifiers # Old provider model removed assert "test_provider/provider-model-old" not in identifiers # Old provider model removed
# Verify sources are correct # Verify sources are correct
user_model = next((m for m in models.data if m.identifier == "test_provider/user-model"), None) user_model = next((m for m in models.data if m.identifier == "test_provider/user-model"), None)
provider_model = next((m for m in models.data if m.identifier == "provider-model-new"), None) provider_model = next((m for m in models.data if m.identifier == "test_provider/provider-model-new"), None)
assert user_model.source == RegistryEntrySource.via_register_api assert user_model.source == RegistryEntrySource.via_register_api
assert provider_model.source == RegistryEntrySource.listed_from_provider assert provider_model.source == RegistryEntrySource.listed_from_provider