This commit is contained in:
Ashwin Bharambe 2025-07-25 10:58:36 -07:00
parent f6ba8a123d
commit 145da06fdf
4 changed files with 7 additions and 7 deletions

View file

@ -206,7 +206,7 @@ inference_store:
models: []
shields:
- shield_id: llama-guard
provider_id: ${env.SAFETY_MODEL:+inline::llama-guard}
provider_id: ${env.SAFETY_MODEL:+llama-guard}
provider_shield_id: ${env.SAFETY_MODEL:=}
vector_dbs: []
datasets: []

View file

@ -206,7 +206,7 @@ inference_store:
models: []
shields:
- shield_id: llama-guard
provider_id: ${env.SAFETY_MODEL:+inline::llama-guard}
provider_id: ${env.SAFETY_MODEL:+llama-guard}
provider_shield_id: ${env.SAFETY_MODEL:=}
vector_dbs: []
datasets: []

View file

@ -161,7 +161,7 @@ def get_distribution_template() -> DistributionTemplate:
# if the
ShieldInput(
shield_id="llama-guard",
provider_id="${env.SAFETY_MODEL:+inline::llama-guard}",
provider_id="${env.SAFETY_MODEL:+llama-guard}",
provider_shield_id="${env.SAFETY_MODEL:=}",
),
]

View file

@ -502,7 +502,7 @@ async def test_models_source_interaction_preserves_default(cached_disk_dist_regi
# Find the user model and provider model
user_model = next((m for m in models.data if m.identifier == "my-custom-alias"), None)
provider_model = next((m for m in models.data if m.identifier == "different-model"), None)
provider_model = next((m for m in models.data if m.identifier == "test_provider/different-model"), None)
assert user_model is not None
assert user_model.source == RegistryEntrySource.via_register_api
@ -558,12 +558,12 @@ async def test_models_source_interaction_cleanup_provider_models(cached_disk_dis
identifiers = {m.identifier for m in models.data}
assert "test_provider/user-model" in identifiers # User model preserved
assert "provider-model-new" in identifiers # New provider model (uses provider's identifier)
assert "provider-model-old" not in identifiers # Old provider model removed
assert "test_provider/provider-model-new" in identifiers # New provider model (uses provider's identifier)
assert "test_provider/provider-model-old" not in identifiers # Old provider model removed
# Verify sources are correct
user_model = next((m for m in models.data if m.identifier == "test_provider/user-model"), None)
provider_model = next((m for m in models.data if m.identifier == "provider-model-new"), None)
provider_model = next((m for m in models.data if m.identifier == "test_provider/provider-model-new"), None)
assert user_model.source == RegistryEntrySource.via_register_api
assert provider_model.source == RegistryEntrySource.listed_from_provider