From 932b52444965214b1bed3fbb131e54f33429bd27 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Thu, 7 Nov 2024 20:15:42 -0800 Subject: [PATCH] use env vars for bedrock guardrail vars --- llama_stack/providers/tests/safety/conftest.py | 12 ------------ llama_stack/providers/tests/safety/fixtures.py | 7 +++---- 2 files changed, 3 insertions(+), 16 deletions(-) diff --git a/llama_stack/providers/tests/safety/conftest.py b/llama_stack/providers/tests/safety/conftest.py index 252187251..daf16aefc 100644 --- a/llama_stack/providers/tests/safety/conftest.py +++ b/llama_stack/providers/tests/safety/conftest.py @@ -71,18 +71,6 @@ def pytest_addoption(parser): default=None, help="Specify the safety model to use for testing", ) - parser.addoption( - "--bedrock-guardrail-id", - action="store", - default=None, - help="Specify the guard rail ID to use for testing bedrock", - ) - parser.addoption( - "--bedrock-guardrail-version", - action="store", - default=None, - help="Specify the guard rail version to use for testing bedrock", - ) SAFETY_MODEL_PARAMS = [ diff --git a/llama_stack/providers/tests/safety/fixtures.py b/llama_stack/providers/tests/safety/fixtures.py index 3a374815f..40c89de92 100644 --- a/llama_stack/providers/tests/safety/fixtures.py +++ b/llama_stack/providers/tests/safety/fixtures.py @@ -15,6 +15,7 @@ from llama_stack.providers.inline.safety.meta_reference import ( SafetyConfig, ) from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig +from llama_stack.providers.tests.env import get_env_or_fail from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 from ..conftest import ProviderFixture, remote_stack_fixture @@ -120,10 +121,8 @@ async def safety_stack(inference_model, safety_model, request): elif provider_type == "remote::together": shield_config["model"] = safety_model elif provider_type == "remote::bedrock": - identifier = request.config.getoption("--bedrock-guardrail-id", None) - shield_config["guardrailVersion"] = request.config.getoption( - "--bedrock-guardrail-version", None - ) + identifier = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER") + shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION") # Create shield shield = Shield(