shield_type->identifier

This commit is contained in:
Xi Yan 2024-11-06 16:25:35 -08:00
parent 67bfd946dc
commit ea9f5368cb
3 changed files with 6 additions and 6 deletions

View file

@ -48,5 +48,5 @@ class Safety(Protocol):
@webmethod(route="/safety/run_shield")
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
self, identifier: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse: ...

View file

@ -154,12 +154,12 @@ class SafetyRouter(Safety):
async def run_shield(
self,
shield_type: str,
identifier: str,
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
return await self.routing_table.get_provider_impl(identifier).run_shield(
shield_type=shield_type,
identifier=identifier,
messages=messages,
params=params,
)

View file

@ -57,13 +57,13 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate):
async def run_shield(
self,
shield_type: str,
identifier: str,
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
shield_def = await self.shield_store.get_shield(shield_type)
shield_def = await self.shield_store.get_shield(identifier)
if not shield_def:
raise ValueError(f"Unknown shield {shield_type}")
raise ValueError(f"Unknown shield {identifier}")
shield = self.get_shield_impl(shield_def)