mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-05 20:27:35 +00:00
Update the meta reference safety implementation to match new API
This commit is contained in:
parent
7e40eead4e
commit
82ddd851c8
11 changed files with 115 additions and 130 deletions
|
@ -7,7 +7,7 @@
|
|||
from typing import List
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import Safety, ShieldDefinition
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.agents.safety import ShieldRunnerMixin
|
||||
|
||||
|
@ -21,8 +21,8 @@ class SafeTool(BaseTool, ShieldRunnerMixin):
|
|||
self,
|
||||
tool: BaseTool,
|
||||
safety_api: Safety,
|
||||
input_shields: List[ShieldDefinition] = None,
|
||||
output_shields: List[ShieldDefinition] = None,
|
||||
input_shields: List[str] = None,
|
||||
output_shields: List[str] = None,
|
||||
):
|
||||
self._tool = tool
|
||||
ShieldRunnerMixin.__init__(
|
||||
|
@ -30,7 +30,6 @@ class SafeTool(BaseTool, ShieldRunnerMixin):
|
|||
)
|
||||
|
||||
def get_name(self) -> str:
|
||||
# return the name of the wrapped tool
|
||||
return self._tool.get_name()
|
||||
|
||||
async def run(self, messages: List[Message]) -> List[Message]:
|
||||
|
@ -47,8 +46,8 @@ class SafeTool(BaseTool, ShieldRunnerMixin):
|
|||
def with_safety(
|
||||
tool: BaseTool,
|
||||
safety_api: Safety,
|
||||
input_shields: List[ShieldDefinition] = None,
|
||||
output_shields: List[ShieldDefinition] = None,
|
||||
input_shields: List[str] = None,
|
||||
output_shields: List[str] = None,
|
||||
) -> SafeTool:
|
||||
return SafeTool(
|
||||
tool,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue