mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-17 07:07:19 +00:00
add deprecation_error pointing meta-reference -> inline::llama-guard
This commit is contained in:
parent
fdaec91747
commit
984ba074e1
9 changed files with 84 additions and 74 deletions
|
@ -19,52 +19,6 @@ from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
|||
from .config import LlamaGuardConfig
|
||||
|
||||
|
||||
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||
def __init__(self, config: LlamaGuardConfig, deps) -> None:
|
||||
self.config = config
|
||||
self.inference_api = deps[Api.inference]
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.shield = LlamaGuardShield(
|
||||
model=self.config.model,
|
||||
inference_api=self.inference_api,
|
||||
excluded_categories=self.config.excluded_categories,
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
raise ValueError("Registering dynamic shields is not supported")
|
||||
|
||||
async def list_shields(self) -> List[ShieldDef]:
|
||||
return [
|
||||
ShieldDef(
|
||||
identifier=ShieldType.llama_guard.value,
|
||||
shield_type=ShieldType.llama_guard.value,
|
||||
params={},
|
||||
),
|
||||
]
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
identifier: str,
|
||||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
shield_def = await self.shield_store.get_shield(identifier)
|
||||
if not shield_def:
|
||||
raise ValueError(f"Unknown shield {identifier}")
|
||||
|
||||
messages = messages.copy()
|
||||
# some shields like llama-guard require the first message to be a user message
|
||||
# since this might be a tool call, first role might not be user
|
||||
if len(messages) > 0 and messages[0].role != Role.user.value:
|
||||
messages[0] = UserMessage(content=messages[0].content)
|
||||
|
||||
return await self.shield.run(messages)
|
||||
|
||||
|
||||
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
||||
|
||||
SAFE_RESPONSE = "safe"
|
||||
|
@ -158,6 +112,52 @@ PROMPT_TEMPLATE = Template(
|
|||
)
|
||||
|
||||
|
||||
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||
def __init__(self, config: LlamaGuardConfig, deps) -> None:
|
||||
self.config = config
|
||||
self.inference_api = deps[Api.inference]
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.shield = LlamaGuardShield(
|
||||
model=self.config.model,
|
||||
inference_api=self.inference_api,
|
||||
excluded_categories=self.config.excluded_categories,
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
raise ValueError("Registering dynamic shields is not supported")
|
||||
|
||||
async def list_shields(self) -> List[ShieldDef]:
|
||||
return [
|
||||
ShieldDef(
|
||||
identifier=ShieldType.llama_guard.value,
|
||||
shield_type=ShieldType.llama_guard.value,
|
||||
params={},
|
||||
),
|
||||
]
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
identifier: str,
|
||||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
shield_def = await self.shield_store.get_shield(identifier)
|
||||
if not shield_def:
|
||||
raise ValueError(f"Unknown shield {identifier}")
|
||||
|
||||
messages = messages.copy()
|
||||
# some shields like llama-guard require the first message to be a user message
|
||||
# since this might be a tool call, first role might not be user
|
||||
if len(messages) > 0 and messages[0].role != Role.user.value:
|
||||
messages[0] = UserMessage(content=messages[0].content)
|
||||
|
||||
return await self.shield.run(messages)
|
||||
|
||||
|
||||
class LlamaGuardShield:
|
||||
def __init__(
|
||||
self,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue