added nvidia as safety provider

This commit is contained in:
Chantal D Gama Rose 2025-02-25 08:16:49 +00:00
parent 07a992ef90
commit 0593408c19
14 changed files with 354 additions and 78 deletions

View file

@ -390,16 +390,13 @@
], ],
"nvidia": [ "nvidia": [
"aiosqlite", "aiosqlite",
"autoevals",
"blobfile", "blobfile",
"chardet", "chardet",
"datasets",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"httpx", "httpx",
"matplotlib", "matplotlib",
"mcp",
"nltk", "nltk",
"numpy", "numpy",
"openai", "openai",

View file

@ -6,13 +6,13 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
| API | Provider(s) | | API | Provider(s) |
|-----|-------------| |-----|-------------|
| agents | `inline::meta-reference` | | agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` | | datasetio | `inline::localfs` |
| eval | `inline::meta-reference` | | eval | `inline::meta-reference` |
| inference | `remote::nvidia` | | inference | `remote::nvidia` |
| safety | `inline::llama-guard` | | safety | `remote::nvidia` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | scoring | `inline::basic` |
| telemetry | `inline::meta-reference` | | telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` | | tool_runtime | `inline::rag-runtime` |
| vector_io | `inline::faiss` | | vector_io | `inline::faiss` |
@ -20,8 +20,10 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
The following environment variables can be configured: The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``) - `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
- `GUARDRAILS_SERVICE_URL`: URL for the NeMo Guardrails Service (default: `http://0.0.0.0:7331`)
- `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`)
- `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`)
### Models ### Models

View file

@ -85,4 +85,13 @@ Provider `inline::meta-reference` for API `safety` does not work with the latest
config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig", config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig",
), ),
), ),
remote_provider_spec(
api=Api.safety,
adapter=AdapterSpec(
adapter_type="nvidia",
pip_packages=["requests"],
module="llama_stack.providers.remote.safety.nvidia",
config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig",
),
),
] ]

View file

@ -7,7 +7,7 @@
import os import os
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, SecretStr from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@ -39,7 +39,7 @@ class NVIDIAConfig(BaseModel):
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com"), default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com"),
description="A base url for accessing the NVIDIA NIM", description="A base url for accessing the NVIDIA NIM",
) )
api_key: Optional[SecretStr] = Field( api_key: Optional[str] = Field(
default_factory=lambda: os.getenv("NVIDIA_API_KEY"), default_factory=lambda: os.getenv("NVIDIA_API_KEY"),
description="The NVIDIA API key, only needed of using the hosted service", description="The NVIDIA API key, only needed of using the hosted service",
) )

View file

@ -85,7 +85,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
# make sure the client lives longer than any async calls # make sure the client lives longer than any async calls
self._client = AsyncOpenAI( self._client = AsyncOpenAI(
base_url=f"{self._config.url}/v1", base_url=f"{self._config.url}/v1",
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"), api_key=(self._config.api_key if self._config.api_key else "NO KEY"),
timeout=self._config.timeout, timeout=self._config.timeout,
) )

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import NVIDIASafetyConfig
async def get_adapter_impl(config: NVIDIASafetyConfig, _deps) -> Any:
from .nvidia import NVIDIASafetyAdapter
impl = NVIDIASafetyAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,45 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class NVIDIASafetyConfig(BaseModel):
"""
Configuration for the NVIDIA Guardrail microservice endpoint.
Attributes:
guardrails_service_url (str): A base url for accessing the NVIDIA guardrail endpoint, e.g. http://localhost:8000
api_key (str): The access key for the hosted NIM endpoints
There are two ways to access NVIDIA NIMs -
0. Hosted: Preview APIs hosted at https://integrate.api.nvidia.com
1. Self-hosted: You can run NVIDIA NIMs on your own infrastructure
By default the configuration is set to use the hosted APIs. This requires
an API key which can be obtained from https://ngc.nvidia.com/.
By default the configuration will attempt to read the NVIDIA_API_KEY environment
variable to set the api_key. Please do not put your API key in code.
"""
guardrails_service_url: str = Field(
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "http://0.0.0.0:7331"),
description="The url for accessing the guardrails service",
)
config_id: Optional[str] = Field(default="self-check", description="Config ID to use from the config store")
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"guardrails_service_url": "${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}",
"config_id": "self-check",
}

View file

@ -0,0 +1,103 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import Any, Dict, List
import requests
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import RunShieldResponse, Safety, SafetyViolation, ViolationLevel
from llama_stack.apis.shields import Shield
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from .config import NVIDIASafetyConfig
logger = logging.getLogger(__name__)
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
def __init__(self, config: NVIDIASafetyConfig) -> None:
print(f"Initializing NVIDIASafetyAdapter({config.guardrails_service_url})...")
self.config = config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: Shield) -> None:
if not shield.provider_resource_id:
raise ValueError("Shield model not provided.")
async def run_shield(
self, shield_id: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
if not shield:
raise ValueError(f"Shield {shield_id} not found")
self.shield = NeMoGuardrails(self.config, shield.shield_id)
return await self.shield.run(messages)
class NeMoGuardrails:
def __init__(
self,
config: NVIDIASafetyConfig,
model: str,
threshold: float = 0.9,
temperature: float = 1.0,
):
self.config_id = config.config_id
self.model = model
assert self.config_id is not None or self.config_store_path is not None, (
"Must provide one of config id or config store path"
)
if temperature <= 0:
raise ValueError("Temperature must be greater than 0")
self.temperature = temperature
self.threshold = threshold
self.guardrails_service_url = config.guardrails_service_url
async def run(self, messages: List[Message]) -> RunShieldResponse:
headers = {
"Accept": "application/json",
}
request_data = {
"model": self.model,
"messages": convert_pydantic_to_json_value(messages),
"temperature": self.temperature,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"max_tokens": 160,
"stream": False,
"guardrails": {
"config_id": self.config_id,
},
}
response = requests.post(
url=f"{self.guardrails_service_url}/v1/guardrail/checks", headers=headers, json=request_data
)
response.raise_for_status()
if "Content-Type" in response.headers and response.headers["Content-Type"].startswith("application/json"):
response_json = response.json()
if response_json["status"] == "blocked":
user_message = "Sorry I cannot do this."
metadata = response_json["rails_status"]
return RunShieldResponse(
violation=SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata,
)
)
return RunShieldResponse(violation=None)

View file

@ -51,11 +51,19 @@ DEFAULT_PROVIDER_COMBINATIONS = [
id="remote", id="remote",
marks=pytest.mark.remote, marks=pytest.mark.remote,
), ),
pytest.param(
{
"inference": "nvidia",
"safety": "nvidia",
},
id="nvidia",
marks=pytest.mark.nvidia,
),
] ]
def pytest_configure(config): def pytest_configure(config):
for mark in ["meta_reference", "ollama", "together", "remote", "bedrock"]: for mark in ["meta_reference", "ollama", "together", "remote", "bedrock", "nvidia"]:
config.addinivalue_line( config.addinivalue_line(
"markers", "markers",
f"{mark}: marks tests as {mark} specific", f"{mark}: marks tests as {mark} specific",

View file

@ -13,6 +13,7 @@ from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig
from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
from llama_stack.providers.tests.resolver import construct_stack_for_test from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture from ..conftest import ProviderFixture, remote_stack_fixture
@ -95,7 +96,20 @@ def safety_bedrock() -> ProviderFixture:
) )
SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote"] @pytest.fixture(scope="session")
def safety_nvidia() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="nvidia",
provider_type="remote::nvidia",
config=NVIDIASafetyConfig().model_dump(),
)
],
)
SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote", "nvidia"]
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")

View file

@ -1,13 +1,13 @@
version: '2' version: '2'
distribution_spec: distribution_spec:
description: Use NVIDIA NIM for running LLM inference description: Use NVIDIA NIM for running LLM inference and safety
providers: providers:
inference: inference:
- remote::nvidia - remote::nvidia
vector_io: vector_io:
- inline::faiss - inline::faiss
safety: safety:
- inline::llama-guard - remote::nvidia
agents: agents:
- inline::meta-reference - inline::meta-reference
telemetry: telemetry:
@ -15,16 +15,9 @@ distribution_spec:
eval: eval:
- inline::meta-reference - inline::meta-reference
datasetio: datasetio:
- remote::huggingface
- inline::localfs - inline::localfs
scoring: scoring:
- inline::basic - inline::basic
- inline::llm-as-judge
- inline::braintrust
tool_runtime: tool_runtime:
- remote::brave-search
- remote::tavily-search
- inline::code-interpreter
- inline::rag-runtime - inline::rag-runtime
- remote::model-context-protocol
image_type: conda image_type: conda

View file

@ -10,25 +10,23 @@ from llama_stack.distribution.datatypes import Provider, ToolGroupInput
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
def get_distribution_template() -> DistributionTemplate: def get_distribution_template() -> DistributionTemplate:
providers = { providers = {
"inference": ["remote::nvidia"], "inference": ["remote::nvidia"],
"vector_io": ["inline::faiss"], "vector_io": ["inline::faiss"],
"safety": ["inline::llama-guard"], "safety": ["remote::nvidia"],
"agents": ["inline::meta-reference"], "agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"], "telemetry": ["inline::meta-reference"],
"eval": ["inline::meta-reference"], "eval": ["inline::meta-reference"],
"datasetio": ["remote::huggingface", "inline::localfs"], "datasetio": ["inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], "scoring": ["inline::basic"],
"tool_runtime": [ "tool_runtime": ["inline::rag-runtime"],
"remote::brave-search",
"remote::tavily-search",
"inline::code-interpreter",
"inline::rag-runtime",
"remote::model-context-protocol",
],
} }
inference_provider = Provider( inference_provider = Provider(
@ -36,30 +34,35 @@ def get_distribution_template() -> DistributionTemplate:
provider_type="remote::nvidia", provider_type="remote::nvidia",
config=NVIDIAConfig.sample_run_config(), config=NVIDIAConfig.sample_run_config(),
) )
safety_provider = Provider(
provider_id="nvidia",
provider_type="remote::nvidia",
config=NVIDIASafetyConfig.sample_run_config(),
)
inference_model = ModelInput(
model_id="${env.INFERENCE_MODEL}",
provider_id="nvidia",
)
safety_model = ModelInput(
model_id="${env.SAFETY_MODEL}",
provider_id="nvidia",
)
available_models = { available_models = {
"nvidia": MODEL_ENTRIES, "nvidia": MODEL_ENTRIES,
} }
default_tool_groups = [ default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput( ToolGroupInput(
toolgroup_id="builtin::rag", toolgroup_id="builtin::rag",
provider_id="rag-runtime", provider_id="rag-runtime",
), ),
ToolGroupInput(
toolgroup_id="builtin::code_interpreter",
provider_id="code-interpreter",
),
] ]
default_models = get_model_registry(available_models) default_models = get_model_registry(available_models)
return DistributionTemplate( return DistributionTemplate(
name="nvidia", name="nvidia",
distro_type="remote_hosted", distro_type="remote_hosted",
description="Use NVIDIA NIM for running LLM inference", description="Use NVIDIA NIM for running LLM inference and safety",
container_image=None, container_image=None,
template_path=Path(__file__).parent / "doc_template.md", template_path=Path(__file__).parent / "doc_template.md",
providers=providers, providers=providers,
@ -72,15 +75,34 @@ def get_distribution_template() -> DistributionTemplate:
default_models=default_models, default_models=default_models,
default_tool_groups=default_tool_groups, default_tool_groups=default_tool_groups,
), ),
"run-with-safety.yaml": RunConfigSettings(
provider_overrides={
"inference": [
inference_provider,
safety_provider,
]
},
default_models=[inference_model, safety_model],
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}", provider_id="nvidia")],
default_tool_groups=default_tool_groups,
),
}, },
run_config_env_vars={ run_config_env_vars={
"LLAMASTACK_PORT": (
"5001",
"Port for the Llama Stack distribution server",
),
"NVIDIA_API_KEY": ( "NVIDIA_API_KEY": (
"", "",
"NVIDIA API Key", "NVIDIA API Key",
), ),
"GUARDRAILS_SERVICE_URL": (
"http://0.0.0.0:7331",
"URL for the NeMo Guardrails Service",
),
"INFERENCE_MODEL": (
"Llama3.1-8B-Instruct",
"Inference model",
),
"SAFETY_MODEL": (
"meta/llama-3.1-8b-instruct",
"Name of the model to use for safety",
),
}, },
) )

View file

@ -0,0 +1,93 @@
version: '2'
image_name: nvidia
apis:
- agents
- datasetio
- eval
- inference
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: nvidia
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}
api_key: ${env.NVIDIA_API_KEY:}
- provider_id: nvidia
provider_type: remote::nvidia
config:
guardrails_service_url: ${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}
config_id: self-check
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/faiss_store.db
safety:
- provider_id: nvidia
provider_type: remote::nvidia
config:
guardrails_service_url: ${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}
config_id: self-check
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/agents_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db}
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config: {}
datasetio:
- provider_id: localfs
provider_type: inline::localfs
config: {}
scoring:
- provider_id: basic
provider_type: inline::basic
config: {}
tool_runtime:
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: nvidia
model_type: llm
- metadata: {}
model_id: ${env.SAFETY_MODEL}
provider_id: nvidia
model_type: llm
shields:
- shield_id: ${env.SAFETY_MODEL}
provider_id: nvidia
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321

View file

@ -26,9 +26,11 @@ providers:
namespace: null namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/faiss_store.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/faiss_store.db
safety: safety:
- provider_id: llama-guard - provider_id: nvidia
provider_type: inline::llama-guard provider_type: remote::nvidia
config: {} config:
guardrails_service_url: ${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}
config_id: self-check
agents: agents:
- provider_id: meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference provider_type: inline::meta-reference
@ -49,9 +51,6 @@ providers:
provider_type: inline::meta-reference provider_type: inline::meta-reference
config: {} config: {}
datasetio: datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config: {}
- provider_id: localfs - provider_id: localfs
provider_type: inline::localfs provider_type: inline::localfs
config: {} config: {}
@ -59,33 +58,10 @@ providers:
- provider_id: basic - provider_id: basic
provider_type: inline::basic provider_type: inline::basic
config: {} config: {}
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
config: {}
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
tool_runtime: tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3
- provider_id: code-interpreter
provider_type: inline::code-interpreter
config: {}
- provider_id: rag-runtime - provider_id: rag-runtime
provider_type: inline::rag-runtime provider_type: inline::rag-runtime
config: {} config: {}
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
config: {}
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db
@ -214,11 +190,7 @@ datasets: []
scoring_fns: [] scoring_fns: []
benchmarks: [] benchmarks: []
tool_groups: tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag - toolgroup_id: builtin::rag
provider_id: rag-runtime provider_id: rag-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter
server: server:
port: 8321 port: 8321