From b90ff9ca1613ea6fede5beb41e2bc404eba7f851 Mon Sep 17 00:00:00 2001 From: Chantal D Gama Rose Date: Wed, 19 Feb 2025 14:50:48 -0800 Subject: [PATCH] add NVIDIA safety provider --- .../remote/safety/nvidia/__init__.py | 18 ++ .../providers/remote/safety/nvidia/config.py | 63 +++++++ .../providers/remote/safety/nvidia/nvidia.py | 110 +++++++++++++ llama_stack/templates/nvidia/build.yaml | 2 +- llama_stack/templates/nvidia/nvidia.py | 47 +++++- .../templates/nvidia/run-with-safety.yaml | 154 ++++++++++++++++++ 6 files changed, 389 insertions(+), 5 deletions(-) create mode 100644 llama_stack/providers/remote/safety/nvidia/__init__.py create mode 100644 llama_stack/providers/remote/safety/nvidia/config.py create mode 100644 llama_stack/providers/remote/safety/nvidia/nvidia.py create mode 100644 llama_stack/templates/nvidia/run-with-safety.yaml diff --git a/llama_stack/providers/remote/safety/nvidia/__init__.py b/llama_stack/providers/remote/safety/nvidia/__init__.py new file mode 100644 index 000000000..4677268c6 --- /dev/null +++ b/llama_stack/providers/remote/safety/nvidia/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from typing import Any + +from .config import NVIDIASafetyConfig + + +async def get_adapter_impl(config: NVIDIASafetyConfig, _deps) -> Any: + from .nvidia import NVIDIASafetyAdapter + + impl = NVIDIASafetyAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/safety/nvidia/config.py b/llama_stack/providers/remote/safety/nvidia/config.py new file mode 100644 index 000000000..ef93a8d41 --- /dev/null +++ b/llama_stack/providers/remote/safety/nvidia/config.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from enum import Enum +import os +from typing import Any, Dict, Optional + +from pydantic import BaseModel, Field, SecretStr, field_validator + +from llama_models.schema_utils import json_schema_type + + +class ShieldType(Enum): + self_check = "self_check" + + +@json_schema_type +class NVIDIASafetyConfig(BaseModel): + """ + Configuration for the NVIDIA Guardrail microservice endpoint. + + Attributes: + url (str): A base url for accessing the NVIDIA guardrail endpoint, e.g. http://localhost:8000 + api_key (str): The access key for the hosted NIM endpoints + + There are two ways to access NVIDIA NIMs - + 0. Hosted: Preview APIs hosted at https://integrate.api.nvidia.com + 1. Self-hosted: You can run NVIDIA NIMs on your own infrastructure + + By default the configuration is set to use the hosted APIs. This requires + an API key which can be obtained from https://ngc.nvidia.com/. + + By default the configuration will attempt to read the NVIDIA_API_KEY environment + variable to set the api_key. Please do not put your API key in code. + """ + guardrails_service_url: str = Field( + default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://0.0.0.0:7331"), + description="The url for accessing the guardrails service", + ) + config_id: Optional[str] = Field( + default="self-check", + description="Config ID to use from the config store" + ) + config_store_path: Optional[str] = Field( + default="/config-store", + description="Path to config store" + ) + + @classmethod + @field_validator("guard_type") + def validate_guard_type(cls, v): + if v not in [t.value for t in ShieldType]: + raise ValueError(f"Unknown shield type: {v}") + return v + + @classmethod + def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + return { + "guardrails_service_url": "${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}", + "config_id": "self-check" + } diff --git a/llama_stack/providers/remote/safety/nvidia/nvidia.py b/llama_stack/providers/remote/safety/nvidia/nvidia.py new file mode 100644 index 000000000..6da5bc54a --- /dev/null +++ b/llama_stack/providers/remote/safety/nvidia/nvidia.py @@ -0,0 +1,110 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import logging +from typing import Any, Dict, List + +from llama_stack.apis.inference import Message +from llama_stack.apis.safety import ( + RunShieldResponse, + Safety, + SafetyViolation, + ViolationLevel, +) +from llama_stack.apis.shields import Shield +from llama_stack.providers.datatypes import ShieldsProtocolPrivate +import requests + +from .config import NVIDIASafetyConfig + +logger = logging.getLogger(__name__) + + +class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate): + def __init__(self, config: NVIDIASafetyConfig) -> None: + print(f"Initializing NVIDIASafetyAdapter({config.url})...") + self.config = config + self.registered_shields = [] + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def register_shield(self, shield: Shield) -> None: + if not shield.provider_resource_id: + raise ValueError(f"Shield model not provided. ") + + async def run_shield( + self, shield_id: str, messages: List[Message], params: Dict[str, Any] = None + ) -> RunShieldResponse: + shield = await self.shield_store.get_shield(shield_id) + if not shield: + raise ValueError(f"Shield {shield_id} not found") + self.shield = NeMoGuardrails(self.config, shield.provider_resource_id) + return await self.shield.run(messages) + + + + +class NeMoGuardrails: + def __init__( + self, + config: NVIDIASafetyConfig, + model: str, + threshold: float = 0.9, + temperature: float = 1.0, + ): + config_id = config["config_id"] + config_store_path = config["config_store_path"] + assert config_id is not None or config_store_path is not None, "Must provide one of config id or config store path" + if temperature <= 0: + raise ValueError("Temperature must be greater than 0") + + self.config = config + self.temperature = temperature + self.threshold = threshold + self.guardrails_service_url = config["guardrails_service_url"] + + async def run(self, messages: List[Message]) -> RunShieldResponse: + headers = { + "Accept": "application/json", + } + request_data = { + "model": "meta/llama-3.1-8b-instruct", + "messages": messages, + "temperature": self.temperature, + "top_p": 1, + "frequency_penalty": 0, + "presence_penalty": 0, + "max_tokens": 160, + "stream": False, + "guardrails": { + "config_id": self.config["config_id"], + } + } + response = requests.post( + url=f"{self.guardrails_service_url}/v1/guardrail/checks", + headers=headers, + json=request_data + ) + response.raise_for_status() + if 'Content-Type' in response.headers and response.headers['Content-Type'].startswith('application/json'): + response_json = response.json() + if response_json["status"] == "blocked": + user_message = "Sorry I cannot do this." + metadata = response_json["rails_status"] + + return RunShieldResponse( + violation=SafetyViolation( + user_message=user_message, + violation_level=ViolationLevel.ERROR, + metadata=metadata, + ) + ) + return RunShieldResponse(violation=None) \ No newline at end of file diff --git a/llama_stack/templates/nvidia/build.yaml b/llama_stack/templates/nvidia/build.yaml index e9748721a..63a227d4f 100644 --- a/llama_stack/templates/nvidia/build.yaml +++ b/llama_stack/templates/nvidia/build.yaml @@ -7,7 +7,7 @@ distribution_spec: vector_io: - inline::faiss safety: - - inline::llama-guard + - remote::nvidia agents: - inline::meta-reference telemetry: diff --git a/llama_stack/templates/nvidia/nvidia.py b/llama_stack/templates/nvidia/nvidia.py index ee22b5555..f7100ac72 100644 --- a/llama_stack/templates/nvidia/nvidia.py +++ b/llama_stack/templates/nvidia/nvidia.py @@ -6,8 +6,15 @@ from pathlib import Path +<<<<<<< Updated upstream from llama_stack.distribution.datatypes import ModelInput, Provider, ToolGroupInput from llama_stack.models.llama.sku_list import all_registered_models +======= +from llama_models.sku_list import all_registered_models + +from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput +from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig +>>>>>>> Stashed changes from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.nvidia.nvidia import _MODEL_ALIASES from llama_stack.templates.template import DistributionTemplate, RunConfigSettings @@ -37,6 +44,19 @@ def get_distribution_template() -> DistributionTemplate: provider_type="remote::nvidia", config=NVIDIAConfig.sample_run_config(), ) + safety_provider = Provider( + provider_id="nvidia", + provider_type="remote::nvidia", + config=NVIDIASafetyConfig.sample_run_config(), + ) + inference_model = ModelInput( + model_id="${env.INFERENCE_MODEL}", + provider_id="nvidia", + ) + safety_model = ModelInput( + model_id="${env.SAFETY_MODEL}", + provider_id="nvidia", + ) core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()} default_models = [ @@ -78,15 +98,34 @@ def get_distribution_template() -> DistributionTemplate: default_models=default_models, default_tool_groups=default_tool_groups, ), + "run-with-safety.yaml": RunConfigSettings( + provider_overrides={ + "inference": [ + inference_provider, + safety_provider, + ] + }, + default_models=[inference_model, safety_model], + default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")], + default_tool_groups=default_tool_groups, + ), }, run_config_env_vars={ - "LLAMASTACK_PORT": ( - "5001", - "Port for the Llama Stack distribution server", - ), "NVIDIA_API_KEY": ( "", "NVIDIA API Key", ), + "GUARDRAILS_SERVICE_URL": ( + "http://0.0.0.0:7331", + "URL for the NeMo Guardrails Service", + ), + "INFERENCE_MODEL": ( + "Llama3.1-8B-Instruct", + "Inference model", + ), + "SAFETY_MODEL": ( + "meta/llama-3.1-8b-instruct", + "Name of the model to use for safety", + ), }, ) diff --git a/llama_stack/templates/nvidia/run-with-safety.yaml b/llama_stack/templates/nvidia/run-with-safety.yaml new file mode 100644 index 000000000..37f46acf2 --- /dev/null +++ b/llama_stack/templates/nvidia/run-with-safety.yaml @@ -0,0 +1,154 @@ +version: '2' +image_name: nvidia +apis: +- agents +- datasetio +- eval +- inference +- safety +- scoring +- telemetry +- tool_runtime +- vector_io +providers: + inference: + - provider_id: nvidia + provider_type: remote::nvidia + config: + url: ${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com} + api_key: ${env.NVIDIA_API_KEY:} + vector_io: + - provider_id: faiss + provider_type: inline::faiss + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/faiss_store.db + safety: + - provider_id: nvidia + provider_type: remote::nvidia + config: + url: ${env.GUARDRAILS_SERVICE_URL:http://localhost:7331} + config_id: + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/agents_store.db + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + service_name: ${env.OTEL_SERVICE_NAME:llama-stack} + sinks: ${env.TELEMETRY_SINKS:console,sqlite} + sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db} + eval: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: {} + datasetio: + - provider_id: huggingface + provider_type: remote::huggingface + config: {} + - provider_id: localfs + provider_type: inline::localfs + config: {} + scoring: + - provider_id: basic + provider_type: inline::basic + config: {} + - provider_id: llm-as-judge + provider_type: inline::llm-as-judge + config: {} + - provider_id: braintrust + provider_type: inline::braintrust + config: + openai_api_key: ${env.OPENAI_API_KEY:} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:} + max_results: 3 + - provider_id: code-interpreter + provider_type: inline::code-interpreter + config: {} + - provider_id: rag-runtime + provider_type: inline::rag-runtime + config: {} + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol + config: {} +metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db +models: +- metadata: {} + model_id: meta-llama/Llama-3-8B-Instruct + provider_id: nvidia + provider_model_id: meta/llama3-8b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3-70B-Instruct + provider_id: nvidia + provider_model_id: meta/llama3-70b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.1-8B-Instruct + provider_id: nvidia + provider_model_id: meta/llama-3.1-8b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.1-70B-Instruct + provider_id: nvidia + provider_model_id: meta/llama-3.1-70b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.1-405B-Instruct-FP8 + provider_id: nvidia + provider_model_id: meta/llama-3.1-405b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.2-1B-Instruct + provider_id: nvidia + provider_model_id: meta/llama-3.2-1b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.2-3B-Instruct + provider_id: nvidia + provider_model_id: meta/llama-3.2-3b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.2-11B-Vision-Instruct + provider_id: nvidia + provider_model_id: meta/llama-3.2-11b-vision-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.2-90B-Vision-Instruct + provider_id: nvidia + provider_model_id: meta/llama-3.2-90b-vision-instruct + model_type: llm +shields: +- shield_id: ${env.SAFETY_MODEL} +vector_dbs: [] +datasets: [] +scoring_fns: [] +eval_tasks: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::rag + provider_id: rag-runtime +- toolgroup_id: builtin::code_interpreter + provider_id: code-interpreter +server: + port: 8321