From 15ffceb533702a144fae060c69d4e8133fe3f707 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 11 Nov 2024 09:09:47 -0800 Subject: [PATCH] more fixes (some fixes to pre-existing issues in safety fixture) --- .../safety/code_scanner/code_scanner.py | 2 +- .../inline/safety/llama_guard/llama_guard.py | 3 ++- .../safety/prompt_guard/prompt_guard.py | 2 +- .../remote/inference/bedrock/__init__.py | 3 ++- .../providers/tests/inference/fixtures.py | 5 +++-- .../providers/tests/safety/fixtures.py | 19 ++++++++++++++++--- 6 files changed, 25 insertions(+), 9 deletions(-) diff --git a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py index 36ad60b8e..1ca65c9bb 100644 --- a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py +++ b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py @@ -25,7 +25,7 @@ class MetaReferenceCodeScannerSafetyImpl(Safety): pass async def register_shield(self, shield: Shield) -> None: - if shield.shield_type != ShieldType.code_scanner.value: + if shield.shield_type != ShieldType.code_scanner: raise ValueError(f"Unsupported safety shield type: {shield.shield_type}") async def run_shield( diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index 25f1805b9..9c3ec7750 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -128,7 +128,8 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): pass async def register_shield(self, shield: Shield) -> None: - if shield.shield_type != ShieldType.llama_guard.value: + print(f"Registering shield {shield}") + if shield.shield_type != ShieldType.llama_guard: raise ValueError(f"Unsupported shield type: {shield.shield_type}") async def run_shield( diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index 7754607ec..20bfdd241 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -36,7 +36,7 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate): pass async def register_shield(self, shield: Shield) -> None: - if shield.shield_type != ShieldType.prompt_guard.value: + if shield.shield_type != ShieldType.prompt_guard: raise ValueError(f"Unsupported shield type: {shield.shield_type}") async def run_shield( diff --git a/llama_stack/providers/remote/inference/bedrock/__init__.py b/llama_stack/providers/remote/inference/bedrock/__init__.py index a38af374a..e72c6ada9 100644 --- a/llama_stack/providers/remote/inference/bedrock/__init__.py +++ b/llama_stack/providers/remote/inference/bedrock/__init__.py @@ -3,11 +3,12 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .bedrock import BedrockInferenceAdapter from .config import BedrockConfig async def get_adapter_impl(config: BedrockConfig, _deps): + from .bedrock import BedrockInferenceAdapter + assert isinstance(config, BedrockConfig), f"Unexpected config type: {type(config)}" impl = BedrockInferenceAdapter(config) diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index b2c6d3a5e..d91337998 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -65,7 +65,6 @@ def inference_ollama(inference_model) -> ProviderFixture: inference_model = ( [inference_model] if isinstance(inference_model, str) else inference_model ) - print("!!!", inference_model) if "Llama3.1-8B-Instruct" in inference_model: pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing") @@ -162,9 +161,11 @@ async def inference_stack(request, inference_model): inference_fixture.provider_data, ) + provider_id = inference_fixture.providers[0].provider_id + print(f"Registering model {inference_model} with provider {provider_id}") await impls[Api.models].register_model( model_id=inference_model, - provider_model_id=inference_fixture.providers[0].provider_id, + provider_id=provider_id, ) return (impls[Api.inference], impls[Api.models]) diff --git a/llama_stack/providers/tests/safety/fixtures.py b/llama_stack/providers/tests/safety/fixtures.py index c2beff7e2..5b4c07de5 100644 --- a/llama_stack/providers/tests/safety/fixtures.py +++ b/llama_stack/providers/tests/safety/fixtures.py @@ -16,6 +16,7 @@ from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 from ..conftest import ProviderFixture, remote_stack_fixture +from ..env import get_env_or_fail @pytest.fixture(scope="session") @@ -89,7 +90,21 @@ async def safety_stack(inference_model, safety_model, request): # Register the appropriate shield based on provider type provider_type = safety_fixture.providers[0].provider_type + shield = await create_and_register_shield(provider_type, safety_model, shields_impl) + provider_id = inference_fixture.providers[0].provider_id + print(f"Registering model {inference_model} with provider {provider_id}") + await impls[Api.models].register_model( + model_id=inference_model, + provider_id=provider_id, + ) + + return safety_impl, shields_impl, shield + + +async def create_and_register_shield( + provider_type: str, safety_model: str, shields_impl +): shield_config = {} shield_type = ShieldType.llama_guard identifier = "llama_guard" @@ -102,10 +117,8 @@ async def safety_stack(inference_model, safety_model, request): shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION") shield_type = ShieldType.generic_content_shield - shield = await shields_impl.register_shield( + return await shields_impl.register_shield( shield_id=identifier, shield_type=shield_type, params=shield_config, ) - - return safety_impl, shields_impl, shield