mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
further changes to nvidia.py
This commit is contained in:
parent
fedfcf9e44
commit
c580bae369
1 changed files with 30 additions and 9 deletions
|
|
@ -131,8 +131,7 @@ class NeMoGuardrails:
|
|||
messages (List[Message]): A list of Message objects to be checked for safety violations.
|
||||
|
||||
Returns:
|
||||
RunShieldResponse: If the response indicates a violation ("blocked" status), returns a
|
||||
RunShieldResponse with a SafetyViolation; otherwise, returns a RunShieldResponse with violation set to None.
|
||||
RunShieldResponse: Response with SafetyViolation if content is blocked, None otherwise.
|
||||
|
||||
Raises:
|
||||
requests.HTTPError: If the POST request fails.
|
||||
|
|
@ -143,16 +142,38 @@ class NeMoGuardrails:
|
|||
}
|
||||
response = await self._guardrails_post(path="/v1/chat/completions", data=request_data)
|
||||
|
||||
if response["status"] == "blocked":
|
||||
user_message = "Sorry I cannot do this."
|
||||
metadata = response["rails_status"]
|
||||
# Support legacy format with explicit status field
|
||||
if "status" in response and response["status"] == "blocked":
|
||||
return RunShieldResponse(
|
||||
violation=SafetyViolation(
|
||||
user_message="Sorry I cannot do this.",
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
metadata=response.get("rails_status", {}),
|
||||
)
|
||||
)
|
||||
|
||||
# NOTE: The implementation targets the actual behavior of the NeMo Guardrails server
|
||||
# as defined in 'nemoguardrails/server/api.py'. The 'RequestBody' class accepts
|
||||
# 'config_id' at the top level, and 'ResponseBody' returns a 'messages' array,
|
||||
# distinct from the OpenAI 'choices' format often referenced in documentation.
|
||||
|
||||
response_messages = response.get("messages", [])
|
||||
if response_messages:
|
||||
content = response_messages[0].get("content", "").strip()
|
||||
else:
|
||||
choices = response.get("choices", [])
|
||||
if choices:
|
||||
content = choices[0].get("message", {}).get("content", "").strip()
|
||||
else:
|
||||
content = ""
|
||||
|
||||
refusal_phrases = ["sorry i cannot do this", "i cannot help with that", "i can't assist with that"]
|
||||
is_blocked = not content or any(phrase in content.lower() for phrase in refusal_phrases)
|
||||
|
||||
return RunShieldResponse(
|
||||
violation=SafetyViolation(
|
||||
user_message=user_message,
|
||||
user_message="Sorry I cannot do this.",
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
metadata=metadata,
|
||||
metadata={"reason": "Content violates safety guidelines", "response": content or "(empty)"},
|
||||
) if is_blocked else None
|
||||
)
|
||||
)
|
||||
|
||||
return RunShieldResponse(violation=None)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue