fixing errors

This commit is contained in:
Chantal D Gama Rose 2025-02-19 23:42:16 -08:00
parent 205cbcb46a
commit 3b0a4d0f5e
3 changed files with 29 additions and 13 deletions

View file

@ -22,7 +22,7 @@ class NVIDIASafetyConfig(BaseModel):
Configuration for the NVIDIA Guardrail microservice endpoint.
Attributes:
url (str): A base url for accessing the NVIDIA guardrail endpoint, e.g. http://localhost:8000
guardrails_service_url (str): A base url for accessing the NVIDIA guardrail endpoint, e.g. http://localhost:8000
api_key (str): The access key for the hosted NIM endpoints
There are two ways to access NVIDIA NIMs -

View file

@ -8,7 +8,7 @@ import json
import logging
from typing import Any, Dict, List
from llama_stack.apis.inference import Message
from llama_stack.apis.inference import Message, UserMessage
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
@ -16,6 +16,8 @@ from llama_stack.apis.safety import (
ViolationLevel,
)
from llama_stack.apis.shields import Shield
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
import requests
@ -23,10 +25,21 @@ from .config import NVIDIASafetyConfig
logger = logging.getLogger(__name__)
SHIELD_IDS_TO_MODEL_MAPPING = {
CoreModelId.llama3_8b_instruct.value: "meta/llama3-8b-instruct",
CoreModelId.llama3_70b_instruct.value: "meta/llama3-70b-instruct",
CoreModelId.llama3_1_8b_instruct.value: "meta/llama-3.1-8b-instruct",
CoreModelId.llama3_1_70b_instruct.value: "meta/llama-3.1-70b-instruct",
CoreModelId.llama3_1_405b_instruct.value: "meta/llama-3.1-405b-instruct",
CoreModelId.llama3_2_1b_instruct.value: "meta/llama-3.2-1b-instruct",
CoreModelId.llama3_2_3b_instruct.value: "meta/llama-3.2-3b-instruct",
CoreModelId.llama3_2_11b_vision_instruct.value: "meta/llama-3.2-11b-vision-instruct",
CoreModelId.llama3_2_90b_vision_instruct.value: "meta/llama-3.2-90b-vision-instruct"
}
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
def __init__(self, config: NVIDIASafetyConfig) -> None:
print(f"Initializing NVIDIASafetyAdapter({config.url})...")
print(f"Initializing NVIDIASafetyAdapter({config.guardrails_service_url})...")
self.config = config
self.registered_shields = []
@ -46,10 +59,10 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
shield = await self.shield_store.get_shield(shield_id)
if not shield:
raise ValueError(f"Shield {shield_id} not found")
self.shield = NeMoGuardrails(self.config, shield.provider_resource_id)
self.shield = NeMoGuardrails(self.config, SHIELD_IDS_TO_MODEL_MAPPING[shield.shield_id])
return await self.shield.run(messages)
class NeMoGuardrails:
def __init__(
self,
@ -58,24 +71,25 @@ class NeMoGuardrails:
threshold: float = 0.9,
temperature: float = 1.0,
):
config_id = config["config_id"]
config_store_path = config["config_store_path"]
assert config_id is not None or config_store_path is not None, "Must provide one of config id or config store path"
self.config_id = config.config_id
self.config_store_path = config.config_store_path
self.model = model
assert self.config_id is not None or self.config_store_path is not None, "Must provide one of config id or config store path"
if temperature <= 0:
raise ValueError("Temperature must be greater than 0")
self.config = config
self.temperature = temperature
self.threshold = threshold
self.guardrails_service_url = config["guardrails_service_url"]
self.guardrails_service_url = config.guardrails_service_url
async def run(self, messages: List[Message]) -> RunShieldResponse:
headers = {
"Accept": "application/json",
}
request_data = {
"model": "meta/llama-3.1-8b-instruct",
"messages": messages,
"model": self.model,
"messages": convert_pydantic_to_json_value(messages),
"temperature": self.temperature,
"top_p": 1,
"frequency_penalty": 0,
@ -83,7 +97,7 @@ class NeMoGuardrails:
"max_tokens": 160,
"stream": False,
"guardrails": {
"config_id": self.config["config_id"],
"config_id": self.config_id,
}
}
response = requests.post(
@ -91,6 +105,7 @@ class NeMoGuardrails:
headers=headers,
json=request_data
)
print(response)
response.raise_for_status()
if 'Content-Type' in response.headers and response.headers['Content-Type'].startswith('application/json'):
response_json = response.json()

View file

@ -30,7 +30,7 @@ providers:
provider_type: remote::nvidia
config:
url: ${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}
config_id:
config_id: self-check
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
@ -139,6 +139,7 @@ models:
model_type: llm
shields:
- shield_id: ${env.SAFETY_MODEL}
provider_id: nvidia
vector_dbs: []
datasets: []
scoring_fns: []