Update the meta reference safety implementation to match new API

This commit is contained in:
Ashwin Bharambe 2024-09-20 14:17:44 -07:00 committed by Xi Yan
parent 7e40eead4e
commit 82ddd851c8
11 changed files with 115 additions and 130 deletions

View file

@ -7,7 +7,7 @@
from typing import List
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import Safety, ShieldDefinition
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.providers.impls.meta_reference.agents.safety import ShieldRunnerMixin
@ -21,8 +21,8 @@ class SafeTool(BaseTool, ShieldRunnerMixin):
self,
tool: BaseTool,
safety_api: Safety,
input_shields: List[ShieldDefinition] = None,
output_shields: List[ShieldDefinition] = None,
input_shields: List[str] = None,
output_shields: List[str] = None,
):
self._tool = tool
ShieldRunnerMixin.__init__(
@ -30,7 +30,6 @@ class SafeTool(BaseTool, ShieldRunnerMixin):
)
def get_name(self) -> str:
# return the name of the wrapped tool
return self._tool.get_name()
async def run(self, messages: List[Message]) -> List[Message]:
@ -47,8 +46,8 @@ class SafeTool(BaseTool, ShieldRunnerMixin):
def with_safety(
tool: BaseTool,
safety_api: Safety,
input_shields: List[ShieldDefinition] = None,
output_shields: List[ShieldDefinition] = None,
input_shields: List[str] = None,
output_shields: List[str] = None,
) -> SafeTool:
return SafeTool(
tool,