diff --git a/src/llama_stack/providers/remote/safety/nvidia/nvidia.py b/src/llama_stack/providers/remote/safety/nvidia/nvidia.py index 43ff45cc9..bb6fcee1b 100644 --- a/src/llama_stack/providers/remote/safety/nvidia/nvidia.py +++ b/src/llama_stack/providers/remote/safety/nvidia/nvidia.py @@ -125,7 +125,7 @@ class NeMoGuardrails: async def run(self, messages: list[OpenAIMessageParam]) -> RunShieldResponse: """ - Queries the /v1/guardrails/checks endpoint of the NeMo guardrails deployed API. + Queries the /v1/chat/completions endpoint of the NeMo guardrails deployed API. Args: messages (List[Message]): A list of Message objects to be checked for safety violations. @@ -138,19 +138,10 @@ class NeMoGuardrails: requests.HTTPError: If the POST request fails. """ request_data = { - "model": self.model, + "config_id": self.config_id, "messages": [{"role": message.role, "content": message.content} for message in messages], - "temperature": self.temperature, - "top_p": 1, - "frequency_penalty": 0, - "presence_penalty": 0, - "max_tokens": 160, - "stream": False, - "guardrails": { - "config_id": self.config_id, - }, } - response = await self._guardrails_post(path="/v1/guardrail/checks", data=request_data) + response = await self._guardrails_post(path="/v1/chat/completions", data=request_data) if response["status"] == "blocked": user_message = "Sorry I cannot do this." diff --git a/tests/unit/providers/nvidia/test_safety.py b/tests/unit/providers/nvidia/test_safety.py index 07e04ddea..b99e4dd88 100644 --- a/tests/unit/providers/nvidia/test_safety.py +++ b/tests/unit/providers/nvidia/test_safety.py @@ -152,22 +152,13 @@ async def test_run_shield_allowed(nvidia_adapter, mock_guardrails_post): # Verify the Guardrails API was called correctly mock_guardrails_post.assert_called_once_with( - path="/v1/guardrail/checks", + path="/v1/chat/completions", data={ - "model": shield_id, + "config_id": "self-check", "messages": [ {"role": "user", "content": "Hello, how are you?"}, {"role": "assistant", "content": "I'm doing well, thank you for asking!"}, ], - "temperature": 1.0, - "top_p": 1, - "frequency_penalty": 0, - "presence_penalty": 0, - "max_tokens": 160, - "stream": False, - "guardrails": { - "config_id": "self-check", - }, }, ) @@ -206,22 +197,13 @@ async def test_run_shield_blocked(nvidia_adapter, mock_guardrails_post): # Verify the Guardrails API was called correctly mock_guardrails_post.assert_called_once_with( - path="/v1/guardrail/checks", + path="/v1/chat/completions", data={ - "model": shield_id, + "config_id": "self-check", "messages": [ {"role": "user", "content": "Hello, how are you?"}, {"role": "assistant", "content": "I'm doing well, thank you for asking!"}, ], - "temperature": 1.0, - "top_p": 1, - "frequency_penalty": 0, - "presence_penalty": 0, - "max_tokens": 160, - "stream": False, - "guardrails": { - "config_id": "self-check", - }, }, ) @@ -286,22 +268,13 @@ async def test_run_shield_http_error(nvidia_adapter, mock_guardrails_post): # Verify the Guardrails API was called correctly mock_guardrails_post.assert_called_once_with( - path="/v1/guardrail/checks", + path="/v1/chat/completions", data={ - "model": shield_id, + "config_id": "self-check", "messages": [ {"role": "user", "content": "Hello, how are you?"}, {"role": "assistant", "content": "I'm doing well, thank you for asking!"}, ], - "temperature": 1.0, - "top_p": 1, - "frequency_penalty": 0, - "presence_penalty": 0, - "max_tokens": 160, - "stream": False, - "guardrails": { - "config_id": "self-check", - }, }, ) # Verify the exception message