adding the fiddlecube provider

This commit is contained in:
Kaushik 2025-02-06 20:01:31 -08:00
parent 7ec79c0297
commit aea5b2745d
4 changed files with 145 additions and 0 deletions

View file

@ -85,4 +85,13 @@ Provider `inline::meta-reference` for API `safety` does not work with the latest
config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig",
),
),
remote_provider_spec(
api=Api.safety,
adapter=AdapterSpec(
adapter_type="fiddlecube",
pip_packages=["httpx"],
module="llama_stack.providers.remote.safety.fiddlecube",
config_class="llama_stack.providers.remote.safety.fiddlecube.FiddlecubeSafetyConfig",
),
),
]

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import FiddlecubeSafetyConfig
async def get_adapter_impl(config: FiddlecubeSafetyConfig, _deps) -> Any:
from .fiddlecube import FiddlecubeSafetyAdapter
impl = FiddlecubeSafetyAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from llama_models.schema_utils import json_schema_type
@json_schema_type
class FiddlecubeSafetyConfig(BaseModel):
pass

View file

@ -0,0 +1,104 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import logging
from typing import Any, Dict, List
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.shields import Shield
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from .config import FiddlecubeSafetyConfig
logger = logging.getLogger(__name__)
class FiddlecubeSafetyAdapter(Safety, ShieldsProtocolPrivate):
def __init__(self, config: FiddlecubeSafetyConfig) -> None:
self.config = config
self.registered_shields = []
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: Shield) -> None:
pass
async def run_shield(
self, shield_id: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse:
# Set up FiddleCube API using httpx
# [TBD] convert the `messages` into format FiddleCube expects
# make a call to the API for guardrails
# convert the [TBD] response into the format RunShieldResponse expects
# return the response
return RunShieldResponse()
shield = await self.shield_store.get_shield(shield_id)
if not shield:
raise ValueError(f"Shield {shield_id} not found")
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
```content = [
{
"text": {
"text": "Is the AB503 Product a better investment than the S&P 500?"
}
}
]```
However the incoming messages are of this type UserMessage(content=....) coming from
https://github.com/meta-llama/llama-models/blob/main/models/llama3/api/datatypes.py
They contain content, role . For now we will extract the content and default the "qualifiers": ["query"]
"""
shield_params = shield.params
logger.debug(f"run_shield::{shield_params}::messages={messages}")
# - convert the messages into format Bedrock expects
content_messages = []
for message in messages:
content_messages.append({"text": {"text": message.content}})
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
response = self.bedrock_runtime_client.apply_guardrail(
guardrailIdentifier=shield.provider_resource_id,
guardrailVersion=shield_params["guardrailVersion"],
source="OUTPUT", # or 'INPUT' depending on your use case
content=content_messages,
)
if response["action"] == "GUARDRAIL_INTERVENED":
user_message = ""
metadata = {}
for output in response["outputs"]:
# guardrails returns a list - however for this implementation we will leverage the last values
user_message = output["text"]
for assessment in response["assessments"]:
# guardrails returns a list - however for this implementation we will leverage the last values
metadata = dict(assessment)
return RunShieldResponse(
violation=SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata,
)
)
return RunShieldResponse()