This commit is contained in:
Dinesh Yeduguru 2024-11-07 12:09:14 -08:00
parent 7ee9f8d8ac
commit d960f9b60f
16 changed files with 140 additions and 105 deletions

View file

@ -32,7 +32,7 @@ RoutingKey = Union[str, List[str]]
RoutableObject = Union[
ModelDef,
ShieldDef,
Shield,
MemoryBankDef,
DatasetDef,
ScoringFnDef,
@ -42,7 +42,7 @@ RoutableObject = Union[
RoutableObjectWithProvider = Annotated[
Union[
ModelDefWithProvider,
ShieldDefWithProvider,
Shield,
MemoryBankDefWithProvider,
DatasetDefWithProvider,
ScoringFnDefWithProvider,

View file

@ -150,17 +150,17 @@ class SafetyRouter(Safety):
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: ShieldDef) -> None:
async def register_shield(self, shield: Shield) -> None:
await self.routing_table.register_shield(shield)
async def run_shield(
self,
identifier: str,
shield: Shield,
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
return await self.routing_table.get_provider_impl(identifier).run_shield(
identifier=identifier,
return await self.routing_table.get_provider_impl(shield.identifier).run_shield(
shield=shield,
messages=messages,
params=params,
)

View file

@ -87,11 +87,6 @@ class CommonRoutingTableImpl(RoutingTable):
models = await p.list_models()
await add_objects(models, pid, ModelDefWithProvider)
elif api == Api.safety:
p.shield_store = self
shields = await p.list_shields()
await add_objects(shields, pid, ShieldDefWithProvider)
elif api == Api.memory:
p.memory_bank_store = self
memory_banks = await p.list_memory_banks()
@ -212,13 +207,13 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> List[ShieldDef]:
async def list_shields(self) -> List[Shield]:
return await self.get_all_with_type("shield")
async def get_shield(self, identifier: str) -> Optional[ShieldDefWithProvider]:
async def get_shield(self, identifier: str) -> Optional[Shield]:
return await self.get_object_by_identifier(identifier)
async def register_shield(self, shield: ShieldDefWithProvider) -> None:
async def register_shield(self, shield: Shield) -> None:
await self.register_object(shield)