Split safety into (llama-guard, prompt-guard, code-scanner) (#400)

Splits the meta-reference safety implementation into three distinct providers:

- inline::llama-guard
- inline::prompt-guard
- inline::code-scanner

Note that this PR is a backward incompatible change to the llama stack server. I have added deprecation_error field to ProviderSpec -- the server reads it and immediately barfs. This is used to direct the user with a specific message on what action to perform. An automagical "config upgrade" is a bit too much work to implement right now :/

(Note that we will be gradually prefixing all inline providers with inline:: -- I am only doing this for this set of new providers because otherwise existing configuration files will break even more badly.)
This commit is contained in:
Ashwin Bharambe 2024-11-11 09:29:18 -08:00 committed by GitHub
parent 6d38b1690b
commit c1f7ba3aed
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
47 changed files with 464 additions and 500 deletions

View file

@ -10,15 +10,14 @@ import pytest_asyncio
from llama_stack.apis.shields import ShieldType
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.safety.meta_reference import (
LlamaGuardShieldConfig,
SafetyConfig,
)
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig
from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
from llama_stack.providers.tests.env import get_env_or_fail
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail
@pytest.fixture(scope="session")
@ -34,17 +33,29 @@ def safety_model(request):
@pytest.fixture(scope="session")
def safety_meta_reference(safety_model) -> ProviderFixture:
def safety_llama_guard(safety_model) -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="meta-reference",
provider_type="meta-reference",
config=SafetyConfig(
llama_guard_shield=LlamaGuardShieldConfig(
model=safety_model,
),
).model_dump(),
provider_id="inline::llama-guard",
provider_type="inline::llama-guard",
config=LlamaGuardConfig(model=safety_model).model_dump(),
)
],
)
# TODO: this is not tested yet; we would need to configure the run_shield() test
# and parametrize it with the "prompt" for testing depending on the safety fixture
# we are using.
@pytest.fixture(scope="session")
def safety_prompt_guard() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="inline::prompt-guard",
provider_type="inline::prompt-guard",
config=PromptGuardConfig().model_dump(),
)
],
)
@ -63,7 +74,7 @@ def safety_bedrock() -> ProviderFixture:
)
SAFETY_FIXTURES = ["meta_reference", "bedrock", "remote"]
SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote"]
@pytest_asyncio.fixture(scope="session")
@ -96,7 +107,21 @@ async def safety_stack(inference_model, safety_model, request):
# Register the appropriate shield based on provider type
provider_type = safety_fixture.providers[0].provider_type
shield = await create_and_register_shield(provider_type, safety_model, shields_impl)
provider_id = inference_fixture.providers[0].provider_id
print(f"Registering model {inference_model} with provider {provider_id}")
await impls[Api.models].register_model(
model_id=inference_model,
provider_id=provider_id,
)
return safety_impl, shields_impl, shield
async def create_and_register_shield(
provider_type: str, safety_model: str, shields_impl
):
shield_config = {}
shield_type = ShieldType.llama_guard
identifier = "llama_guard"
@ -109,10 +134,8 @@ async def safety_stack(inference_model, safety_model, request):
shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")
shield_type = ShieldType.generic_content_shield
shield = await shields_impl.register_shield(
return await shields_impl.register_shield(
shield_id=identifier,
shield_type=shield_type,
params=shield_config,
)
return safety_impl, shields_impl, shield