mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-18 15:27:16 +00:00
Kill the notion of shield_type
This commit is contained in:
parent
09269e2a44
commit
b1c3a95485
20 changed files with 87 additions and 161 deletions
|
@ -9,7 +9,7 @@ import pytest_asyncio
|
|||
|
||||
from llama_stack.apis.models import ModelInput
|
||||
|
||||
from llama_stack.apis.shields import ShieldInput, ShieldType
|
||||
from llama_stack.apis.shields import ShieldInput
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
|
||||
|
@ -41,7 +41,7 @@ def safety_llama_guard(safety_model) -> ProviderFixture:
|
|||
Provider(
|
||||
provider_id="inline::llama-guard",
|
||||
provider_type="inline::llama-guard",
|
||||
config=LlamaGuardConfig(model=safety_model).model_dump(),
|
||||
config=LlamaGuardConfig().model_dump(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
@ -114,20 +114,14 @@ async def safety_stack(inference_model, safety_model, request):
|
|||
|
||||
|
||||
def get_shield_to_register(provider_type: str, safety_model: str) -> ShieldInput:
|
||||
shield_config = {}
|
||||
shield_type = ShieldType.llama_guard
|
||||
identifier = "llama_guard"
|
||||
if provider_type == "meta-reference":
|
||||
shield_config["model"] = safety_model
|
||||
elif provider_type == "remote::together":
|
||||
shield_config["model"] = safety_model
|
||||
elif provider_type == "remote::bedrock":
|
||||
if provider_type == "remote::bedrock":
|
||||
identifier = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER")
|
||||
shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")
|
||||
shield_type = ShieldType.generic_content_shield
|
||||
params = {"guardrailVersion": get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")}
|
||||
else:
|
||||
params = {}
|
||||
identifier = safety_model
|
||||
|
||||
return ShieldInput(
|
||||
shield_id=identifier,
|
||||
shield_type=shield_type,
|
||||
params=shield_config,
|
||||
params=params,
|
||||
)
|
||||
|
|
|
@ -34,7 +34,6 @@ class TestSafety:
|
|||
|
||||
for shield in response:
|
||||
assert isinstance(shield, Shield)
|
||||
assert shield.shield_type in [v for v in ShieldType]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_shield(self, safety_stack):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue