From aea5b2745d91b2208708feb5e85fc31a9d31b0b1 Mon Sep 17 00:00:00 2001 From: Kaushik Date: Thu, 6 Feb 2025 20:01:31 -0800 Subject: [PATCH] adding the fiddlecube provider --- llama_stack/providers/registry/safety.py | 9 ++ .../remote/safety/fiddlecube/__init__.py | 18 +++ .../remote/safety/fiddlecube/config.py | 14 +++ .../remote/safety/fiddlecube/fiddlecube.py | 104 ++++++++++++++++++ 4 files changed, 145 insertions(+) create mode 100644 llama_stack/providers/remote/safety/fiddlecube/__init__.py create mode 100644 llama_stack/providers/remote/safety/fiddlecube/config.py create mode 100644 llama_stack/providers/remote/safety/fiddlecube/fiddlecube.py diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index b9f7b6d78..1e57982e9 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -85,4 +85,13 @@ Provider `inline::meta-reference` for API `safety` does not work with the latest config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig", ), ), + remote_provider_spec( + api=Api.safety, + adapter=AdapterSpec( + adapter_type="fiddlecube", + pip_packages=["httpx"], + module="llama_stack.providers.remote.safety.fiddlecube", + config_class="llama_stack.providers.remote.safety.fiddlecube.FiddlecubeSafetyConfig", + ), + ), ] diff --git a/llama_stack/providers/remote/safety/fiddlecube/__init__.py b/llama_stack/providers/remote/safety/fiddlecube/__init__.py new file mode 100644 index 000000000..d4cceb6e7 --- /dev/null +++ b/llama_stack/providers/remote/safety/fiddlecube/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from typing import Any + +from .config import FiddlecubeSafetyConfig + + +async def get_adapter_impl(config: FiddlecubeSafetyConfig, _deps) -> Any: + from .fiddlecube import FiddlecubeSafetyAdapter + + impl = FiddlecubeSafetyAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/safety/fiddlecube/config.py b/llama_stack/providers/remote/safety/fiddlecube/config.py new file mode 100644 index 000000000..f26e79f20 --- /dev/null +++ b/llama_stack/providers/remote/safety/fiddlecube/config.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + +from llama_models.schema_utils import json_schema_type + + +@json_schema_type +class FiddlecubeSafetyConfig(BaseModel): + pass diff --git a/llama_stack/providers/remote/safety/fiddlecube/fiddlecube.py b/llama_stack/providers/remote/safety/fiddlecube/fiddlecube.py new file mode 100644 index 000000000..84228c5f8 --- /dev/null +++ b/llama_stack/providers/remote/safety/fiddlecube/fiddlecube.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import logging + +from typing import Any, Dict, List + +from llama_stack.apis.inference import Message + +from llama_stack.apis.safety import ( + RunShieldResponse, + Safety, + SafetyViolation, + ViolationLevel, +) +from llama_stack.apis.shields import Shield +from llama_stack.providers.datatypes import ShieldsProtocolPrivate + +from .config import FiddlecubeSafetyConfig + + +logger = logging.getLogger(__name__) + + +class FiddlecubeSafetyAdapter(Safety, ShieldsProtocolPrivate): + def __init__(self, config: FiddlecubeSafetyConfig) -> None: + self.config = config + self.registered_shields = [] + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def register_shield(self, shield: Shield) -> None: + pass + + async def run_shield( + self, shield_id: str, messages: List[Message], params: Dict[str, Any] = None + ) -> RunShieldResponse: + # Set up FiddleCube API using httpx + # [TBD] convert the `messages` into format FiddleCube expects + # make a call to the API for guardrails + # convert the [TBD] response into the format RunShieldResponse expects + # return the response + return RunShieldResponse() + + shield = await self.shield_store.get_shield(shield_id) + if not shield: + raise ValueError(f"Shield {shield_id} not found") + + """This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format + ```content = [ + { + "text": { + "text": "Is the AB503 Product a better investment than the S&P 500?" + } + } + ]``` + However the incoming messages are of this type UserMessage(content=....) coming from + https://github.com/meta-llama/llama-models/blob/main/models/llama3/api/datatypes.py + + They contain content, role . For now we will extract the content and default the "qualifiers": ["query"] + """ + + shield_params = shield.params + logger.debug(f"run_shield::{shield_params}::messages={messages}") + + # - convert the messages into format Bedrock expects + content_messages = [] + for message in messages: + content_messages.append({"text": {"text": message.content}}) + logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:") + + response = self.bedrock_runtime_client.apply_guardrail( + guardrailIdentifier=shield.provider_resource_id, + guardrailVersion=shield_params["guardrailVersion"], + source="OUTPUT", # or 'INPUT' depending on your use case + content=content_messages, + ) + if response["action"] == "GUARDRAIL_INTERVENED": + user_message = "" + metadata = {} + for output in response["outputs"]: + # guardrails returns a list - however for this implementation we will leverage the last values + user_message = output["text"] + for assessment in response["assessments"]: + # guardrails returns a list - however for this implementation we will leverage the last values + metadata = dict(assessment) + + return RunShieldResponse( + violation=SafetyViolation( + user_message=user_message, + violation_level=ViolationLevel.ERROR, + metadata=metadata, + ) + ) + + return RunShieldResponse()