mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
fixing errors
This commit is contained in:
parent
205cbcb46a
commit
3b0a4d0f5e
3 changed files with 29 additions and 13 deletions
|
@ -22,7 +22,7 @@ class NVIDIASafetyConfig(BaseModel):
|
|||
Configuration for the NVIDIA Guardrail microservice endpoint.
|
||||
|
||||
Attributes:
|
||||
url (str): A base url for accessing the NVIDIA guardrail endpoint, e.g. http://localhost:8000
|
||||
guardrails_service_url (str): A base url for accessing the NVIDIA guardrail endpoint, e.g. http://localhost:8000
|
||||
api_key (str): The access key for the hosted NIM endpoints
|
||||
|
||||
There are two ways to access NVIDIA NIMs -
|
||||
|
|
|
@ -8,7 +8,7 @@ import json
|
|||
import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.inference import Message, UserMessage
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
|
@ -16,6 +16,8 @@ from llama_stack.apis.safety import (
|
|||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
|
||||
from llama_stack.models.llama.datatypes import CoreModelId
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
import requests
|
||||
|
||||
|
@ -23,10 +25,21 @@ from .config import NVIDIASafetyConfig
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SHIELD_IDS_TO_MODEL_MAPPING = {
|
||||
CoreModelId.llama3_8b_instruct.value: "meta/llama3-8b-instruct",
|
||||
CoreModelId.llama3_70b_instruct.value: "meta/llama3-70b-instruct",
|
||||
CoreModelId.llama3_1_8b_instruct.value: "meta/llama-3.1-8b-instruct",
|
||||
CoreModelId.llama3_1_70b_instruct.value: "meta/llama-3.1-70b-instruct",
|
||||
CoreModelId.llama3_1_405b_instruct.value: "meta/llama-3.1-405b-instruct",
|
||||
CoreModelId.llama3_2_1b_instruct.value: "meta/llama-3.2-1b-instruct",
|
||||
CoreModelId.llama3_2_3b_instruct.value: "meta/llama-3.2-3b-instruct",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value: "meta/llama-3.2-11b-vision-instruct",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value: "meta/llama-3.2-90b-vision-instruct"
|
||||
}
|
||||
|
||||
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||
def __init__(self, config: NVIDIASafetyConfig) -> None:
|
||||
print(f"Initializing NVIDIASafetyAdapter({config.url})...")
|
||||
print(f"Initializing NVIDIASafetyAdapter({config.guardrails_service_url})...")
|
||||
self.config = config
|
||||
self.registered_shields = []
|
||||
|
||||
|
@ -46,10 +59,10 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|||
shield = await self.shield_store.get_shield(shield_id)
|
||||
if not shield:
|
||||
raise ValueError(f"Shield {shield_id} not found")
|
||||
self.shield = NeMoGuardrails(self.config, shield.provider_resource_id)
|
||||
self.shield = NeMoGuardrails(self.config, SHIELD_IDS_TO_MODEL_MAPPING[shield.shield_id])
|
||||
return await self.shield.run(messages)
|
||||
|
||||
|
||||
|
||||
class NeMoGuardrails:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -58,24 +71,25 @@ class NeMoGuardrails:
|
|||
threshold: float = 0.9,
|
||||
temperature: float = 1.0,
|
||||
):
|
||||
config_id = config["config_id"]
|
||||
config_store_path = config["config_store_path"]
|
||||
assert config_id is not None or config_store_path is not None, "Must provide one of config id or config store path"
|
||||
self.config_id = config.config_id
|
||||
self.config_store_path = config.config_store_path
|
||||
self.model = model
|
||||
assert self.config_id is not None or self.config_store_path is not None, "Must provide one of config id or config store path"
|
||||
if temperature <= 0:
|
||||
raise ValueError("Temperature must be greater than 0")
|
||||
|
||||
self.config = config
|
||||
self.temperature = temperature
|
||||
self.threshold = threshold
|
||||
self.guardrails_service_url = config["guardrails_service_url"]
|
||||
self.guardrails_service_url = config.guardrails_service_url
|
||||
|
||||
async def run(self, messages: List[Message]) -> RunShieldResponse:
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
}
|
||||
request_data = {
|
||||
"model": "meta/llama-3.1-8b-instruct",
|
||||
"messages": messages,
|
||||
"model": self.model,
|
||||
"messages": convert_pydantic_to_json_value(messages),
|
||||
"temperature": self.temperature,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
|
@ -83,7 +97,7 @@ class NeMoGuardrails:
|
|||
"max_tokens": 160,
|
||||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": self.config["config_id"],
|
||||
"config_id": self.config_id,
|
||||
}
|
||||
}
|
||||
response = requests.post(
|
||||
|
@ -91,6 +105,7 @@ class NeMoGuardrails:
|
|||
headers=headers,
|
||||
json=request_data
|
||||
)
|
||||
print(response)
|
||||
response.raise_for_status()
|
||||
if 'Content-Type' in response.headers and response.headers['Content-Type'].startswith('application/json'):
|
||||
response_json = response.json()
|
||||
|
|
|
@ -30,7 +30,7 @@ providers:
|
|||
provider_type: remote::nvidia
|
||||
config:
|
||||
url: ${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}
|
||||
config_id:
|
||||
config_id: self-check
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
@ -139,6 +139,7 @@ models:
|
|||
model_type: llm
|
||||
shields:
|
||||
- shield_id: ${env.SAFETY_MODEL}
|
||||
provider_id: nvidia
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue