mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-07 02:58:21 +00:00
add safety violation code
also improved error handling. used to get fields to prevent field errors. process safety violation sent from the server and return a response
This commit is contained in:
parent
49f7e04f83
commit
319df22390
1 changed files with 23 additions and 36 deletions
|
@ -15,6 +15,7 @@ from llama_stack.apis.safety import (
|
|||
RunShieldResponse,
|
||||
Safety,
|
||||
)
|
||||
from llama_stack.apis.safety.safety import SafetyViolation, ViolationLevel
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
|
||||
|
@ -41,41 +42,6 @@ class FiddlecubeSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|||
async def run_shield(
|
||||
self, shield_id: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
) -> RunShieldResponse:
|
||||
# Set up FiddleCube API using httpx
|
||||
# [TBD] convert the `messages` into format FiddleCube expects
|
||||
# make a call to the API for guardrails
|
||||
# convert the [TBD] response into the format RunShieldResponse expects
|
||||
# return the response
|
||||
return RunShieldResponse()
|
||||
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
if not shield:
|
||||
raise ValueError(f"Shield {shield_id} not found")
|
||||
|
||||
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
|
||||
```content = [
|
||||
{
|
||||
"text": {
|
||||
"text": "Is the AB503 Product a better investment than the S&P 500?"
|
||||
}
|
||||
}
|
||||
]```
|
||||
However the incoming messages are of this type UserMessage(content=....) coming from
|
||||
https://github.com/meta-llama/llama-models/blob/main/models/llama3/api/datatypes.py
|
||||
|
||||
They contain content, role . For now we will extract the content and default the "qualifiers": ["query"]
|
||||
"""
|
||||
|
||||
shield_params = shield.params
|
||||
logger.debug(f"run_shield::{shield_params}::messages={messages}")
|
||||
|
||||
# - convert the messages into format Bedrock expects
|
||||
content_messages = []
|
||||
for message in messages:
|
||||
content_messages.append({"text": {"text": message.content}})
|
||||
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
|
||||
|
||||
# Make a call to the FiddleCube API for guardrails
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
request_body = {
|
||||
"messages": [message.model_dump(mode="json") for message in messages],
|
||||
|
@ -98,6 +64,27 @@ class FiddlecubeSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|||
|
||||
# Convert the response into the format RunShieldResponse expects
|
||||
response_data = response.json()
|
||||
logger.debug("Response data", response_data)
|
||||
logger.debug("Response data: %s", json.dumps(response_data, indent=2))
|
||||
|
||||
# Check if there's a violation based on the response structure
|
||||
if response_data.get("action") == "GUARDRAIL_INTERVENED":
|
||||
user_message = ""
|
||||
metadata = {}
|
||||
|
||||
outputs = response_data.get("outputs", [])
|
||||
if outputs:
|
||||
user_message = outputs[-1].get("text", "Safety violation detected")
|
||||
|
||||
assessments = response_data.get("assessments", [])
|
||||
for assessment in assessments:
|
||||
metadata.update(dict(assessment))
|
||||
|
||||
return RunShieldResponse(
|
||||
violation=SafetyViolation(
|
||||
user_message=user_message,
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
|
||||
return RunShieldResponse()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue