Add back Guardrails section

This commit is contained in:
Jash Gulabrai 2025-04-10 10:57:25 -04:00
parent 57813f5606
commit a671b33589
2 changed files with 1078 additions and 40 deletions

File diff suppressed because it is too large Load diff

View file

@ -104,6 +104,16 @@ class NeMoGuardrails:
self.threshold = threshold self.threshold = threshold
self.guardrails_service_url = config.guardrails_service_url self.guardrails_service_url = config.guardrails_service_url
async def _guardrails_post(self, path: str, data: Any | None):
"""Helper for making POST requests to the guardrails service."""
headers = {
"Accept": "application/json",
}
print(data)
response = requests.post(url=f"{self.guardrails_service_url}{path}", headers=headers, json=data)
response.raise_for_status()
return response.json()
async def run(self, messages: List[Message]) -> RunShieldResponse: async def run(self, messages: List[Message]) -> RunShieldResponse:
""" """
Queries the /v1/guardrails/checks endpoint of the NeMo guardrails deployed API. Queries the /v1/guardrails/checks endpoint of the NeMo guardrails deployed API.
@ -118,9 +128,6 @@ class NeMoGuardrails:
Raises: Raises:
requests.HTTPError: If the POST request fails. requests.HTTPError: If the POST request fails.
""" """
headers = {
"Accept": "application/json",
}
request_data = { request_data = {
"model": self.model, "model": self.model,
"messages": convert_pydantic_to_json_value(messages), "messages": convert_pydantic_to_json_value(messages),
@ -134,15 +141,11 @@ class NeMoGuardrails:
"config_id": self.config_id, "config_id": self.config_id,
}, },
} }
response = requests.post( response = await self._guardrails_post(path="/v1/guardrail/checks", data=request_data)
url=f"{self.guardrails_service_url}/v1/guardrail/checks", headers=headers, json=request_data
) if response["status"] == "blocked":
response.raise_for_status()
if "Content-Type" in response.headers and response.headers["Content-Type"].startswith("application/json"):
response_json = response.json()
if response_json["status"] == "blocked":
user_message = "Sorry I cannot do this." user_message = "Sorry I cannot do this."
metadata = response_json["rails_status"] metadata = response["rails_status"]
return RunShieldResponse( return RunShieldResponse(
violation=SafetyViolation( violation=SafetyViolation(
@ -151,4 +154,5 @@ class NeMoGuardrails:
metadata=metadata, metadata=metadata,
) )
) )
return RunShieldResponse(violation=None) return RunShieldResponse(violation=None)