adding test fixtures for fiddlecube safety provider

This commit is contained in:
Kaushik 2025-02-17 10:55:51 +05:30
parent 5492a13c79
commit 2da0cd79c7
2 changed files with 8 additions and 15 deletions

View file

@ -47,7 +47,7 @@ class FiddlecubeSafetyAdapter(Safety, ShieldsProtocolPrivate):
request_body = {
"messages": [message.model_dump(mode="json") for message in messages],
}
if params.get("excluded_categories"):
if params and params.get("excluded_categories"):
request_body["excluded_categories"] = params.get("excluded_categories")
headers = {"Content-Type": "application/json"}
response = await client.post(
@ -62,22 +62,15 @@ class FiddlecubeSafetyAdapter(Safety, ShieldsProtocolPrivate):
response_data = response.json()
if response_data.get("action") == "GUARDRAIL_INTERVENED":
user_message = ""
metadata = {}
outputs = response_data.get("outputs", [])
if outputs:
user_message = outputs[-1].get("text", "Safety violation detected")
assessments = response_data.get("assessments", [])
for assessment in assessments:
metadata.update(dict(assessment))
if response_data.get("violation"):
violation = response_data.get("violation")
user_message = violation.get("user_message")
metadata = violation.get("metadata")
violation_level = ViolationLevel(violation.get("violation_level"))
return RunShieldResponse(
violation=SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
violation_level=violation_level,
metadata=metadata,
)
)

View file

@ -55,7 +55,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
),
pytest.param(
{
"inference": "bedrock",
"inference": "together",
"safety": "fiddlecube",
},
id="fiddlecube",