Make Safety test work, other cleanup

This commit is contained in:
Ashwin Bharambe 2024-10-09 21:09:50 -07:00
parent ba1f294cc6
commit fcd22b6baa
16 changed files with 229 additions and 123 deletions

View file

@ -50,15 +50,6 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
)
]
async def get_model(self, identifier: str) -> Optional[ModelDef]:
if self.model.descriptor() != identifier:
return None
return ModelDef(
identifier=self.model.descriptor(),
llama_model=self.model.descriptor(),
)
async def shutdown(self) -> None:
self.generator.stop()

View file

@ -85,13 +85,6 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
)
self.cache[memory_bank.identifier] = index
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
banks = await self.list_memory_banks()
for bank in banks:
if bank.identifier == identifier:
return bank
return None
async def list_memory_banks(self) -> List[MemoryBankDef]:
return [i.bank for i in self.cache.values()]

View file

@ -12,6 +12,8 @@ from llama_stack.apis.safety import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from .base import OnViolationAction, ShieldBase
from .config import SafetyConfig
from .llama_guard import LlamaGuardShield
@ -21,7 +23,7 @@ from .prompt_guard import InjectionShield, JailbreakShield, PromptGuardShield
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
class MetaReferenceSafetyImpl(Safety):
class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate):
def __init__(self, config: SafetyConfig, deps) -> None:
self.config = config
self.inference_api = deps[Api.inference]
@ -41,8 +43,17 @@ class MetaReferenceSafetyImpl(Safety):
pass
async def register_shield(self, shield: ShieldDef) -> None:
if shield.type not in self.available_shields:
raise ValueError(f"Unsupported safety shield type: {shield.type}")
raise ValueError("Registering dynamic shields is not supported")
async def list_shields(self) -> List[ShieldDef]:
return [
ShieldDef(
identifier=shield_type,
type=shield_type,
params={},
)
for shield_type in self.available_shields
]
async def run_shield(
self,