diff --git a/llama_stack/providers/remote/safety/fiddlecube/fiddlecube.py b/llama_stack/providers/remote/safety/fiddlecube/fiddlecube.py index 9a468f663..7a4faea2e 100644 --- a/llama_stack/providers/remote/safety/fiddlecube/fiddlecube.py +++ b/llama_stack/providers/remote/safety/fiddlecube/fiddlecube.py @@ -47,7 +47,7 @@ class FiddlecubeSafetyAdapter(Safety, ShieldsProtocolPrivate): request_body = { "messages": [message.model_dump(mode="json") for message in messages], } - if params.get("excluded_categories"): + if params and params.get("excluded_categories"): request_body["excluded_categories"] = params.get("excluded_categories") headers = {"Content-Type": "application/json"} response = await client.post( @@ -62,22 +62,15 @@ class FiddlecubeSafetyAdapter(Safety, ShieldsProtocolPrivate): response_data = response.json() - if response_data.get("action") == "GUARDRAIL_INTERVENED": - user_message = "" - metadata = {} - - outputs = response_data.get("outputs", []) - if outputs: - user_message = outputs[-1].get("text", "Safety violation detected") - - assessments = response_data.get("assessments", []) - for assessment in assessments: - metadata.update(dict(assessment)) - + if response_data.get("violation"): + violation = response_data.get("violation") + user_message = violation.get("user_message") + metadata = violation.get("metadata") + violation_level = ViolationLevel(violation.get("violation_level")) return RunShieldResponse( violation=SafetyViolation( user_message=user_message, - violation_level=ViolationLevel.ERROR, + violation_level=violation_level, metadata=metadata, ) ) diff --git a/llama_stack/providers/tests/safety/conftest.py b/llama_stack/providers/tests/safety/conftest.py index 98b19616c..133aa4f5a 100644 --- a/llama_stack/providers/tests/safety/conftest.py +++ b/llama_stack/providers/tests/safety/conftest.py @@ -55,7 +55,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ ), pytest.param( { - "inference": "bedrock", + "inference": "together", "safety": "fiddlecube", }, id="fiddlecube",