llama-stack/llama_stack/providers/inline/agents/meta_reference/tools/safety.py
Ashwin Bharambe 694c142b89
Add provider deprecation support; change directory structure (#397)
* Add provider deprecation support; change directory structure

* fix a couple dangling imports

* move the meta_reference safety dir also
2024-11-07 13:04:53 -08:00

42 lines
1.3 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import * # noqa: F403
from ..safety import ShieldRunnerMixin
from .builtin import BaseTool
class SafeTool(BaseTool, ShieldRunnerMixin):
"""A tool that makes other tools safety enabled"""
def __init__(
self,
tool: BaseTool,
safety_api: Safety,
input_shields: List[str] = None,
output_shields: List[str] = None,
):
self._tool = tool
ShieldRunnerMixin.__init__(
self, safety_api, input_shields=input_shields, output_shields=output_shields
)
def get_name(self) -> str:
return self._tool.get_name()
async def run(self, messages: List[Message]) -> List[Message]:
if self.input_shields:
await self.run_multiple_shields(messages, self.input_shields)
# run the underlying tool
res = await self._tool.run(messages)
if self.output_shields:
await self.run_multiple_shields(messages, self.output_shields)
return res