rebase on top of Dinesh's refactor

This commit is contained in:
Ashwin Bharambe 2024-11-11 08:46:20 -08:00
parent a7f728e41c
commit 7507cd487f
2 changed files with 16 additions and 32 deletions

View file

@ -127,27 +127,19 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def register_shield(self, shield: ShieldDef) -> None: async def register_shield(self, shield: Shield) -> None:
raise ValueError("Registering dynamic shields is not supported") if shield.shield_type != ShieldType.llama_guard.value:
raise ValueError(f"Unsupported shield type: {shield.shield_type}")
async def list_shields(self) -> List[ShieldDef]:
return [
ShieldDef(
identifier=ShieldType.llama_guard.value,
shield_type=ShieldType.llama_guard.value,
params={},
),
]
async def run_shield( async def run_shield(
self, self,
identifier: str, shield_id: str,
messages: List[Message], messages: List[Message],
params: Dict[str, Any] = None, params: Dict[str, Any] = None,
) -> RunShieldResponse: ) -> RunShieldResponse:
shield_def = await self.shield_store.get_shield(identifier) shield = await self.shield_store.get_shield(shield_id)
if not shield_def: if not shield:
raise ValueError(f"Unknown shield {identifier}") raise ValueError(f"Unknown shield {shield_id}")
messages = messages.copy() messages = messages.copy()
# some shields like llama-guard require the first message to be a user message # some shields like llama-guard require the first message to be a user message

View file

@ -7,11 +7,11 @@
from typing import Any, Dict, List from typing import Any, Dict, List
import torch import torch
from llama_stack.distribution.utils.model_utils import model_local_dir
from termcolor import cprint from termcolor import cprint
from transformers import AutoModelForSequenceClassification, AutoTokenizer from transformers import AutoModelForSequenceClassification, AutoTokenizer
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
@ -35,27 +35,19 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def register_shield(self, shield: ShieldDef) -> None: async def register_shield(self, shield: Shield) -> None:
raise ValueError("Registering dynamic shields is not supported") if shield.shield_type != ShieldType.prompt_guard.value:
raise ValueError(f"Unsupported shield type: {shield.shield_type}")
async def list_shields(self) -> List[ShieldDef]:
return [
ShieldDef(
identifier=ShieldType.prompt_guard.value,
shield_type=ShieldType.prompt_guard.value,
params={},
)
]
async def run_shield( async def run_shield(
self, self,
identifier: str, shield_id: str,
messages: List[Message], messages: List[Message],
params: Dict[str, Any] = None, params: Dict[str, Any] = None,
) -> RunShieldResponse: ) -> RunShieldResponse:
shield_def = await self.shield_store.get_shield(identifier) shield = await self.shield_store.get_shield(shield_id)
if not shield_def: if not shield:
raise ValueError(f"Unknown shield {identifier}") raise ValueError(f"Unknown shield {shield_id}")
return await self.shield.run(messages) return await self.shield.run(messages)