Introduce model_store, shield_store, memory_bank_store

This commit is contained in:
Ashwin Bharambe 2024-10-06 16:29:33 -07:00 committed by Ashwin Bharambe
parent e45a417543
commit 91e0063593
19 changed files with 172 additions and 297 deletions

View file

@ -7,14 +7,12 @@
import json
import logging
import traceback
from typing import Any, Dict, List
import boto3
from llama_stack.apis.safety import * # noqa
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import RoutableProvider
from .config import BedrockSafetyConfig
@ -22,16 +20,17 @@ from .config import BedrockSafetyConfig
logger = logging.getLogger(__name__)
SUPPORTED_SHIELD_TYPES = [
"bedrock_guardrail",
BEDROCK_SUPPORTED_SHIELDS = [
ShieldType.generic_content_shield.value,
]
class BedrockSafetyAdapter(Safety, RoutableProvider):
class BedrockSafetyAdapter(Safety):
def __init__(self, config: BedrockSafetyConfig) -> None:
if not config.aws_profile:
raise ValueError(f"Missing boto_client aws_profile in model info::{config}")
self.config = config
self.registered_shields = []
async def initialize(self) -> None:
try:
@ -45,16 +44,27 @@ class BedrockSafetyAdapter(Safety, RoutableProvider):
async def shutdown(self) -> None:
pass
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
for key in routing_keys:
if key not in SUPPORTED_SHIELD_TYPES:
raise ValueError(f"Unknown safety shield type: {key}")
async def register_shield(self, shield: ShieldDef) -> None:
if shield.type not in BEDROCK_SUPPORTED_SHIELDS:
raise ValueError(f"Unsupported safety shield type: {shield.type}")
shield_params = shield.params
if "guardrailIdentifier" not in shield_params:
raise ValueError(
"Error running request for BedrockGaurdrails:Missing GuardrailID in request"
)
if "guardrailVersion" not in shield_params:
raise ValueError(
"Error running request for BedrockGaurdrails:Missing guardrailVersion in request"
)
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse:
if shield_type not in SUPPORTED_SHIELD_TYPES:
raise ValueError(f"Unknown safety shield type: {shield_type}")
shield_def = await self.shield_store.get_shield(shield_type)
if not shield_def:
raise ValueError(f"Unknown shield {shield_type}")
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
```content = [
@ -69,52 +79,38 @@ class BedrockSafetyAdapter(Safety, RoutableProvider):
They contain content, role . For now we will extract the content and default the "qualifiers": ["query"]
"""
try:
logger.debug(f"run_shield::{params}::messages={messages}")
if "guardrailIdentifier" not in params:
raise RuntimeError(
"Error running request for BedrockGaurdrails:Missing GuardrailID in request"
)
if "guardrailVersion" not in params:
raise RuntimeError(
"Error running request for BedrockGaurdrails:Missing guardrailVersion in request"
)
shield_params = shield_def.params
logger.debug(f"run_shield::{shield_params}::messages={messages}")
# - convert the messages into format Bedrock expects
content_messages = []
for message in messages:
content_messages.append({"text": {"text": message.content}})
logger.debug(
f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:"
)
# - convert the messages into format Bedrock expects
content_messages = []
for message in messages:
content_messages.append({"text": {"text": message.content}})
logger.debug(
f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:"
)
response = self.boto_client.apply_guardrail(
guardrailIdentifier=params.get("guardrailIdentifier"),
guardrailVersion=params.get("guardrailVersion"),
source="OUTPUT", # or 'INPUT' depending on your use case
content=content_messages,
)
logger.debug(f"run_shield:: response: {response}::")
if response["action"] == "GUARDRAIL_INTERVENED":
user_message = ""
metadata = {}
for output in response["outputs"]:
# guardrails returns a list - however for this implementation we will leverage the last values
user_message = output["text"]
for assessment in response["assessments"]:
# guardrails returns a list - however for this implementation we will leverage the last values
metadata = dict(assessment)
return SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata,
)
response = self.boto_client.apply_guardrail(
guardrailIdentifier=shield_params["guardrailIdentifier"],
guardrailVersion=shield_params["guardrailVersion"],
source="OUTPUT", # or 'INPUT' depending on your use case
content=content_messages,
)
if response["action"] == "GUARDRAIL_INTERVENED":
user_message = ""
metadata = {}
for output in response["outputs"]:
# guardrails returns a list - however for this implementation we will leverage the last values
user_message = output["text"]
for assessment in response["assessments"]:
# guardrails returns a list - however for this implementation we will leverage the last values
metadata = dict(assessment)
except Exception:
error_str = traceback.format_exc()
logger.error(
f"Error in apply_guardrails:{error_str}:: RETURNING None !!!!!"
return SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata,
)
return None