mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-18 15:27:16 +00:00
Kill the notion of shield_type
This commit is contained in:
parent
09269e2a44
commit
b1c3a95485
20 changed files with 87 additions and 161 deletions
|
@ -14,6 +14,12 @@ from .config import CodeScannerConfig
|
|||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
|
||||
ALLOWED_CODE_SCANNER_MODEL_IDS = [
|
||||
"CodeScanner",
|
||||
"CodeShield",
|
||||
]
|
||||
|
||||
|
||||
class MetaReferenceCodeScannerSafetyImpl(Safety):
|
||||
def __init__(self, config: CodeScannerConfig, deps) -> None:
|
||||
self.config = config
|
||||
|
@ -25,8 +31,10 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
|||
pass
|
||||
|
||||
async def register_shield(self, shield: Shield) -> None:
|
||||
if shield.shield_type != ShieldType.code_scanner:
|
||||
raise ValueError(f"Unsupported safety shield type: {shield.shield_type}")
|
||||
if shield.provider_resource_id not in ALLOWED_CODE_SCANNER_MODEL_IDS:
|
||||
raise ValueError(
|
||||
f"Unsupported Code Scanner ID: {shield.provider_resource_id}. Allowed IDs: {ALLOWED_CODE_SCANNER_MODEL_IDS}"
|
||||
)
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
|
|
|
@ -6,32 +6,8 @@
|
|||
|
||||
from typing import List
|
||||
|
||||
from llama_models.sku_list import CoreModelId, safety_models
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LlamaGuardConfig(BaseModel):
|
||||
model: str = "Llama-Guard-3-1B"
|
||||
excluded_categories: List[str] = []
|
||||
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, model: str) -> str:
|
||||
permitted_models = [
|
||||
m.descriptor()
|
||||
for m in safety_models()
|
||||
if (
|
||||
m.core_model_id
|
||||
in {
|
||||
CoreModelId.llama_guard_3_8b,
|
||||
CoreModelId.llama_guard_3_1b,
|
||||
CoreModelId.llama_guard_3_11b_vision,
|
||||
}
|
||||
)
|
||||
]
|
||||
if model not in permitted_models:
|
||||
raise ValueError(
|
||||
f"Invalid model: {model}. Must be one of {permitted_models}"
|
||||
)
|
||||
return model
|
||||
|
|
|
@ -73,6 +73,11 @@ DEFAULT_LG_V3_SAFETY_CATEGORIES = [
|
|||
CAT_ELECTIONS,
|
||||
]
|
||||
|
||||
LLAMA_GUARD_MODEL_IDS = [
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
CoreModelId.llama_guard_3_1b.value,
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
]
|
||||
|
||||
MODEL_TO_SAFETY_CATEGORIES_MAP = {
|
||||
CoreModelId.llama_guard_3_8b.value: (
|
||||
|
@ -118,18 +123,16 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
self.inference_api = deps[Api.inference]
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.shield = LlamaGuardShield(
|
||||
model=self.config.model,
|
||||
inference_api=self.inference_api,
|
||||
excluded_categories=self.config.excluded_categories,
|
||||
)
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_shield(self, shield: Shield) -> None:
|
||||
if shield.shield_type != ShieldType.llama_guard:
|
||||
raise ValueError(f"Unsupported shield type: {shield.shield_type}")
|
||||
if shield.provider_resource_id not in LLAMA_GUARD_MODEL_IDS:
|
||||
raise ValueError(
|
||||
f"Unsupported Llama Guard type: {shield.provider_resource_id}. Allowed types: {LLAMA_GUARD_MODEL_IDS}"
|
||||
)
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
|
@ -147,7 +150,13 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
if len(messages) > 0 and messages[0].role != Role.user.value:
|
||||
messages[0] = UserMessage(content=messages[0].content)
|
||||
|
||||
return await self.shield.run(messages)
|
||||
impl = LlamaGuardShield(
|
||||
model=shield.provider_resource_id,
|
||||
inference_api=self.inference_api,
|
||||
excluded_categories=self.config.excluded_categories,
|
||||
)
|
||||
|
||||
return await impl.run(messages)
|
||||
|
||||
|
||||
class LlamaGuardShield:
|
||||
|
|
|
@ -36,8 +36,10 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
pass
|
||||
|
||||
async def register_shield(self, shield: Shield) -> None:
|
||||
if shield.shield_type != ShieldType.prompt_guard:
|
||||
raise ValueError(f"Unsupported shield type: {shield.shield_type}")
|
||||
if shield.provider_resource_id != PROMPT_GUARD_MODEL:
|
||||
raise ValueError(
|
||||
f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. "
|
||||
)
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
|
|
|
@ -20,11 +20,6 @@ from .config import BedrockSafetyConfig
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
BEDROCK_SUPPORTED_SHIELDS = [
|
||||
ShieldType.generic_content_shield,
|
||||
]
|
||||
|
||||
|
||||
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||
def __init__(self, config: BedrockSafetyConfig) -> None:
|
||||
self.config = config
|
||||
|
|
|
@ -81,15 +81,17 @@ async def create_agent_session(agents_impl, agent_config):
|
|||
|
||||
class TestAgents:
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_turns_with_safety(self, agents_stack, common_params):
|
||||
async def test_agent_turns_with_safety(
|
||||
self, safety_model, agents_stack, common_params
|
||||
):
|
||||
agents_impl, _ = agents_stack
|
||||
agent_id, session_id = await create_agent_session(
|
||||
agents_impl,
|
||||
AgentConfig(
|
||||
**{
|
||||
**common_params,
|
||||
"input_shields": ["llama_guard"],
|
||||
"output_shields": ["llama_guard"],
|
||||
"input_shields": [safety_model],
|
||||
"output_shields": [safety_model],
|
||||
}
|
||||
),
|
||||
)
|
||||
|
|
|
@ -9,7 +9,7 @@ import pytest_asyncio
|
|||
|
||||
from llama_stack.apis.models import ModelInput
|
||||
|
||||
from llama_stack.apis.shields import ShieldInput, ShieldType
|
||||
from llama_stack.apis.shields import ShieldInput
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
|
||||
|
@ -41,7 +41,7 @@ def safety_llama_guard(safety_model) -> ProviderFixture:
|
|||
Provider(
|
||||
provider_id="inline::llama-guard",
|
||||
provider_type="inline::llama-guard",
|
||||
config=LlamaGuardConfig(model=safety_model).model_dump(),
|
||||
config=LlamaGuardConfig().model_dump(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
@ -114,20 +114,14 @@ async def safety_stack(inference_model, safety_model, request):
|
|||
|
||||
|
||||
def get_shield_to_register(provider_type: str, safety_model: str) -> ShieldInput:
|
||||
shield_config = {}
|
||||
shield_type = ShieldType.llama_guard
|
||||
identifier = "llama_guard"
|
||||
if provider_type == "meta-reference":
|
||||
shield_config["model"] = safety_model
|
||||
elif provider_type == "remote::together":
|
||||
shield_config["model"] = safety_model
|
||||
elif provider_type == "remote::bedrock":
|
||||
if provider_type == "remote::bedrock":
|
||||
identifier = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER")
|
||||
shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")
|
||||
shield_type = ShieldType.generic_content_shield
|
||||
params = {"guardrailVersion": get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")}
|
||||
else:
|
||||
params = {}
|
||||
identifier = safety_model
|
||||
|
||||
return ShieldInput(
|
||||
shield_id=identifier,
|
||||
shield_type=shield_type,
|
||||
params=shield_config,
|
||||
params=params,
|
||||
)
|
||||
|
|
|
@ -34,7 +34,6 @@ class TestSafety:
|
|||
|
||||
for shield in response:
|
||||
assert isinstance(shield, Shield)
|
||||
assert shield.shield_type in [v for v in ShieldType]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_shield(self, safety_stack):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue