forked from phoenix-oss/llama-stack-mirror
Add a RoutableProvider protocol, support for multiple routing keys (#163)
* Update configure.py to use multiple routing keys for safety * Refactor distribution/datatypes into a providers/datatypes * Cleanup
This commit is contained in:
parent
73decb3781
commit
eb2d8a31a5
24 changed files with 600 additions and 577 deletions
|
@ -4,47 +4,58 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
import traceback
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from .config import BedrockSafetyConfig
|
||||
import boto3
|
||||
|
||||
from llama_stack.apis.safety import * # noqa
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
import json
|
||||
import logging
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
|
||||
import boto3
|
||||
from .config import BedrockSafetyConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BedrockSafetyAdapter(Safety):
|
||||
SUPPORTED_SHIELD_TYPES = [
|
||||
"bedrock_guardrail",
|
||||
]
|
||||
|
||||
|
||||
class BedrockSafetyAdapter(Safety, RoutableProvider):
|
||||
def __init__(self, config: BedrockSafetyConfig) -> None:
|
||||
if not config.aws_profile:
|
||||
raise ValueError(f"Missing boto_client aws_profile in model info::{config}")
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
if not self.config.aws_profile:
|
||||
raise RuntimeError(
|
||||
f"Missing boto_client aws_profile in model info::{self.config}"
|
||||
)
|
||||
|
||||
try:
|
||||
print(f"initializing with profile --- > {self.config}::")
|
||||
self.boto_client_profile = self.config.aws_profile
|
||||
print(f"initializing with profile --- > {self.config}")
|
||||
self.boto_client = boto3.Session(
|
||||
profile_name=self.boto_client_profile
|
||||
profile_name=self.config.aws_profile
|
||||
).client("bedrock-runtime")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error initializing BedrockSafetyAdapter: {e}") from e
|
||||
raise RuntimeError("Error initializing BedrockSafetyAdapter") from e
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
for key in routing_keys:
|
||||
if key not in SUPPORTED_SHIELD_TYPES:
|
||||
raise ValueError(f"Unknown safety shield type: {key}")
|
||||
|
||||
async def run_shield(
|
||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
) -> RunShieldResponse:
|
||||
if shield_type not in SUPPORTED_SHIELD_TYPES:
|
||||
raise ValueError(f"Unknown safety shield type: {shield_type}")
|
||||
|
||||
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
|
||||
```content = [
|
||||
{
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from llama_models.sku_list import resolve_model
|
||||
from together import Together
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
@ -13,53 +12,52 @@ from llama_stack.apis.safety import (
|
|||
SafetyViolation,
|
||||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import RoutableProvider
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
|
||||
from .config import TogetherSafetyConfig
|
||||
|
||||
|
||||
SAFETY_SHIELD_TYPES = {
|
||||
"llama_guard": "meta-llama/Meta-Llama-Guard-3-8B",
|
||||
"Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B",
|
||||
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
|
||||
}
|
||||
|
||||
|
||||
def shield_type_to_model_name(shield_type: str) -> str:
|
||||
if shield_type == "llama_guard":
|
||||
shield_type = "Llama-Guard-3-8B"
|
||||
|
||||
model = resolve_model(shield_type)
|
||||
if (
|
||||
model is None
|
||||
or not model.descriptor(shorten_default_variant=True) in SAFETY_SHIELD_TYPES
|
||||
or model.model_family is not ModelFamily.safety
|
||||
):
|
||||
raise ValueError(
|
||||
f"{shield_type} is not supported, please use of {','.join(SAFETY_SHIELD_TYPES.keys())}"
|
||||
)
|
||||
|
||||
return SAFETY_SHIELD_TYPES.get(model.descriptor(shorten_default_variant=True))
|
||||
|
||||
|
||||
class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
|
||||
class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
|
||||
def __init__(self, config: TogetherSafetyConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||
for key in routing_keys:
|
||||
if key not in SAFETY_SHIELD_TYPES:
|
||||
raise ValueError(f"Unknown safety shield type: {key}")
|
||||
|
||||
async def run_shield(
|
||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
) -> RunShieldResponse:
|
||||
if shield_type not in SAFETY_SHIELD_TYPES:
|
||||
raise ValueError(f"Unknown safety shield type: {shield_type}")
|
||||
|
||||
together_api_key = None
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.together_api_key:
|
||||
raise ValueError(
|
||||
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
|
||||
)
|
||||
together_api_key = provider_data.together_api_key
|
||||
if self.config.api_key is not None:
|
||||
together_api_key = self.config.api_key
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.together_api_key:
|
||||
raise ValueError(
|
||||
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
|
||||
)
|
||||
together_api_key = provider_data.together_api_key
|
||||
|
||||
model_name = shield_type_to_model_name(shield_type)
|
||||
model_name = SAFETY_SHIELD_TYPES[shield_type]
|
||||
|
||||
# messages can have role assistant or user
|
||||
api_messages = []
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue