Introduce model_store, shield_store, memory_bank_store

This commit is contained in:
Ashwin Bharambe 2024-10-06 16:29:33 -07:00 committed by Ashwin Bharambe
parent e45a417543
commit 91e0063593
19 changed files with 172 additions and 297 deletions

View file

@ -7,14 +7,12 @@
import json
import logging
import traceback
from typing import Any, Dict, List
import boto3
from llama_stack.apis.safety import * # noqa
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from .config import BedrockSafetyConfig
@ -22,16 +20,17 @@ from .config import BedrockSafetyConfig
logger = logging.getLogger(__name__)
SUPPORTED_SHIELD_TYPES = [
"bedrock_guardrail",
BEDROCK_SUPPORTED_SHIELDS = [
ShieldType.generic_content_shield.value,
]
class BedrockSafetyAdapter(Safety, RoutableProvider):
class BedrockSafetyAdapter(Safety):
def __init__(self, config: BedrockSafetyConfig) -> None:
if not config.aws_profile:
raise ValueError(f"Missing boto_client aws_profile in model info::{config}")
self.config = config
self.registered_shields = []
async def initialize(self) -> None:
try:
@ -45,16 +44,27 @@ class BedrockSafetyAdapter(Safety, RoutableProvider):
async def shutdown(self) -> None:
pass
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
for key in routing_keys:
if key not in SUPPORTED_SHIELD_TYPES:
raise ValueError(f"Unknown safety shield type: {key}")
async def register_shield(self, shield: ShieldDef) -> None:
if shield.type not in BEDROCK_SUPPORTED_SHIELDS:
raise ValueError(f"Unsupported safety shield type: {shield.type}")
shield_params = shield.params
if "guardrailIdentifier" not in shield_params:
raise ValueError(
"Error running request for BedrockGaurdrails:Missing GuardrailID in request"
)
if "guardrailVersion" not in shield_params:
raise ValueError(
"Error running request for BedrockGaurdrails:Missing guardrailVersion in request"
)
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse:
if shield_type not in SUPPORTED_SHIELD_TYPES:
raise ValueError(f"Unknown safety shield type: {shield_type}")
shield_def = await self.shield_store.get_shield(shield_type)
if not shield_def:
raise ValueError(f"Unknown shield {shield_type}")
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
```content = [
@ -69,52 +79,38 @@ class BedrockSafetyAdapter(Safety, RoutableProvider):
They contain content, role . For now we will extract the content and default the "qualifiers": ["query"]
"""
try:
logger.debug(f"run_shield::{params}::messages={messages}")
if "guardrailIdentifier" not in params:
raise RuntimeError(
"Error running request for BedrockGaurdrails:Missing GuardrailID in request"
)
if "guardrailVersion" not in params:
raise RuntimeError(
"Error running request for BedrockGaurdrails:Missing guardrailVersion in request"
)
shield_params = shield_def.params
logger.debug(f"run_shield::{shield_params}::messages={messages}")
# - convert the messages into format Bedrock expects
content_messages = []
for message in messages:
content_messages.append({"text": {"text": message.content}})
logger.debug(
f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:"
)
# - convert the messages into format Bedrock expects
content_messages = []
for message in messages:
content_messages.append({"text": {"text": message.content}})
logger.debug(
f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:"
)
response = self.boto_client.apply_guardrail(
guardrailIdentifier=params.get("guardrailIdentifier"),
guardrailVersion=params.get("guardrailVersion"),
source="OUTPUT", # or 'INPUT' depending on your use case
content=content_messages,
)
logger.debug(f"run_shield:: response: {response}::")
if response["action"] == "GUARDRAIL_INTERVENED":
user_message = ""
metadata = {}
for output in response["outputs"]:
# guardrails returns a list - however for this implementation we will leverage the last values
user_message = output["text"]
for assessment in response["assessments"]:
# guardrails returns a list - however for this implementation we will leverage the last values
metadata = dict(assessment)
return SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata,
)
response = self.boto_client.apply_guardrail(
guardrailIdentifier=shield_params["guardrailIdentifier"],
guardrailVersion=shield_params["guardrailVersion"],
source="OUTPUT", # or 'INPUT' depending on your use case
content=content_messages,
)
if response["action"] == "GUARDRAIL_INTERVENED":
user_message = ""
metadata = {}
for output in response["outputs"]:
# guardrails returns a list - however for this implementation we will leverage the last values
user_message = output["text"]
for assessment in response["assessments"]:
# guardrails returns a list - however for this implementation we will leverage the last values
metadata = dict(assessment)
except Exception:
error_str = traceback.format_exc()
logger.error(
f"Error in apply_guardrails:{error_str}:: RETURNING None !!!!!"
return SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata,
)
return None

View file

@ -9,14 +9,12 @@ from .config import SampleConfig
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
class SampleSafetyImpl(Safety, RoutableProvider):
class SampleSafetyImpl(Safety):
def __init__(self, config: SampleConfig):
self.config = config
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
async def register_shield(self, shield: ShieldDef) -> None:
# these are the safety shields the Llama Stack will use to route requests to this provider
# perform validation here if necessary
pass

View file

@ -12,7 +12,7 @@ from llama_stack.distribution.request_headers import NeedsRequestProviderData
from .config import TogetherSafetyConfig
SAFETY_SHIELD_MODEL_MAP = {
TOGETHER_SHIELD_MODEL_MAP = {
"llama_guard": "meta-llama/Meta-Llama-Guard-3-8B",
"Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B",
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
@ -22,7 +22,6 @@ SAFETY_SHIELD_MODEL_MAP = {
class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
def __init__(self, config: TogetherSafetyConfig) -> None:
self.config = config
self.register_shields = []
async def initialize(self) -> None:
pass
@ -34,26 +33,15 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
if shield.type != ShieldType.llama_guard.value:
raise ValueError(f"Unsupported safety shield type: {shield.type}")
self.registered_shields.append(shield)
async def list_shields(self) -> List[ShieldDef]:
return self.registered_shields
async def get_shield(self, identifier: str) -> Optional[ShieldDef]:
for shield in self.registered_shields:
if shield.identifier == identifier:
return shield
return None
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse:
shield_def = await self.get_shield(shield_type)
shield_def = await self.shield_store.get_shield(shield_type)
if not shield_def:
raise ValueError(f"Unknown shield {shield_type}")
model = shield_def.params.get("model", "llama_guard")
if model not in SAFETY_SHIELD_MODEL_MAP:
if model not in TOGETHER_SHIELD_MODEL_MAP:
raise ValueError(f"Unsupported safety model: {model}")
together_api_key = None
@ -73,7 +61,9 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
if message.role in (Role.user.value, Role.assistant.value):
api_messages.append({"role": message.role, "content": message.content})
violation = await get_safety_response(together_api_key, model, api_messages)
violation = await get_safety_response(
together_api_key, TOGETHER_SHIELD_MODEL_MAP[model], api_messages
)
return RunShieldResponse(violation=violation)