more fixes (some fixes to pre-existing issues in safety fixture)

This commit is contained in:
Ashwin Bharambe 2024-11-11 09:09:47 -08:00
parent 7507cd487f
commit 15ffceb533
6 changed files with 25 additions and 9 deletions

View file

@ -25,7 +25,7 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
pass
async def register_shield(self, shield: Shield) -> None:
if shield.shield_type != ShieldType.code_scanner.value:
if shield.shield_type != ShieldType.code_scanner:
raise ValueError(f"Unsupported safety shield type: {shield.shield_type}")
async def run_shield(

View file

@ -128,7 +128,8 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
pass
async def register_shield(self, shield: Shield) -> None:
if shield.shield_type != ShieldType.llama_guard.value:
print(f"Registering shield {shield}")
if shield.shield_type != ShieldType.llama_guard:
raise ValueError(f"Unsupported shield type: {shield.shield_type}")
async def run_shield(

View file

@ -36,7 +36,7 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
pass
async def register_shield(self, shield: Shield) -> None:
if shield.shield_type != ShieldType.prompt_guard.value:
if shield.shield_type != ShieldType.prompt_guard:
raise ValueError(f"Unsupported shield type: {shield.shield_type}")
async def run_shield(

View file

@ -3,11 +3,12 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .bedrock import BedrockInferenceAdapter
from .config import BedrockConfig
async def get_adapter_impl(config: BedrockConfig, _deps):
from .bedrock import BedrockInferenceAdapter
assert isinstance(config, BedrockConfig), f"Unexpected config type: {type(config)}"
impl = BedrockInferenceAdapter(config)

View file

@ -65,7 +65,6 @@ def inference_ollama(inference_model) -> ProviderFixture:
inference_model = (
[inference_model] if isinstance(inference_model, str) else inference_model
)
print("!!!", inference_model)
if "Llama3.1-8B-Instruct" in inference_model:
pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing")
@ -162,9 +161,11 @@ async def inference_stack(request, inference_model):
inference_fixture.provider_data,
)
provider_id = inference_fixture.providers[0].provider_id
print(f"Registering model {inference_model} with provider {provider_id}")
await impls[Api.models].register_model(
model_id=inference_model,
provider_model_id=inference_fixture.providers[0].provider_id,
provider_id=provider_id,
)
return (impls[Api.inference], impls[Api.models])

View file

@ -16,6 +16,7 @@ from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail
@pytest.fixture(scope="session")
@ -89,7 +90,21 @@ async def safety_stack(inference_model, safety_model, request):
# Register the appropriate shield based on provider type
provider_type = safety_fixture.providers[0].provider_type
shield = await create_and_register_shield(provider_type, safety_model, shields_impl)
provider_id = inference_fixture.providers[0].provider_id
print(f"Registering model {inference_model} with provider {provider_id}")
await impls[Api.models].register_model(
model_id=inference_model,
provider_id=provider_id,
)
return safety_impl, shields_impl, shield
async def create_and_register_shield(
provider_type: str, safety_model: str, shields_impl
):
shield_config = {}
shield_type = ShieldType.llama_guard
identifier = "llama_guard"
@ -102,10 +117,8 @@ async def safety_stack(inference_model, safety_model, request):
shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")
shield_type = ShieldType.generic_content_shield
shield = await shields_impl.register_shield(
return await shields_impl.register_shield(
shield_id=identifier,
shield_type=shield_type,
params=shield_config,
)
return safety_impl, shields_impl, shield