From c580bae3694d7e53a5c0ecab0ee2dec15eaddbe2 Mon Sep 17 00:00:00 2001 From: r-bit-rry Date: Mon, 24 Nov 2025 22:16:50 +0200 Subject: [PATCH] further changes to nvidia.py --- .../providers/remote/safety/nvidia/nvidia.py | 39 ++++++++++++++----- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/src/llama_stack/providers/remote/safety/nvidia/nvidia.py b/src/llama_stack/providers/remote/safety/nvidia/nvidia.py index bb6fcee1b..9e24dd109 100644 --- a/src/llama_stack/providers/remote/safety/nvidia/nvidia.py +++ b/src/llama_stack/providers/remote/safety/nvidia/nvidia.py @@ -131,8 +131,7 @@ class NeMoGuardrails: messages (List[Message]): A list of Message objects to be checked for safety violations. Returns: - RunShieldResponse: If the response indicates a violation ("blocked" status), returns a - RunShieldResponse with a SafetyViolation; otherwise, returns a RunShieldResponse with violation set to None. + RunShieldResponse: Response with SafetyViolation if content is blocked, None otherwise. Raises: requests.HTTPError: If the POST request fails. @@ -143,16 +142,38 @@ class NeMoGuardrails: } response = await self._guardrails_post(path="/v1/chat/completions", data=request_data) - if response["status"] == "blocked": - user_message = "Sorry I cannot do this." - metadata = response["rails_status"] - + # Support legacy format with explicit status field + if "status" in response and response["status"] == "blocked": return RunShieldResponse( violation=SafetyViolation( - user_message=user_message, + user_message="Sorry I cannot do this.", violation_level=ViolationLevel.ERROR, - metadata=metadata, + metadata=response.get("rails_status", {}), ) ) - return RunShieldResponse(violation=None) + # NOTE: The implementation targets the actual behavior of the NeMo Guardrails server + # as defined in 'nemoguardrails/server/api.py'. The 'RequestBody' class accepts + # 'config_id' at the top level, and 'ResponseBody' returns a 'messages' array, + # distinct from the OpenAI 'choices' format often referenced in documentation. + + response_messages = response.get("messages", []) + if response_messages: + content = response_messages[0].get("content", "").strip() + else: + choices = response.get("choices", []) + if choices: + content = choices[0].get("message", {}).get("content", "").strip() + else: + content = "" + + refusal_phrases = ["sorry i cannot do this", "i cannot help with that", "i can't assist with that"] + is_blocked = not content or any(phrase in content.lower() for phrase in refusal_phrases) + + return RunShieldResponse( + violation=SafetyViolation( + user_message="Sorry I cannot do this.", + violation_level=ViolationLevel.ERROR, + metadata={"reason": "Content violates safety guidelines", "response": content or "(empty)"}, + ) if is_blocked else None + )