Add back Guardrails section

This commit is contained in:
Jash Gulabrai 2025-04-10 10:57:25 -04:00
parent 57813f5606
commit a671b33589
2 changed files with 1078 additions and 40 deletions

File diff suppressed because it is too large Load diff

View file

@ -104,6 +104,16 @@ class NeMoGuardrails:
self.threshold = threshold
self.guardrails_service_url = config.guardrails_service_url
async def _guardrails_post(self, path: str, data: Any | None):
"""Helper for making POST requests to the guardrails service."""
headers = {
"Accept": "application/json",
}
print(data)
response = requests.post(url=f"{self.guardrails_service_url}{path}", headers=headers, json=data)
response.raise_for_status()
return response.json()
async def run(self, messages: List[Message]) -> RunShieldResponse:
"""
Queries the /v1/guardrails/checks endpoint of the NeMo guardrails deployed API.
@ -118,9 +128,6 @@ class NeMoGuardrails:
Raises:
requests.HTTPError: If the POST request fails.
"""
headers = {
"Accept": "application/json",
}
request_data = {
"model": self.model,
"messages": convert_pydantic_to_json_value(messages),
@ -134,15 +141,11 @@ class NeMoGuardrails:
"config_id": self.config_id,
},
}
response = requests.post(
url=f"{self.guardrails_service_url}/v1/guardrail/checks", headers=headers, json=request_data
)
response.raise_for_status()
if "Content-Type" in response.headers and response.headers["Content-Type"].startswith("application/json"):
response_json = response.json()
if response_json["status"] == "blocked":
response = await self._guardrails_post(path="/v1/guardrail/checks", data=request_data)
if response["status"] == "blocked":
user_message = "Sorry I cannot do this."
metadata = response_json["rails_status"]
metadata = response["rails_status"]
return RunShieldResponse(
violation=SafetyViolation(
@ -151,4 +154,5 @@ class NeMoGuardrails:
metadata=metadata,
)
)
return RunShieldResponse(violation=None)