diff --git a/llama_stack/providers/remote/inference/nvidia/NVIDIA.md b/llama_stack/providers/remote/inference/nvidia/NVIDIA.md new file mode 100644 index 000000000..a353c67f5 --- /dev/null +++ b/llama_stack/providers/remote/inference/nvidia/NVIDIA.md @@ -0,0 +1,85 @@ +# NVIDIA Inference Provider for LlamaStack + +This provider enables running inference using NVIDIA NIM. + +## Features +- Endpoints for completions, chat completions, and embeddings for registered models + +## Getting Started + +### Prerequisites + +- LlamaStack with NVIDIA configuration +- Access to NVIDIA NIM deployment +- NIM for model to use for inference is deployed + +### Setup + +Build the NVIDIA environment: + +```bash +llama stack build --template nvidia --image-type conda +``` + +### Basic Usage using the LlamaStack Python Client + +#### Initialize the client + +```python +import os + +os.environ["NVIDIA_API_KEY"] = ( + "" # Required if using hosted NIM endpoint. If self-hosted, not required. +) +os.environ["NVIDIA_BASE_URL"] = "http://nim.test" # NIM URL + +from llama_stack.distribution.library_client import LlamaStackAsLibraryClient + +client = LlamaStackAsLibraryClient("nvidia") +client.initialize() +``` + +### Create Completion + +```python +response = client.completion( + model_id="meta-llama/Llama-3.1-8b-Instruct", + content="Complete the sentence using one word: Roses are red, violets are :", + stream=False, + sampling_params={ + "max_tokens": 50, + }, +) +print(f"Response: {response.content}") +``` + +### Create Chat Completion + +```python +response = client.chat_completion( + model_id="meta-llama/Llama-3.1-8b-Instruct", + messages=[ + { + "role": "system", + "content": "You must respond to each message with only one word", + }, + { + "role": "user", + "content": "Complete the sentence using one word: Roses are red, violets are:", + }, + ], + stream=False, + sampling_params={ + "max_tokens": 50, + }, +) +print(f"Response: {response.completion_message.content}") +``` + +### Create Embeddings +```python +response = client.embeddings( + model_id="meta-llama/Llama-3.1-8b-Instruct", contents=["foo", "bar", "baz"] +) +print(f"Embeddings: {response.embeddings}") +``` diff --git a/llama_stack/providers/remote/safety/nvidia/README.md b/llama_stack/providers/remote/safety/nvidia/README.md new file mode 100644 index 000000000..434db32fb --- /dev/null +++ b/llama_stack/providers/remote/safety/nvidia/README.md @@ -0,0 +1,77 @@ +# NVIDIA Safety Provider for LlamaStack + +This provider enables safety checks and guardrails for LLM interactions using NVIDIA's NeMo Guardrails service. + +## Features + +- Run safety checks for messages + +## Getting Started + +### Prerequisites + +- LlamaStack with NVIDIA configuration +- Access to NVIDIA NeMo Guardrails service +- NIM for model to use for safety check is deployed + +### Setup + +Build the NVIDIA environment: + +```bash +llama stack build --template nvidia --image-type conda +``` + +### Basic Usage using the LlamaStack Python Client + +#### Initialize the client + +```python +import os + +os.environ["NVIDIA_API_KEY"] = "your-api-key" +os.environ["NVIDIA_GUARDRAILS_URL"] = "http://guardrails.test" + +from llama_stack.distribution.library_client import LlamaStackAsLibraryClient + +client = LlamaStackAsLibraryClient("nvidia") +client.initialize() +``` + +#### Create a safety shield + +```python +from llama_stack.apis.safety import Shield +from llama_stack.apis.inference import Message + +# Create a safety shield +shield = Shield( + shield_id="your-shield-id", + provider_resource_id="safety-model-id", # The model to use for safety checks + description="Safety checks for content moderation", +) + +# Register the shield +await client.safety.register_shield(shield) +``` + +#### Run safety checks + +```python +# Messages to check +messages = [Message(role="user", content="Your message to check")] + +# Run safety check +response = await client.safety.run_shield( + shield_id="your-shield-id", + messages=messages, +) + +# Check for violations +if response.violation: + print(f"Safety violation detected: {response.violation.user_message}") + print(f"Violation level: {response.violation.violation_level}") + print(f"Metadata: {response.violation.metadata}") +else: + print("No safety violations detected") +```