mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-19 07:37:18 +00:00
Kill the notion of shield_type
This commit is contained in:
parent
09269e2a44
commit
b1c3a95485
20 changed files with 87 additions and 161 deletions
|
@ -6,32 +6,8 @@
|
|||
|
||||
from typing import List
|
||||
|
||||
from llama_models.sku_list import CoreModelId, safety_models
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LlamaGuardConfig(BaseModel):
|
||||
model: str = "Llama-Guard-3-1B"
|
||||
excluded_categories: List[str] = []
|
||||
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, model: str) -> str:
|
||||
permitted_models = [
|
||||
m.descriptor()
|
||||
for m in safety_models()
|
||||
if (
|
||||
m.core_model_id
|
||||
in {
|
||||
CoreModelId.llama_guard_3_8b,
|
||||
CoreModelId.llama_guard_3_1b,
|
||||
CoreModelId.llama_guard_3_11b_vision,
|
||||
}
|
||||
)
|
||||
]
|
||||
if model not in permitted_models:
|
||||
raise ValueError(
|
||||
f"Invalid model: {model}. Must be one of {permitted_models}"
|
||||
)
|
||||
return model
|
||||
|
|
|
@ -73,6 +73,11 @@ DEFAULT_LG_V3_SAFETY_CATEGORIES = [
|
|||
CAT_ELECTIONS,
|
||||
]
|
||||
|
||||
LLAMA_GUARD_MODEL_IDS = [
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
CoreModelId.llama_guard_3_1b.value,
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
]
|
||||
|
||||
MODEL_TO_SAFETY_CATEGORIES_MAP = {
|
||||
CoreModelId.llama_guard_3_8b.value: (
|
||||
|
@ -118,18 +123,16 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
self.inference_api = deps[Api.inference]
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.shield = LlamaGuardShield(
|
||||
model=self.config.model,
|
||||
inference_api=self.inference_api,
|
||||
excluded_categories=self.config.excluded_categories,
|
||||
)
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_shield(self, shield: Shield) -> None:
|
||||
if shield.shield_type != ShieldType.llama_guard:
|
||||
raise ValueError(f"Unsupported shield type: {shield.shield_type}")
|
||||
if shield.provider_resource_id not in LLAMA_GUARD_MODEL_IDS:
|
||||
raise ValueError(
|
||||
f"Unsupported Llama Guard type: {shield.provider_resource_id}. Allowed types: {LLAMA_GUARD_MODEL_IDS}"
|
||||
)
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
|
@ -147,7 +150,13 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
if len(messages) > 0 and messages[0].role != Role.user.value:
|
||||
messages[0] = UserMessage(content=messages[0].content)
|
||||
|
||||
return await self.shield.run(messages)
|
||||
impl = LlamaGuardShield(
|
||||
model=shield.provider_resource_id,
|
||||
inference_api=self.inference_api,
|
||||
excluded_categories=self.config.excluded_categories,
|
||||
)
|
||||
|
||||
return await impl.run(messages)
|
||||
|
||||
|
||||
class LlamaGuardShield:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue