forked from phoenix-oss/llama-stack-mirror
feat: Add unit tests for NVIDIA safety (#1897)
# What does this PR do? This PR adds unit tests for the NVIDIA Safety provider implementation. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] 1. Ran `./scripts/unit-tests.sh tests/unit/providers/nvidia/test_safety.py` from the root of the project. Verified tests pass. ``` tests/unit/providers/nvidia/test_safety.py::TestNVIDIASafetyAdapter::test_init_nemo_guardrails Initializing NVIDIASafetyAdapter(http://nemo.test)... PASSED tests/unit/providers/nvidia/test_safety.py::TestNVIDIASafetyAdapter::test_init_nemo_guardrails_invalid_temperature Initializing NVIDIASafetyAdapter(http://nemo.test)... PASSED tests/unit/providers/nvidia/test_safety.py::TestNVIDIASafetyAdapter::test_register_shield_with_valid_id Initializing NVIDIASafetyAdapter(http://nemo.test)... PASSED tests/unit/providers/nvidia/test_safety.py::TestNVIDIASafetyAdapter::test_register_shield_without_id Initializing NVIDIASafetyAdapter(http://nemo.test)... PASSED tests/unit/providers/nvidia/test_safety.py::TestNVIDIASafetyAdapter::test_run_shield_allowed Initializing NVIDIASafetyAdapter(http://nemo.test)... PASSED tests/unit/providers/nvidia/test_safety.py::TestNVIDIASafetyAdapter::test_run_shield_blocked Initializing NVIDIASafetyAdapter(http://nemo.test)... PASSED tests/unit/providers/nvidia/test_safety.py::TestNVIDIASafetyAdapter::test_run_shield_http_error Initializing NVIDIASafetyAdapter(http://nemo.test)... PASSED tests/unit/providers/nvidia/test_safety.py::TestNVIDIASafetyAdapter::test_run_shield_not_found Initializing NVIDIASafetyAdapter(http://nemo.test)... PASSED ``` [//]: # (## Documentation) --------- Co-authored-by: Jash Gulabrai <jgulabrai@nvidia.com>
This commit is contained in:
parent
2a74f0db39
commit
c1cb6aad11
2 changed files with 340 additions and 11 deletions
|
@ -104,6 +104,15 @@ class NeMoGuardrails:
|
|||
self.threshold = threshold
|
||||
self.guardrails_service_url = config.guardrails_service_url
|
||||
|
||||
async def _guardrails_post(self, path: str, data: Any | None):
|
||||
"""Helper for making POST requests to the guardrails service."""
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
}
|
||||
response = requests.post(url=f"{self.guardrails_service_url}{path}", headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def run(self, messages: List[Message]) -> RunShieldResponse:
|
||||
"""
|
||||
Queries the /v1/guardrails/checks endpoint of the NeMo guardrails deployed API.
|
||||
|
@ -118,9 +127,6 @@ class NeMoGuardrails:
|
|||
Raises:
|
||||
requests.HTTPError: If the POST request fails.
|
||||
"""
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
}
|
||||
request_data = {
|
||||
"model": self.model,
|
||||
"messages": convert_pydantic_to_json_value(messages),
|
||||
|
@ -134,15 +140,11 @@ class NeMoGuardrails:
|
|||
"config_id": self.config_id,
|
||||
},
|
||||
}
|
||||
response = requests.post(
|
||||
url=f"{self.guardrails_service_url}/v1/guardrail/checks", headers=headers, json=request_data
|
||||
)
|
||||
response.raise_for_status()
|
||||
if "Content-Type" in response.headers and response.headers["Content-Type"].startswith("application/json"):
|
||||
response_json = response.json()
|
||||
if response_json["status"] == "blocked":
|
||||
response = await self._guardrails_post(path="/v1/guardrail/checks", data=request_data)
|
||||
|
||||
if response["status"] == "blocked":
|
||||
user_message = "Sorry I cannot do this."
|
||||
metadata = response_json["rails_status"]
|
||||
metadata = response["rails_status"]
|
||||
|
||||
return RunShieldResponse(
|
||||
violation=SafetyViolation(
|
||||
|
@ -151,4 +153,5 @@ class NeMoGuardrails:
|
|||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
|
||||
return RunShieldResponse(violation=None)
|
||||
|
|
326
tests/unit/providers/nvidia/test_safety.py
Normal file
326
tests/unit/providers/nvidia/test_safety.py
Normal file
|
@ -0,0 +1,326 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import os
|
||||
import unittest
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.inference.inference import CompletionMessage, UserMessage
|
||||
from llama_stack.apis.safety import RunShieldResponse, ViolationLevel
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig
|
||||
from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter
|
||||
|
||||
|
||||
class TestNVIDIASafetyAdapter(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
|
||||
|
||||
# Initialize the adapter
|
||||
self.config = NVIDIASafetyConfig(
|
||||
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
||||
)
|
||||
self.adapter = NVIDIASafetyAdapter(config=self.config)
|
||||
self.shield_store = AsyncMock()
|
||||
self.adapter.shield_store = self.shield_store
|
||||
|
||||
# Mock the HTTP request methods
|
||||
self.guardrails_post_patcher = patch(
|
||||
"llama_stack.providers.remote.safety.nvidia.nvidia.NeMoGuardrails._guardrails_post"
|
||||
)
|
||||
self.mock_guardrails_post = self.guardrails_post_patcher.start()
|
||||
self.mock_guardrails_post.return_value = {"status": "allowed"}
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up after each test."""
|
||||
self.guardrails_post_patcher.stop()
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, run_async):
|
||||
self.run_async = run_async
|
||||
|
||||
def _assert_request(
|
||||
self,
|
||||
mock_call: MagicMock,
|
||||
expected_url: str,
|
||||
expected_headers: dict[str, str] | None = None,
|
||||
expected_json: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Helper method to verify request details in mock API calls.
|
||||
|
||||
Args:
|
||||
mock_call: The MagicMock object that was called
|
||||
expected_url: The expected URL to which the request was made
|
||||
expected_headers: Optional dictionary of expected request headers
|
||||
expected_json: Optional dictionary of expected JSON payload
|
||||
"""
|
||||
call_args = mock_call.call_args
|
||||
|
||||
# Check URL
|
||||
assert call_args[0][0] == expected_url
|
||||
|
||||
# Check headers if provided
|
||||
if expected_headers:
|
||||
for key, value in expected_headers.items():
|
||||
assert call_args[1]["headers"][key] == value
|
||||
|
||||
# Check JSON if provided
|
||||
if expected_json:
|
||||
for key, value in expected_json.items():
|
||||
if isinstance(value, dict):
|
||||
for nested_key, nested_value in value.items():
|
||||
assert call_args[1]["json"][key][nested_key] == nested_value
|
||||
else:
|
||||
assert call_args[1]["json"][key] == value
|
||||
|
||||
def test_register_shield_with_valid_id(self):
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type="shield",
|
||||
identifier="test-shield",
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
|
||||
# Register the shield
|
||||
self.run_async(self.adapter.register_shield(shield))
|
||||
|
||||
def test_register_shield_without_id(self):
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type="shield",
|
||||
identifier="test-shield",
|
||||
provider_resource_id="",
|
||||
)
|
||||
|
||||
# Register the shield should raise a ValueError
|
||||
with self.assertRaises(ValueError):
|
||||
self.run_async(self.adapter.register_shield(shield))
|
||||
|
||||
def test_run_shield_allowed(self):
|
||||
# Set up the shield
|
||||
shield_id = "test-shield"
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type="shield",
|
||||
identifier=shield_id,
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
self.shield_store.get_shield.return_value = shield
|
||||
|
||||
# Mock Guardrails API response
|
||||
self.mock_guardrails_post.return_value = {"status": "allowed"}
|
||||
|
||||
# Run the shield
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason="end_of_message",
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
result = self.run_async(self.adapter.run_shield(shield_id, messages))
|
||||
|
||||
# Verify the shield store was called
|
||||
self.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
|
||||
# Verify the Guardrails API was called correctly
|
||||
self.mock_guardrails_post.assert_called_once_with(
|
||||
path="/v1/guardrail/checks",
|
||||
data={
|
||||
"model": shield_id,
|
||||
"messages": [
|
||||
json.loads(messages[0].model_dump_json()),
|
||||
json.loads(messages[1].model_dump_json()),
|
||||
],
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"max_tokens": 160,
|
||||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": "self-check",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, RunShieldResponse)
|
||||
assert result.violation is None
|
||||
|
||||
def test_run_shield_blocked(self):
|
||||
# Set up the shield
|
||||
shield_id = "test-shield"
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type="shield",
|
||||
identifier=shield_id,
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
self.shield_store.get_shield.return_value = shield
|
||||
|
||||
# Mock Guardrails API response
|
||||
self.mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}}
|
||||
|
||||
# Run the shield
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason="end_of_message",
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
result = self.run_async(self.adapter.run_shield(shield_id, messages))
|
||||
|
||||
# Verify the shield store was called
|
||||
self.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
|
||||
# Verify the Guardrails API was called correctly
|
||||
self.mock_guardrails_post.assert_called_once_with(
|
||||
path="/v1/guardrail/checks",
|
||||
data={
|
||||
"model": shield_id,
|
||||
"messages": [
|
||||
json.loads(messages[0].model_dump_json()),
|
||||
json.loads(messages[1].model_dump_json()),
|
||||
],
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"max_tokens": 160,
|
||||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": "self-check",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result.violation is not None
|
||||
assert isinstance(result, RunShieldResponse)
|
||||
assert result.violation.user_message == "Sorry I cannot do this."
|
||||
assert result.violation.violation_level == ViolationLevel.ERROR
|
||||
assert result.violation.metadata == {"reason": "harmful_content"}
|
||||
|
||||
def test_run_shield_not_found(self):
|
||||
# Set up shield store to return None
|
||||
shield_id = "non-existent-shield"
|
||||
self.shield_store.get_shield.return_value = None
|
||||
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
]
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
self.run_async(self.adapter.run_shield(shield_id, messages))
|
||||
|
||||
# Verify the shield store was called
|
||||
self.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
|
||||
# Verify the Guardrails API was not called
|
||||
self.mock_guardrails_post.assert_not_called()
|
||||
|
||||
def test_run_shield_http_error(self):
|
||||
shield_id = "test-shield"
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type="shield",
|
||||
identifier=shield_id,
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
self.shield_store.get_shield.return_value = shield
|
||||
|
||||
# Mock Guardrails API to raise an exception
|
||||
error_msg = "API Error: 500 Internal Server Error"
|
||||
self.mock_guardrails_post.side_effect = Exception(error_msg)
|
||||
|
||||
# Running the shield should raise an exception
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason="end_of_message",
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
with self.assertRaises(Exception) as context:
|
||||
self.run_async(self.adapter.run_shield(shield_id, messages))
|
||||
|
||||
# Verify the shield store was called
|
||||
self.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
|
||||
# Verify the Guardrails API was called correctly
|
||||
self.mock_guardrails_post.assert_called_once_with(
|
||||
path="/v1/guardrail/checks",
|
||||
data={
|
||||
"model": shield_id,
|
||||
"messages": [
|
||||
json.loads(messages[0].model_dump_json()),
|
||||
json.loads(messages[1].model_dump_json()),
|
||||
],
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"max_tokens": 160,
|
||||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": "self-check",
|
||||
},
|
||||
},
|
||||
)
|
||||
# Verify the exception message
|
||||
assert error_msg in str(context.exception)
|
||||
|
||||
def test_init_nemo_guardrails(self):
|
||||
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
|
||||
|
||||
test_config_id = "test-custom-config-id"
|
||||
config = NVIDIASafetyConfig(
|
||||
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
||||
config_id=test_config_id,
|
||||
)
|
||||
# Initialize with default parameters
|
||||
test_model = "test-model"
|
||||
guardrails = NeMoGuardrails(config, test_model)
|
||||
|
||||
# Verify the attributes are set correctly
|
||||
assert guardrails.config_id == test_config_id
|
||||
assert guardrails.model == test_model
|
||||
assert guardrails.threshold == 0.9 # Default value
|
||||
assert guardrails.temperature == 1.0 # Default value
|
||||
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
|
||||
|
||||
# Initialize with custom parameters
|
||||
guardrails = NeMoGuardrails(config, test_model, threshold=0.8, temperature=0.7)
|
||||
|
||||
# Verify the attributes are set correctly
|
||||
assert guardrails.config_id == test_config_id
|
||||
assert guardrails.model == test_model
|
||||
assert guardrails.threshold == 0.8
|
||||
assert guardrails.temperature == 0.7
|
||||
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
|
||||
|
||||
def test_init_nemo_guardrails_invalid_temperature(self):
|
||||
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
|
||||
|
||||
config = NVIDIASafetyConfig(
|
||||
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
||||
config_id="test-custom-config-id",
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
NeMoGuardrails(config, "test-model", temperature=0)
|
Loading…
Add table
Add a link
Reference in a new issue