mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
adding test fixtures for fiddlecube safety provider
This commit is contained in:
parent
5492a13c79
commit
2da0cd79c7
2 changed files with 8 additions and 15 deletions
|
@ -47,7 +47,7 @@ class FiddlecubeSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
request_body = {
|
request_body = {
|
||||||
"messages": [message.model_dump(mode="json") for message in messages],
|
"messages": [message.model_dump(mode="json") for message in messages],
|
||||||
}
|
}
|
||||||
if params.get("excluded_categories"):
|
if params and params.get("excluded_categories"):
|
||||||
request_body["excluded_categories"] = params.get("excluded_categories")
|
request_body["excluded_categories"] = params.get("excluded_categories")
|
||||||
headers = {"Content-Type": "application/json"}
|
headers = {"Content-Type": "application/json"}
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
|
@ -62,22 +62,15 @@ class FiddlecubeSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
|
|
||||||
response_data = response.json()
|
response_data = response.json()
|
||||||
|
|
||||||
if response_data.get("action") == "GUARDRAIL_INTERVENED":
|
if response_data.get("violation"):
|
||||||
user_message = ""
|
violation = response_data.get("violation")
|
||||||
metadata = {}
|
user_message = violation.get("user_message")
|
||||||
|
metadata = violation.get("metadata")
|
||||||
outputs = response_data.get("outputs", [])
|
violation_level = ViolationLevel(violation.get("violation_level"))
|
||||||
if outputs:
|
|
||||||
user_message = outputs[-1].get("text", "Safety violation detected")
|
|
||||||
|
|
||||||
assessments = response_data.get("assessments", [])
|
|
||||||
for assessment in assessments:
|
|
||||||
metadata.update(dict(assessment))
|
|
||||||
|
|
||||||
return RunShieldResponse(
|
return RunShieldResponse(
|
||||||
violation=SafetyViolation(
|
violation=SafetyViolation(
|
||||||
user_message=user_message,
|
user_message=user_message,
|
||||||
violation_level=ViolationLevel.ERROR,
|
violation_level=violation_level,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -55,7 +55,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
),
|
),
|
||||||
pytest.param(
|
pytest.param(
|
||||||
{
|
{
|
||||||
"inference": "bedrock",
|
"inference": "together",
|
||||||
"safety": "fiddlecube",
|
"safety": "fiddlecube",
|
||||||
},
|
},
|
||||||
id="fiddlecube",
|
id="fiddlecube",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue