add NVIDIA safety provider

This commit is contained in:
Chantal D Gama Rose 2025-02-19 14:50:48 -08:00
parent 688e1806d1
commit b90ff9ca16
6 changed files with 389 additions and 5 deletions

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import NVIDIASafetyConfig
async def get_adapter_impl(config: NVIDIASafetyConfig, _deps) -> Any:
from .nvidia import NVIDIASafetyAdapter
impl = NVIDIASafetyAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,63 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
import os
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, SecretStr, field_validator
from llama_models.schema_utils import json_schema_type
class ShieldType(Enum):
self_check = "self_check"
@json_schema_type
class NVIDIASafetyConfig(BaseModel):
"""
Configuration for the NVIDIA Guardrail microservice endpoint.
Attributes:
url (str): A base url for accessing the NVIDIA guardrail endpoint, e.g. http://localhost:8000
api_key (str): The access key for the hosted NIM endpoints
There are two ways to access NVIDIA NIMs -
0. Hosted: Preview APIs hosted at https://integrate.api.nvidia.com
1. Self-hosted: You can run NVIDIA NIMs on your own infrastructure
By default the configuration is set to use the hosted APIs. This requires
an API key which can be obtained from https://ngc.nvidia.com/.
By default the configuration will attempt to read the NVIDIA_API_KEY environment
variable to set the api_key. Please do not put your API key in code.
"""
guardrails_service_url: str = Field(
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://0.0.0.0:7331"),
description="The url for accessing the guardrails service",
)
config_id: Optional[str] = Field(
default="self-check",
description="Config ID to use from the config store"
)
config_store_path: Optional[str] = Field(
default="/config-store",
description="Path to config store"
)
@classmethod
@field_validator("guard_type")
def validate_guard_type(cls, v):
if v not in [t.value for t in ShieldType]:
raise ValueError(f"Unknown shield type: {v}")
return v
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"guardrails_service_url": "${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}",
"config_id": "self-check"
}

View file

@ -0,0 +1,110 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import logging
from typing import Any, Dict, List
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.shields import Shield
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
import requests
from .config import NVIDIASafetyConfig
logger = logging.getLogger(__name__)
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
def __init__(self, config: NVIDIASafetyConfig) -> None:
print(f"Initializing NVIDIASafetyAdapter({config.url})...")
self.config = config
self.registered_shields = []
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: Shield) -> None:
if not shield.provider_resource_id:
raise ValueError(f"Shield model not provided. ")
async def run_shield(
self, shield_id: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
if not shield:
raise ValueError(f"Shield {shield_id} not found")
self.shield = NeMoGuardrails(self.config, shield.provider_resource_id)
return await self.shield.run(messages)
class NeMoGuardrails:
def __init__(
self,
config: NVIDIASafetyConfig,
model: str,
threshold: float = 0.9,
temperature: float = 1.0,
):
config_id = config["config_id"]
config_store_path = config["config_store_path"]
assert config_id is not None or config_store_path is not None, "Must provide one of config id or config store path"
if temperature <= 0:
raise ValueError("Temperature must be greater than 0")
self.config = config
self.temperature = temperature
self.threshold = threshold
self.guardrails_service_url = config["guardrails_service_url"]
async def run(self, messages: List[Message]) -> RunShieldResponse:
headers = {
"Accept": "application/json",
}
request_data = {
"model": "meta/llama-3.1-8b-instruct",
"messages": messages,
"temperature": self.temperature,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"max_tokens": 160,
"stream": False,
"guardrails": {
"config_id": self.config["config_id"],
}
}
response = requests.post(
url=f"{self.guardrails_service_url}/v1/guardrail/checks",
headers=headers,
json=request_data
)
response.raise_for_status()
if 'Content-Type' in response.headers and response.headers['Content-Type'].startswith('application/json'):
response_json = response.json()
if response_json["status"] == "blocked":
user_message = "Sorry I cannot do this."
metadata = response_json["rails_status"]
return RunShieldResponse(
violation=SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata,
)
)
return RunShieldResponse(violation=None)

View file

@ -7,7 +7,7 @@ distribution_spec:
vector_io:
- inline::faiss
safety:
- inline::llama-guard
- remote::nvidia
agents:
- inline::meta-reference
telemetry:

View file

@ -6,8 +6,15 @@
from pathlib import Path
<<<<<<< Updated upstream
from llama_stack.distribution.datatypes import ModelInput, Provider, ToolGroupInput
from llama_stack.models.llama.sku_list import all_registered_models
=======
from llama_models.sku_list import all_registered_models
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
>>>>>>> Stashed changes
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.nvidia.nvidia import _MODEL_ALIASES
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
@ -37,6 +44,19 @@ def get_distribution_template() -> DistributionTemplate:
provider_type="remote::nvidia",
config=NVIDIAConfig.sample_run_config(),
)
safety_provider = Provider(
provider_id="nvidia",
provider_type="remote::nvidia",
config=NVIDIASafetyConfig.sample_run_config(),
)
inference_model = ModelInput(
model_id="${env.INFERENCE_MODEL}",
provider_id="nvidia",
)
safety_model = ModelInput(
model_id="${env.SAFETY_MODEL}",
provider_id="nvidia",
)
core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()}
default_models = [
@ -78,15 +98,34 @@ def get_distribution_template() -> DistributionTemplate:
default_models=default_models,
default_tool_groups=default_tool_groups,
),
"run-with-safety.yaml": RunConfigSettings(
provider_overrides={
"inference": [
inference_provider,
safety_provider,
]
},
default_models=[inference_model, safety_model],
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")],
default_tool_groups=default_tool_groups,
),
},
run_config_env_vars={
"LLAMASTACK_PORT": (
"5001",
"Port for the Llama Stack distribution server",
),
"NVIDIA_API_KEY": (
"",
"NVIDIA API Key",
),
"GUARDRAILS_SERVICE_URL": (
"http://0.0.0.0:7331",
"URL for the NeMo Guardrails Service",
),
"INFERENCE_MODEL": (
"Llama3.1-8B-Instruct",
"Inference model",
),
"SAFETY_MODEL": (
"meta/llama-3.1-8b-instruct",
"Name of the model to use for safety",
),
},
)

View file

@ -0,0 +1,154 @@
version: '2'
image_name: nvidia
apis:
- agents
- datasetio
- eval
- inference
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: nvidia
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}
api_key: ${env.NVIDIA_API_KEY:}
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/faiss_store.db
safety:
- provider_id: nvidia
provider_type: remote::nvidia
config:
url: ${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}
config_id:
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/agents_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db}
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config: {}
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config: {}
- provider_id: localfs
provider_type: inline::localfs
config: {}
scoring:
- provider_id: basic
provider_type: inline::basic
config: {}
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
config: {}
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3
- provider_id: code-interpreter
provider_type: inline::code-interpreter
config: {}
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
config: {}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db
models:
- metadata: {}
model_id: meta-llama/Llama-3-8B-Instruct
provider_id: nvidia
provider_model_id: meta/llama3-8b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3-70B-Instruct
provider_id: nvidia
provider_model_id: meta/llama3-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.1-8b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-70B-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.1-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-405B-Instruct-FP8
provider_id: nvidia
provider_model_id: meta/llama-3.1-405b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-1B-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-1b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-3B-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-3b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-90b-vision-instruct
model_type: llm
shields:
- shield_id: ${env.SAFETY_MODEL}
vector_dbs: []
datasets: []
scoring_fns: []
eval_tasks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter
server:
port: 8321