Introduce model_store, shield_store, memory_bank_store

This commit is contained in:
Ashwin Bharambe 2024-10-06 16:29:33 -07:00 committed by Ashwin Bharambe
parent e45a417543
commit 91e0063593
19 changed files with 172 additions and 297 deletions

View file

@ -83,15 +83,6 @@ class FaissMemoryImpl(Memory):
)
self.cache[memory_bank.identifier] = index
async def get_memory_bank(self, identifier: str) -> Optional[MemoryBankDef]:
index = self.cache.get(identifier)
if index is None:
return None
return index.bank
async def list_memory_banks(self) -> List[MemoryBankDef]:
return [x.bank for x in self.cache.values()]
async def insert_documents(
self,
bank_id: str,

View file

@ -33,7 +33,6 @@ class MetaReferenceSafetyImpl(Safety):
def __init__(self, config: SafetyConfig, deps) -> None:
self.config = config
self.inference_api = deps[Api.inference]
self.registered_shields = []
self.available_shields = [ShieldType.code_scanner.value]
if config.llama_guard_shield:
@ -55,24 +54,13 @@ class MetaReferenceSafetyImpl(Safety):
if shield.type not in self.available_shields:
raise ValueError(f"Unsupported safety shield type: {shield.type}")
self.registered_shields.append(shield)
async def list_shields(self) -> List[ShieldDef]:
return self.registered_shields
async def get_shield(self, identifier: str) -> Optional[ShieldDef]:
for shield in self.registered_shields:
if shield.identifier == identifier:
return shield
return None
async def run_shield(
self,
shield_type: str,
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
shield_def = await self.get_shield(shield_type)
shield_def = await self.shield_store.get_shield(shield_type)
if not shield_def:
raise ValueError(f"Unknown shield {shield_type}")