add safety violation code

also improved error handling.
used  to get  fields to prevent field errors.

process safety violation sent from the server and return a
response
This commit is contained in:
Kaushik 2025-02-10 15:51:51 -08:00
parent 49f7e04f83
commit 319df22390

View file

@ -15,6 +15,7 @@ from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
)
from llama_stack.apis.safety.safety import SafetyViolation, ViolationLevel
from llama_stack.apis.shields import Shield
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
@ -41,41 +42,6 @@ class FiddlecubeSafetyAdapter(Safety, ShieldsProtocolPrivate):
async def run_shield(
self, shield_id: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse:
# Set up FiddleCube API using httpx
# [TBD] convert the `messages` into format FiddleCube expects
# make a call to the API for guardrails
# convert the [TBD] response into the format RunShieldResponse expects
# return the response
return RunShieldResponse()
shield = await self.shield_store.get_shield(shield_id)
if not shield:
raise ValueError(f"Shield {shield_id} not found")
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
```content = [
{
"text": {
"text": "Is the AB503 Product a better investment than the S&P 500?"
}
}
]```
However the incoming messages are of this type UserMessage(content=....) coming from
https://github.com/meta-llama/llama-models/blob/main/models/llama3/api/datatypes.py
They contain content, role . For now we will extract the content and default the "qualifiers": ["query"]
"""
shield_params = shield.params
logger.debug(f"run_shield::{shield_params}::messages={messages}")
# - convert the messages into format Bedrock expects
content_messages = []
for message in messages:
content_messages.append({"text": {"text": message.content}})
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
# Make a call to the FiddleCube API for guardrails
async with httpx.AsyncClient(timeout=30.0) as client:
request_body = {
"messages": [message.model_dump(mode="json") for message in messages],
@ -98,6 +64,27 @@ class FiddlecubeSafetyAdapter(Safety, ShieldsProtocolPrivate):
# Convert the response into the format RunShieldResponse expects
response_data = response.json()
logger.debug("Response data", response_data)
logger.debug("Response data: %s", json.dumps(response_data, indent=2))
# Check if there's a violation based on the response structure
if response_data.get("action") == "GUARDRAIL_INTERVENED":
user_message = ""
metadata = {}
outputs = response_data.get("outputs", [])
if outputs:
user_message = outputs[-1].get("text", "Safety violation detected")
assessments = response_data.get("assessments", [])
for assessment in assessments:
metadata.update(dict(assessment))
return RunShieldResponse(
violation=SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata,
)
)
return RunShieldResponse()