llama-stack-mirror/llama_stack/providers/remote/safety/nvidia
Jiayi Ni b72169ca47
docs: update the docs for NVIDIA Inference provider (#3227)
# What does this PR do?
- Documentation update and fix for the NVIDIA Inference provider. 
- Update the `run_moderation` for safety API with a
`NotImplementedError` placeholder. Otherwise initialization NVIDIA
inference client will raise an error.

## Test Plan
N/A
2025-08-21 15:59:39 -07:00
..
__init__.py feat: added nvidia as safety provider (#1248) 2025-03-17 14:39:23 -07:00
config.py refactor(env)!: enhanced environment variable substitution (#2490) 2025-06-26 08:20:08 +05:30
nvidia.py docs: update the docs for NVIDIA Inference provider (#3227) 2025-08-21 15:59:39 -07:00
README.md chore: rename templates to distributions (#3035) 2025-08-04 11:34:17 -07:00

NVIDIA Safety Provider for LlamaStack

This provider enables safety checks and guardrails for LLM interactions using NVIDIA's NeMo Guardrails service.

Features

  • Run safety checks for messages

Getting Started

Prerequisites

  • LlamaStack with NVIDIA configuration
  • Access to NVIDIA NeMo Guardrails service
  • NIM for model to use for safety check is deployed

Setup

Build the NVIDIA environment:

llama stack build --distro nvidia --image-type venv

Basic Usage using the LlamaStack Python Client

Initialize the client

import os

os.environ["NVIDIA_API_KEY"] = "your-api-key"
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://guardrails.test"

from llama_stack.core.library_client import LlamaStackAsLibraryClient

client = LlamaStackAsLibraryClient("nvidia")
client.initialize()

Create a safety shield

from llama_stack.apis.safety import Shield
from llama_stack.apis.inference import Message

# Create a safety shield
shield = Shield(
    shield_id="your-shield-id",
    provider_resource_id="safety-model-id",  # The model to use for safety checks
    description="Safety checks for content moderation",
)

# Register the shield
await client.safety.register_shield(shield)

Run safety checks

# Messages to check
messages = [Message(role="user", content="Your message to check")]

# Run safety check
response = await client.safety.run_shield(
    shield_id="your-shield-id",
    messages=messages,
)

# Check for violations
if response.violation:
    print(f"Safety violation detected: {response.violation.user_message}")
    print(f"Violation level: {response.violation.violation_level}")
    print(f"Metadata: {response.violation.metadata}")
else:
    print("No safety violations detected")