forked from phoenix-oss/llama-stack-mirror
Resource oriented design for shields (#399)
* init * working bedrock tests * bedrock test for inference fixes * use env vars for bedrock guardrail vars * add register in meta reference * use correct shield impl in meta ref * dont add together fixture * right naming * minor updates * improved registration flow * address feedback --------- Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
parent
7ee9f8d8ac
commit
d800a16acd
20 changed files with 262 additions and 124 deletions
|
@ -7,12 +7,15 @@
|
|||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.shields import ShieldType
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
from llama_stack.providers.inline.safety.meta_reference import (
|
||||
LlamaGuardShieldConfig,
|
||||
SafetyConfig,
|
||||
)
|
||||
|
||||
from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
|
||||
from llama_stack.providers.tests.env import get_env_or_fail
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
|
@ -47,7 +50,20 @@ def safety_meta_reference(safety_model) -> ProviderFixture:
|
|||
)
|
||||
|
||||
|
||||
SAFETY_FIXTURES = ["meta_reference", "remote"]
|
||||
@pytest.fixture(scope="session")
|
||||
def safety_bedrock() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="bedrock",
|
||||
provider_type="remote::bedrock",
|
||||
config=BedrockSafetyConfig().model_dump(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
SAFETY_FIXTURES = ["meta_reference", "bedrock", "remote"]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
|
@ -74,4 +90,29 @@ async def safety_stack(inference_model, safety_model, request):
|
|||
providers,
|
||||
provider_data,
|
||||
)
|
||||
return impls[Api.safety], impls[Api.shields]
|
||||
|
||||
safety_impl = impls[Api.safety]
|
||||
shields_impl = impls[Api.shields]
|
||||
|
||||
# Register the appropriate shield based on provider type
|
||||
provider_type = safety_fixture.providers[0].provider_type
|
||||
|
||||
shield_config = {}
|
||||
shield_type = ShieldType.llama_guard
|
||||
identifier = "llama_guard"
|
||||
if provider_type == "meta-reference":
|
||||
shield_config["model"] = safety_model
|
||||
elif provider_type == "remote::together":
|
||||
shield_config["model"] = safety_model
|
||||
elif provider_type == "remote::bedrock":
|
||||
identifier = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER")
|
||||
shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")
|
||||
shield_type = ShieldType.generic_content_shield
|
||||
|
||||
shield = await shields_impl.register_shield(
|
||||
shield_id=identifier,
|
||||
shield_type=shield_type,
|
||||
params=shield_config,
|
||||
)
|
||||
|
||||
return safety_impl, shields_impl, shield
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue