working bedrock tests

This commit is contained in:
Dinesh Yeduguru 2024-11-07 14:57:12 -08:00
parent d960f9b60f
commit e0f227f23c
4 changed files with 64 additions and 15 deletions

View file

@ -7,13 +7,14 @@
import pytest
import pytest_asyncio
from llama_stack.apis.shields import Shield, ShieldType
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.adapters.safety.bedrock import BedrockSafetyConfig
from llama_stack.providers.inline.safety.meta_reference import (
LlamaGuardShieldConfig,
SafetyConfig,
)
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from ..conftest import ProviderFixture, remote_stack_fixture
@ -104,4 +105,32 @@ async def safety_stack(inference_model, safety_model, request):
providers,
provider_data,
)
return impls[Api.safety], impls[Api.shields]
safety_impl = impls[Api.safety]
shields_impl = impls[Api.shields]
# Register the appropriate shield based on provider type
provider_id = safety_fixture.providers[0].provider_id
provider_type = safety_fixture.providers[0].provider_type
shield_config = {}
identifier = "llama_guard"
if provider_type == "meta-reference":
shield_config["model"] = safety_model
elif provider_type == "remote::together":
shield_config["model"] = safety_model
elif provider_type == "remote::bedrock":
identifier = request.config.getoption("--bedrock-guardrail-id", None)
shield_config["guardrailVersion"] = request.config.getoption(
"--bedrock-guardrail-version", None
)
# Create shield
shield = Shield(
identifier=identifier,
shield_type=ShieldType.llama_guard,
provider_id=provider_id,
params=shield_config,
)
return safety_impl, shields_impl, shield