From 319df223904d2e88a0a0b605322cff8205f568e7 Mon Sep 17 00:00:00 2001 From: Kaushik Date: Mon, 10 Feb 2025 15:51:51 -0800 Subject: [PATCH] add safety violation code also improved error handling. used to get fields to prevent field errors. process safety violation sent from the server and return a response --- .../remote/safety/fiddlecube/fiddlecube.py | 59 ++++++++----------- 1 file changed, 23 insertions(+), 36 deletions(-) diff --git a/llama_stack/providers/remote/safety/fiddlecube/fiddlecube.py b/llama_stack/providers/remote/safety/fiddlecube/fiddlecube.py index dd5c49da5..9c78a6c99 100644 --- a/llama_stack/providers/remote/safety/fiddlecube/fiddlecube.py +++ b/llama_stack/providers/remote/safety/fiddlecube/fiddlecube.py @@ -15,6 +15,7 @@ from llama_stack.apis.safety import ( RunShieldResponse, Safety, ) +from llama_stack.apis.safety.safety import SafetyViolation, ViolationLevel from llama_stack.apis.shields import Shield from llama_stack.providers.datatypes import ShieldsProtocolPrivate @@ -41,41 +42,6 @@ class FiddlecubeSafetyAdapter(Safety, ShieldsProtocolPrivate): async def run_shield( self, shield_id: str, messages: List[Message], params: Dict[str, Any] = None ) -> RunShieldResponse: - # Set up FiddleCube API using httpx - # [TBD] convert the `messages` into format FiddleCube expects - # make a call to the API for guardrails - # convert the [TBD] response into the format RunShieldResponse expects - # return the response - return RunShieldResponse() - - shield = await self.shield_store.get_shield(shield_id) - if not shield: - raise ValueError(f"Shield {shield_id} not found") - - """This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format - ```content = [ - { - "text": { - "text": "Is the AB503 Product a better investment than the S&P 500?" - } - } - ]``` - However the incoming messages are of this type UserMessage(content=....) coming from - https://github.com/meta-llama/llama-models/blob/main/models/llama3/api/datatypes.py - - They contain content, role . For now we will extract the content and default the "qualifiers": ["query"] - """ - - shield_params = shield.params - logger.debug(f"run_shield::{shield_params}::messages={messages}") - - # - convert the messages into format Bedrock expects - content_messages = [] - for message in messages: - content_messages.append({"text": {"text": message.content}}) - logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:") - - # Make a call to the FiddleCube API for guardrails async with httpx.AsyncClient(timeout=30.0) as client: request_body = { "messages": [message.model_dump(mode="json") for message in messages], @@ -98,6 +64,27 @@ class FiddlecubeSafetyAdapter(Safety, ShieldsProtocolPrivate): # Convert the response into the format RunShieldResponse expects response_data = response.json() - logger.debug("Response data", response_data) + logger.debug("Response data: %s", json.dumps(response_data, indent=2)) + + # Check if there's a violation based on the response structure + if response_data.get("action") == "GUARDRAIL_INTERVENED": + user_message = "" + metadata = {} + + outputs = response_data.get("outputs", []) + if outputs: + user_message = outputs[-1].get("text", "Safety violation detected") + + assessments = response_data.get("assessments", []) + for assessment in assessments: + metadata.update(dict(assessment)) + + return RunShieldResponse( + violation=SafetyViolation( + user_message=user_message, + violation_level=ViolationLevel.ERROR, + metadata=metadata, + ) + ) return RunShieldResponse()