mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
71 lines
2.2 KiB
Python
71 lines
2.2 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import List, Union
|
|
|
|
from llama_models.llama3_1.api.datatypes import Attachment, Message
|
|
from llama_toolchain.safety.api.datatypes import * # noqa: F403
|
|
|
|
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
|
|
|
|
|
class ShieldBase(ABC):
|
|
|
|
def __init__(
|
|
self,
|
|
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
|
):
|
|
self.on_violation_action = on_violation_action
|
|
|
|
@abstractmethod
|
|
def get_shield_type(self) -> ShieldType:
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
async def run(self, messages: List[Message]) -> ShieldResponse:
|
|
raise NotImplementedError()
|
|
|
|
|
|
def message_content_as_str(message: Message) -> str:
|
|
def _to_str(content: Union[str, Attachment]) -> str:
|
|
if isinstance(content, str):
|
|
return content
|
|
elif isinstance(content, Attachment):
|
|
return f"File: {str(content.url)}"
|
|
else:
|
|
raise
|
|
|
|
if isinstance(message.content, list) or isinstance(message.content, tuple):
|
|
return "\n".join([_to_str(c) for c in message.content])
|
|
else:
|
|
return _to_str(message.content)
|
|
|
|
|
|
# For shields that operate on simple strings
|
|
class TextShield(ShieldBase):
|
|
def convert_messages_to_text(self, messages: List[Message]) -> str:
|
|
return "\n".join([message_content_as_str(m) for m in messages])
|
|
|
|
async def run(self, messages: List[Message]) -> ShieldResponse:
|
|
text = self.convert_messages_to_text(messages)
|
|
return await self.run_impl(text)
|
|
|
|
@abstractmethod
|
|
async def run_impl(self, text: str) -> ShieldResponse:
|
|
raise NotImplementedError()
|
|
|
|
|
|
class DummyShield(TextShield):
|
|
|
|
def get_shield_type(self) -> ShieldType:
|
|
return "dummy"
|
|
|
|
async def run_impl(self, text: str) -> ShieldResponse:
|
|
# Dummy return LOW to test e2e
|
|
return ShieldResponse(
|
|
shield_type=BuiltinShield.third_party_shield, is_violation=False
|
|
)
|