more fixes (some fixes to pre-existing issues in safety fixture)

This commit is contained in:
Ashwin Bharambe 2024-11-11 09:09:47 -08:00
parent 7507cd487f
commit 15ffceb533
6 changed files with 25 additions and 9 deletions

View file

@ -16,6 +16,7 @@ from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail
@pytest.fixture(scope="session")
@ -89,7 +90,21 @@ async def safety_stack(inference_model, safety_model, request):
# Register the appropriate shield based on provider type
provider_type = safety_fixture.providers[0].provider_type
shield = await create_and_register_shield(provider_type, safety_model, shields_impl)
provider_id = inference_fixture.providers[0].provider_id
print(f"Registering model {inference_model} with provider {provider_id}")
await impls[Api.models].register_model(
model_id=inference_model,
provider_id=provider_id,
)
return safety_impl, shields_impl, shield
async def create_and_register_shield(
provider_type: str, safety_model: str, shields_impl
):
shield_config = {}
shield_type = ShieldType.llama_guard
identifier = "llama_guard"
@ -102,10 +117,8 @@ async def safety_stack(inference_model, safety_model, request):
shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")
shield_type = ShieldType.generic_content_shield
shield = await shields_impl.register_shield(
return await shields_impl.register_shield(
shield_id=identifier,
shield_type=shield_type,
params=shield_config,
)
return safety_impl, shields_impl, shield