further changes to nvidia.py

This commit is contained in:
r-bit-rry 2025-11-24 22:16:50 +02:00
parent fedfcf9e44
commit c580bae369

View file

@ -131,8 +131,7 @@ class NeMoGuardrails:
messages (List[Message]): A list of Message objects to be checked for safety violations.
Returns:
RunShieldResponse: If the response indicates a violation ("blocked" status), returns a
RunShieldResponse with a SafetyViolation; otherwise, returns a RunShieldResponse with violation set to None.
RunShieldResponse: Response with SafetyViolation if content is blocked, None otherwise.
Raises:
requests.HTTPError: If the POST request fails.
@ -143,16 +142,38 @@ class NeMoGuardrails:
}
response = await self._guardrails_post(path="/v1/chat/completions", data=request_data)
if response["status"] == "blocked":
user_message = "Sorry I cannot do this."
metadata = response["rails_status"]
# Support legacy format with explicit status field
if "status" in response and response["status"] == "blocked":
return RunShieldResponse(
violation=SafetyViolation(
user_message="Sorry I cannot do this.",
violation_level=ViolationLevel.ERROR,
metadata=response.get("rails_status", {}),
)
)
# NOTE: The implementation targets the actual behavior of the NeMo Guardrails server
# as defined in 'nemoguardrails/server/api.py'. The 'RequestBody' class accepts
# 'config_id' at the top level, and 'ResponseBody' returns a 'messages' array,
# distinct from the OpenAI 'choices' format often referenced in documentation.
response_messages = response.get("messages", [])
if response_messages:
content = response_messages[0].get("content", "").strip()
else:
choices = response.get("choices", [])
if choices:
content = choices[0].get("message", {}).get("content", "").strip()
else:
content = ""
refusal_phrases = ["sorry i cannot do this", "i cannot help with that", "i can't assist with that"]
is_blocked = not content or any(phrase in content.lower() for phrase in refusal_phrases)
return RunShieldResponse(
violation=SafetyViolation(
user_message=user_message,
user_message="Sorry I cannot do this.",
violation_level=ViolationLevel.ERROR,
metadata=metadata,
metadata={"reason": "Content violates safety guidelines", "response": content or "(empty)"},
) if is_blocked else None
)
)
return RunShieldResponse(violation=None)