mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-21 12:09:40 +00:00
Add back Guardrails section
This commit is contained in:
parent
57813f5606
commit
a671b33589
2 changed files with 1078 additions and 40 deletions
File diff suppressed because it is too large
Load diff
|
@ -104,6 +104,16 @@ class NeMoGuardrails:
|
||||||
self.threshold = threshold
|
self.threshold = threshold
|
||||||
self.guardrails_service_url = config.guardrails_service_url
|
self.guardrails_service_url = config.guardrails_service_url
|
||||||
|
|
||||||
|
async def _guardrails_post(self, path: str, data: Any | None):
|
||||||
|
"""Helper for making POST requests to the guardrails service."""
|
||||||
|
headers = {
|
||||||
|
"Accept": "application/json",
|
||||||
|
}
|
||||||
|
print(data)
|
||||||
|
response = requests.post(url=f"{self.guardrails_service_url}{path}", headers=headers, json=data)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
async def run(self, messages: List[Message]) -> RunShieldResponse:
|
async def run(self, messages: List[Message]) -> RunShieldResponse:
|
||||||
"""
|
"""
|
||||||
Queries the /v1/guardrails/checks endpoint of the NeMo guardrails deployed API.
|
Queries the /v1/guardrails/checks endpoint of the NeMo guardrails deployed API.
|
||||||
|
@ -118,9 +128,6 @@ class NeMoGuardrails:
|
||||||
Raises:
|
Raises:
|
||||||
requests.HTTPError: If the POST request fails.
|
requests.HTTPError: If the POST request fails.
|
||||||
"""
|
"""
|
||||||
headers = {
|
|
||||||
"Accept": "application/json",
|
|
||||||
}
|
|
||||||
request_data = {
|
request_data = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"messages": convert_pydantic_to_json_value(messages),
|
"messages": convert_pydantic_to_json_value(messages),
|
||||||
|
@ -134,15 +141,11 @@ class NeMoGuardrails:
|
||||||
"config_id": self.config_id,
|
"config_id": self.config_id,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
response = requests.post(
|
response = await self._guardrails_post(path="/v1/guardrail/checks", data=request_data)
|
||||||
url=f"{self.guardrails_service_url}/v1/guardrail/checks", headers=headers, json=request_data
|
|
||||||
)
|
if response["status"] == "blocked":
|
||||||
response.raise_for_status()
|
|
||||||
if "Content-Type" in response.headers and response.headers["Content-Type"].startswith("application/json"):
|
|
||||||
response_json = response.json()
|
|
||||||
if response_json["status"] == "blocked":
|
|
||||||
user_message = "Sorry I cannot do this."
|
user_message = "Sorry I cannot do this."
|
||||||
metadata = response_json["rails_status"]
|
metadata = response["rails_status"]
|
||||||
|
|
||||||
return RunShieldResponse(
|
return RunShieldResponse(
|
||||||
violation=SafetyViolation(
|
violation=SafetyViolation(
|
||||||
|
@ -151,4 +154,5 @@ class NeMoGuardrails:
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return RunShieldResponse(violation=None)
|
return RunShieldResponse(violation=None)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue