diff --git a/src/llama_stack/providers/remote/safety/nvidia/nvidia.py b/src/llama_stack/providers/remote/safety/nvidia/nvidia.py index 43ff45cc9..d217647a4 100644 --- a/src/llama_stack/providers/remote/safety/nvidia/nvidia.py +++ b/src/llama_stack/providers/remote/safety/nvidia/nvidia.py @@ -125,43 +125,57 @@ class NeMoGuardrails: async def run(self, messages: list[OpenAIMessageParam]) -> RunShieldResponse: """ - Queries the /v1/guardrails/checks endpoint of the NeMo guardrails deployed API. + Queries the /v1/chat/completions endpoint of the NeMo guardrails deployed API. Args: messages (List[Message]): A list of Message objects to be checked for safety violations. Returns: - RunShieldResponse: If the response indicates a violation ("blocked" status), returns a - RunShieldResponse with a SafetyViolation; otherwise, returns a RunShieldResponse with violation set to None. + RunShieldResponse: Response with SafetyViolation if content is blocked, None otherwise. Raises: requests.HTTPError: If the POST request fails. """ request_data = { - "model": self.model, + "config_id": self.config_id, "messages": [{"role": message.role, "content": message.content} for message in messages], - "temperature": self.temperature, - "top_p": 1, - "frequency_penalty": 0, - "presence_penalty": 0, - "max_tokens": 160, - "stream": False, - "guardrails": { - "config_id": self.config_id, - }, } - response = await self._guardrails_post(path="/v1/guardrail/checks", data=request_data) - - if response["status"] == "blocked": - user_message = "Sorry I cannot do this." - metadata = response["rails_status"] + response = await self._guardrails_post(path="/v1/chat/completions", data=request_data) + # Support legacy format with explicit status field + if "status" in response and response["status"] == "blocked": return RunShieldResponse( violation=SafetyViolation( - user_message=user_message, + user_message="Sorry I cannot do this.", violation_level=ViolationLevel.ERROR, - metadata=metadata, + metadata=response.get("rails_status", {}), ) ) - return RunShieldResponse(violation=None) + # NOTE: The implementation targets the actual behavior of the NeMo Guardrails server + # as defined in 'nemoguardrails/server/api.py'. The 'RequestBody' class accepts + # 'config_id' at the top level, and 'ResponseBody' returns a 'messages' array, + # distinct from the OpenAI 'choices' format often referenced in documentation. + + response_messages = response.get("messages", []) + if response_messages: + content = response_messages[0].get("content", "").strip() + else: + choices = response.get("choices", []) + if choices: + content = choices[0].get("message", {}).get("content", "").strip() + else: + content = "" + + refusal_phrases = ["sorry i cannot do this", "i cannot help with that", "i can't assist with that"] + is_blocked = not content or any(phrase in content.lower() for phrase in refusal_phrases) + + return RunShieldResponse( + violation=SafetyViolation( + user_message="Sorry I cannot do this.", + violation_level=ViolationLevel.ERROR, + metadata={"reason": "Content violates safety guidelines", "response": content or "(empty)"}, + ) + if is_blocked + else None + ) diff --git a/tests/unit/providers/nvidia/test_safety.py b/tests/unit/providers/nvidia/test_safety.py index 07e04ddea..b99e4dd88 100644 --- a/tests/unit/providers/nvidia/test_safety.py +++ b/tests/unit/providers/nvidia/test_safety.py @@ -152,22 +152,13 @@ async def test_run_shield_allowed(nvidia_adapter, mock_guardrails_post): # Verify the Guardrails API was called correctly mock_guardrails_post.assert_called_once_with( - path="/v1/guardrail/checks", + path="/v1/chat/completions", data={ - "model": shield_id, + "config_id": "self-check", "messages": [ {"role": "user", "content": "Hello, how are you?"}, {"role": "assistant", "content": "I'm doing well, thank you for asking!"}, ], - "temperature": 1.0, - "top_p": 1, - "frequency_penalty": 0, - "presence_penalty": 0, - "max_tokens": 160, - "stream": False, - "guardrails": { - "config_id": "self-check", - }, }, ) @@ -206,22 +197,13 @@ async def test_run_shield_blocked(nvidia_adapter, mock_guardrails_post): # Verify the Guardrails API was called correctly mock_guardrails_post.assert_called_once_with( - path="/v1/guardrail/checks", + path="/v1/chat/completions", data={ - "model": shield_id, + "config_id": "self-check", "messages": [ {"role": "user", "content": "Hello, how are you?"}, {"role": "assistant", "content": "I'm doing well, thank you for asking!"}, ], - "temperature": 1.0, - "top_p": 1, - "frequency_penalty": 0, - "presence_penalty": 0, - "max_tokens": 160, - "stream": False, - "guardrails": { - "config_id": "self-check", - }, }, ) @@ -286,22 +268,13 @@ async def test_run_shield_http_error(nvidia_adapter, mock_guardrails_post): # Verify the Guardrails API was called correctly mock_guardrails_post.assert_called_once_with( - path="/v1/guardrail/checks", + path="/v1/chat/completions", data={ - "model": shield_id, + "config_id": "self-check", "messages": [ {"role": "user", "content": "Hello, how are you?"}, {"role": "assistant", "content": "I'm doing well, thank you for asking!"}, ], - "temperature": 1.0, - "top_p": 1, - "frequency_penalty": 0, - "presence_penalty": 0, - "max_tokens": 160, - "stream": False, - "guardrails": { - "config_id": "self-check", - }, }, ) # Verify the exception message