From 3b0a4d0f5ed653a7ebfc4375a302463e74b41469 Mon Sep 17 00:00:00 2001 From: Chantal D Gama Rose Date: Wed, 19 Feb 2025 23:42:16 -0800 Subject: [PATCH] fixing errors --- .../providers/remote/safety/nvidia/config.py | 2 +- .../providers/remote/safety/nvidia/nvidia.py | 37 +++++++++++++------ .../templates/nvidia/run-with-safety.yaml | 3 +- 3 files changed, 29 insertions(+), 13 deletions(-) diff --git a/llama_stack/providers/remote/safety/nvidia/config.py b/llama_stack/providers/remote/safety/nvidia/config.py index 44e1a986f..d98278c94 100644 --- a/llama_stack/providers/remote/safety/nvidia/config.py +++ b/llama_stack/providers/remote/safety/nvidia/config.py @@ -22,7 +22,7 @@ class NVIDIASafetyConfig(BaseModel): Configuration for the NVIDIA Guardrail microservice endpoint. Attributes: - url (str): A base url for accessing the NVIDIA guardrail endpoint, e.g. http://localhost:8000 + guardrails_service_url (str): A base url for accessing the NVIDIA guardrail endpoint, e.g. http://localhost:8000 api_key (str): The access key for the hosted NIM endpoints There are two ways to access NVIDIA NIMs - diff --git a/llama_stack/providers/remote/safety/nvidia/nvidia.py b/llama_stack/providers/remote/safety/nvidia/nvidia.py index c1918f0bc..9b5d051dd 100644 --- a/llama_stack/providers/remote/safety/nvidia/nvidia.py +++ b/llama_stack/providers/remote/safety/nvidia/nvidia.py @@ -8,7 +8,7 @@ import json import logging from typing import Any, Dict, List -from llama_stack.apis.inference import Message +from llama_stack.apis.inference import Message, UserMessage from llama_stack.apis.safety import ( RunShieldResponse, Safety, @@ -16,6 +16,8 @@ from llama_stack.apis.safety import ( ViolationLevel, ) from llama_stack.apis.shields import Shield +from llama_stack.distribution.library_client import convert_pydantic_to_json_value +from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.providers.datatypes import ShieldsProtocolPrivate import requests @@ -23,10 +25,21 @@ from .config import NVIDIASafetyConfig logger = logging.getLogger(__name__) +SHIELD_IDS_TO_MODEL_MAPPING = { + CoreModelId.llama3_8b_instruct.value: "meta/llama3-8b-instruct", + CoreModelId.llama3_70b_instruct.value: "meta/llama3-70b-instruct", + CoreModelId.llama3_1_8b_instruct.value: "meta/llama-3.1-8b-instruct", + CoreModelId.llama3_1_70b_instruct.value: "meta/llama-3.1-70b-instruct", + CoreModelId.llama3_1_405b_instruct.value: "meta/llama-3.1-405b-instruct", + CoreModelId.llama3_2_1b_instruct.value: "meta/llama-3.2-1b-instruct", + CoreModelId.llama3_2_3b_instruct.value: "meta/llama-3.2-3b-instruct", + CoreModelId.llama3_2_11b_vision_instruct.value: "meta/llama-3.2-11b-vision-instruct", + CoreModelId.llama3_2_90b_vision_instruct.value: "meta/llama-3.2-90b-vision-instruct" +} class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate): def __init__(self, config: NVIDIASafetyConfig) -> None: - print(f"Initializing NVIDIASafetyAdapter({config.url})...") + print(f"Initializing NVIDIASafetyAdapter({config.guardrails_service_url})...") self.config = config self.registered_shields = [] @@ -46,10 +59,10 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate): shield = await self.shield_store.get_shield(shield_id) if not shield: raise ValueError(f"Shield {shield_id} not found") - self.shield = NeMoGuardrails(self.config, shield.provider_resource_id) + self.shield = NeMoGuardrails(self.config, SHIELD_IDS_TO_MODEL_MAPPING[shield.shield_id]) return await self.shield.run(messages) - + class NeMoGuardrails: def __init__( self, @@ -58,24 +71,25 @@ class NeMoGuardrails: threshold: float = 0.9, temperature: float = 1.0, ): - config_id = config["config_id"] - config_store_path = config["config_store_path"] - assert config_id is not None or config_store_path is not None, "Must provide one of config id or config store path" + self.config_id = config.config_id + self.config_store_path = config.config_store_path + self.model = model + assert self.config_id is not None or self.config_store_path is not None, "Must provide one of config id or config store path" if temperature <= 0: raise ValueError("Temperature must be greater than 0") self.config = config self.temperature = temperature self.threshold = threshold - self.guardrails_service_url = config["guardrails_service_url"] + self.guardrails_service_url = config.guardrails_service_url async def run(self, messages: List[Message]) -> RunShieldResponse: headers = { "Accept": "application/json", } request_data = { - "model": "meta/llama-3.1-8b-instruct", - "messages": messages, + "model": self.model, + "messages": convert_pydantic_to_json_value(messages), "temperature": self.temperature, "top_p": 1, "frequency_penalty": 0, @@ -83,7 +97,7 @@ class NeMoGuardrails: "max_tokens": 160, "stream": False, "guardrails": { - "config_id": self.config["config_id"], + "config_id": self.config_id, } } response = requests.post( @@ -91,6 +105,7 @@ class NeMoGuardrails: headers=headers, json=request_data ) + print(response) response.raise_for_status() if 'Content-Type' in response.headers and response.headers['Content-Type'].startswith('application/json'): response_json = response.json() diff --git a/llama_stack/templates/nvidia/run-with-safety.yaml b/llama_stack/templates/nvidia/run-with-safety.yaml index 37f46acf2..bfb346c7d 100644 --- a/llama_stack/templates/nvidia/run-with-safety.yaml +++ b/llama_stack/templates/nvidia/run-with-safety.yaml @@ -30,7 +30,7 @@ providers: provider_type: remote::nvidia config: url: ${env.GUARDRAILS_SERVICE_URL:http://localhost:7331} - config_id: + config_id: self-check agents: - provider_id: meta-reference provider_type: inline::meta-reference @@ -139,6 +139,7 @@ models: model_type: llm shields: - shield_id: ${env.SAFETY_MODEL} + provider_id: nvidia vector_dbs: [] datasets: [] scoring_fns: []