added nvidia as safety provider

This commit is contained in:
Chantal D Gama Rose 2025-02-25 08:16:49 +00:00
parent 07a992ef90
commit 0593408c19
14 changed files with 354 additions and 78 deletions

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import NVIDIASafetyConfig
async def get_adapter_impl(config: NVIDIASafetyConfig, _deps) -> Any:
from .nvidia import NVIDIASafetyAdapter
impl = NVIDIASafetyAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,45 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class NVIDIASafetyConfig(BaseModel):
"""
Configuration for the NVIDIA Guardrail microservice endpoint.
Attributes:
guardrails_service_url (str): A base url for accessing the NVIDIA guardrail endpoint, e.g. http://localhost:8000
api_key (str): The access key for the hosted NIM endpoints
There are two ways to access NVIDIA NIMs -
0. Hosted: Preview APIs hosted at https://integrate.api.nvidia.com
1. Self-hosted: You can run NVIDIA NIMs on your own infrastructure
By default the configuration is set to use the hosted APIs. This requires
an API key which can be obtained from https://ngc.nvidia.com/.
By default the configuration will attempt to read the NVIDIA_API_KEY environment
variable to set the api_key. Please do not put your API key in code.
"""
guardrails_service_url: str = Field(
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "http://0.0.0.0:7331"),
description="The url for accessing the guardrails service",
)
config_id: Optional[str] = Field(default="self-check", description="Config ID to use from the config store")
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"guardrails_service_url": "${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}",
"config_id": "self-check",
}

View file

@ -0,0 +1,103 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import Any, Dict, List
import requests
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import RunShieldResponse, Safety, SafetyViolation, ViolationLevel
from llama_stack.apis.shields import Shield
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from .config import NVIDIASafetyConfig
logger = logging.getLogger(__name__)
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
def __init__(self, config: NVIDIASafetyConfig) -> None:
print(f"Initializing NVIDIASafetyAdapter({config.guardrails_service_url})...")
self.config = config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: Shield) -> None:
if not shield.provider_resource_id:
raise ValueError("Shield model not provided.")
async def run_shield(
self, shield_id: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
if not shield:
raise ValueError(f"Shield {shield_id} not found")
self.shield = NeMoGuardrails(self.config, shield.shield_id)
return await self.shield.run(messages)
class NeMoGuardrails:
def __init__(
self,
config: NVIDIASafetyConfig,
model: str,
threshold: float = 0.9,
temperature: float = 1.0,
):
self.config_id = config.config_id
self.model = model
assert self.config_id is not None or self.config_store_path is not None, (
"Must provide one of config id or config store path"
)
if temperature <= 0:
raise ValueError("Temperature must be greater than 0")
self.temperature = temperature
self.threshold = threshold
self.guardrails_service_url = config.guardrails_service_url
async def run(self, messages: List[Message]) -> RunShieldResponse:
headers = {
"Accept": "application/json",
}
request_data = {
"model": self.model,
"messages": convert_pydantic_to_json_value(messages),
"temperature": self.temperature,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"max_tokens": 160,
"stream": False,
"guardrails": {
"config_id": self.config_id,
},
}
response = requests.post(
url=f"{self.guardrails_service_url}/v1/guardrail/checks", headers=headers, json=request_data
)
response.raise_for_status()
if "Content-Type" in response.headers and response.headers["Content-Type"].startswith("application/json"):
response_json = response.json()
if response_json["status"] == "blocked":
user_message = "Sorry I cannot do this."
metadata = response_json["rails_status"]
return RunShieldResponse(
violation=SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata,
)
)
return RunShieldResponse(violation=None)