forked from phoenix-oss/llama-stack-mirror
feat: Add unit tests for NVIDIA safety (#1897)
# What does this PR do? This PR adds unit tests for the NVIDIA Safety provider implementation. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] 1. Ran `./scripts/unit-tests.sh tests/unit/providers/nvidia/test_safety.py` from the root of the project. Verified tests pass. ``` tests/unit/providers/nvidia/test_safety.py::TestNVIDIASafetyAdapter::test_init_nemo_guardrails Initializing NVIDIASafetyAdapter(http://nemo.test)... PASSED tests/unit/providers/nvidia/test_safety.py::TestNVIDIASafetyAdapter::test_init_nemo_guardrails_invalid_temperature Initializing NVIDIASafetyAdapter(http://nemo.test)... PASSED tests/unit/providers/nvidia/test_safety.py::TestNVIDIASafetyAdapter::test_register_shield_with_valid_id Initializing NVIDIASafetyAdapter(http://nemo.test)... PASSED tests/unit/providers/nvidia/test_safety.py::TestNVIDIASafetyAdapter::test_register_shield_without_id Initializing NVIDIASafetyAdapter(http://nemo.test)... PASSED tests/unit/providers/nvidia/test_safety.py::TestNVIDIASafetyAdapter::test_run_shield_allowed Initializing NVIDIASafetyAdapter(http://nemo.test)... PASSED tests/unit/providers/nvidia/test_safety.py::TestNVIDIASafetyAdapter::test_run_shield_blocked Initializing NVIDIASafetyAdapter(http://nemo.test)... PASSED tests/unit/providers/nvidia/test_safety.py::TestNVIDIASafetyAdapter::test_run_shield_http_error Initializing NVIDIASafetyAdapter(http://nemo.test)... PASSED tests/unit/providers/nvidia/test_safety.py::TestNVIDIASafetyAdapter::test_run_shield_not_found Initializing NVIDIASafetyAdapter(http://nemo.test)... PASSED ``` [//]: # (## Documentation) --------- Co-authored-by: Jash Gulabrai <jgulabrai@nvidia.com>
This commit is contained in:
parent
2a74f0db39
commit
c1cb6aad11
2 changed files with 340 additions and 11 deletions
|
@ -104,6 +104,15 @@ class NeMoGuardrails:
|
|||
self.threshold = threshold
|
||||
self.guardrails_service_url = config.guardrails_service_url
|
||||
|
||||
async def _guardrails_post(self, path: str, data: Any | None):
|
||||
"""Helper for making POST requests to the guardrails service."""
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
}
|
||||
response = requests.post(url=f"{self.guardrails_service_url}{path}", headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def run(self, messages: List[Message]) -> RunShieldResponse:
|
||||
"""
|
||||
Queries the /v1/guardrails/checks endpoint of the NeMo guardrails deployed API.
|
||||
|
@ -118,9 +127,6 @@ class NeMoGuardrails:
|
|||
Raises:
|
||||
requests.HTTPError: If the POST request fails.
|
||||
"""
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
}
|
||||
request_data = {
|
||||
"model": self.model,
|
||||
"messages": convert_pydantic_to_json_value(messages),
|
||||
|
@ -134,15 +140,11 @@ class NeMoGuardrails:
|
|||
"config_id": self.config_id,
|
||||
},
|
||||
}
|
||||
response = requests.post(
|
||||
url=f"{self.guardrails_service_url}/v1/guardrail/checks", headers=headers, json=request_data
|
||||
)
|
||||
response.raise_for_status()
|
||||
if "Content-Type" in response.headers and response.headers["Content-Type"].startswith("application/json"):
|
||||
response_json = response.json()
|
||||
if response_json["status"] == "blocked":
|
||||
response = await self._guardrails_post(path="/v1/guardrail/checks", data=request_data)
|
||||
|
||||
if response["status"] == "blocked":
|
||||
user_message = "Sorry I cannot do this."
|
||||
metadata = response_json["rails_status"]
|
||||
metadata = response["rails_status"]
|
||||
|
||||
return RunShieldResponse(
|
||||
violation=SafetyViolation(
|
||||
|
@ -151,4 +153,5 @@ class NeMoGuardrails:
|
|||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
|
||||
return RunShieldResponse(violation=None)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue