forked from phoenix-oss/llama-stack-mirror
Add provider deprecation support; change directory structure (#397)
* Add provider deprecation support; change directory structure * fix a couple dangling imports * move the meta_reference safety dir also
This commit is contained in:
parent
36e2538eb0
commit
694c142b89
58 changed files with 61 additions and 120 deletions
57
llama_stack/providers/inline/safety/meta_reference/base.py
Normal file
57
llama_stack/providers/inline/safety/meta_reference/base.py
Normal file
|
@ -0,0 +1,57 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
|
||||
from pydantic import BaseModel
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
||||
|
||||
|
||||
# TODO: clean this up; just remove this type completely
|
||||
class ShieldResponse(BaseModel):
|
||||
is_violation: bool
|
||||
violation_type: Optional[str] = None
|
||||
violation_return_message: Optional[str] = None
|
||||
|
||||
|
||||
# TODO: this is a caller / agent concern
|
||||
class OnViolationAction(Enum):
|
||||
IGNORE = 0
|
||||
WARN = 1
|
||||
RAISE = 2
|
||||
|
||||
|
||||
class ShieldBase(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||
):
|
||||
self.on_violation_action = on_violation_action
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, messages: List[Message]) -> ShieldResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def message_content_as_str(message: Message) -> str:
|
||||
return interleaved_text_media_as_str(message.content)
|
||||
|
||||
|
||||
class TextShield(ShieldBase):
|
||||
def convert_messages_to_text(self, messages: List[Message]) -> str:
|
||||
return "\n".join([message_content_as_str(m) for m in messages])
|
||||
|
||||
async def run(self, messages: List[Message]) -> ShieldResponse:
|
||||
text = self.convert_messages_to_text(messages)
|
||||
return await self.run_impl(text)
|
||||
|
||||
@abstractmethod
|
||||
async def run_impl(self, text: str) -> ShieldResponse:
|
||||
raise NotImplementedError()
|
Loading…
Add table
Add a link
Reference in a new issue