Kill the notion of shield_type

This commit is contained in:
Ashwin Bharambe 2024-11-12 11:41:23 -08:00
parent 09269e2a44
commit b1c3a95485
20 changed files with 87 additions and 161 deletions

View file

@ -81,15 +81,17 @@ async def create_agent_session(agents_impl, agent_config):
class TestAgents:
@pytest.mark.asyncio
async def test_agent_turns_with_safety(self, agents_stack, common_params):
async def test_agent_turns_with_safety(
self, safety_model, agents_stack, common_params
):
agents_impl, _ = agents_stack
agent_id, session_id = await create_agent_session(
agents_impl,
AgentConfig(
**{
**common_params,
"input_shields": ["llama_guard"],
"output_shields": ["llama_guard"],
"input_shields": [safety_model],
"output_shields": [safety_model],
}
),
)

View file

@ -9,7 +9,7 @@ import pytest_asyncio
from llama_stack.apis.models import ModelInput
from llama_stack.apis.shields import ShieldInput, ShieldType
from llama_stack.apis.shields import ShieldInput
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
@ -41,7 +41,7 @@ def safety_llama_guard(safety_model) -> ProviderFixture:
Provider(
provider_id="inline::llama-guard",
provider_type="inline::llama-guard",
config=LlamaGuardConfig(model=safety_model).model_dump(),
config=LlamaGuardConfig().model_dump(),
)
],
)
@ -114,20 +114,14 @@ async def safety_stack(inference_model, safety_model, request):
def get_shield_to_register(provider_type: str, safety_model: str) -> ShieldInput:
shield_config = {}
shield_type = ShieldType.llama_guard
identifier = "llama_guard"
if provider_type == "meta-reference":
shield_config["model"] = safety_model
elif provider_type == "remote::together":
shield_config["model"] = safety_model
elif provider_type == "remote::bedrock":
if provider_type == "remote::bedrock":
identifier = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER")
shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")
shield_type = ShieldType.generic_content_shield
params = {"guardrailVersion": get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")}
else:
params = {}
identifier = safety_model
return ShieldInput(
shield_id=identifier,
shield_type=shield_type,
params=shield_config,
params=params,
)

View file

@ -34,7 +34,6 @@ class TestSafety:
for shield in response:
assert isinstance(shield, Shield)
assert shield.shield_type in [v for v in ShieldType]
@pytest.mark.asyncio
async def test_run_shield(self, safety_stack):