mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-13 21:29:57 +00:00
fixing errors
This commit is contained in:
parent
205cbcb46a
commit
3b0a4d0f5e
3 changed files with 29 additions and 13 deletions
|
@ -22,7 +22,7 @@ class NVIDIASafetyConfig(BaseModel):
|
||||||
Configuration for the NVIDIA Guardrail microservice endpoint.
|
Configuration for the NVIDIA Guardrail microservice endpoint.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
url (str): A base url for accessing the NVIDIA guardrail endpoint, e.g. http://localhost:8000
|
guardrails_service_url (str): A base url for accessing the NVIDIA guardrail endpoint, e.g. http://localhost:8000
|
||||||
api_key (str): The access key for the hosted NIM endpoints
|
api_key (str): The access key for the hosted NIM endpoints
|
||||||
|
|
||||||
There are two ways to access NVIDIA NIMs -
|
There are two ways to access NVIDIA NIMs -
|
||||||
|
|
|
@ -8,7 +8,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
from llama_stack.apis.inference import Message, UserMessage
|
||||||
from llama_stack.apis.safety import (
|
from llama_stack.apis.safety import (
|
||||||
RunShieldResponse,
|
RunShieldResponse,
|
||||||
Safety,
|
Safety,
|
||||||
|
@ -16,6 +16,8 @@ from llama_stack.apis.safety import (
|
||||||
ViolationLevel,
|
ViolationLevel,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
|
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
|
||||||
|
from llama_stack.models.llama.datatypes import CoreModelId
|
||||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
@ -23,10 +25,21 @@ from .config import NVIDIASafetyConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
SHIELD_IDS_TO_MODEL_MAPPING = {
|
||||||
|
CoreModelId.llama3_8b_instruct.value: "meta/llama3-8b-instruct",
|
||||||
|
CoreModelId.llama3_70b_instruct.value: "meta/llama3-70b-instruct",
|
||||||
|
CoreModelId.llama3_1_8b_instruct.value: "meta/llama-3.1-8b-instruct",
|
||||||
|
CoreModelId.llama3_1_70b_instruct.value: "meta/llama-3.1-70b-instruct",
|
||||||
|
CoreModelId.llama3_1_405b_instruct.value: "meta/llama-3.1-405b-instruct",
|
||||||
|
CoreModelId.llama3_2_1b_instruct.value: "meta/llama-3.2-1b-instruct",
|
||||||
|
CoreModelId.llama3_2_3b_instruct.value: "meta/llama-3.2-3b-instruct",
|
||||||
|
CoreModelId.llama3_2_11b_vision_instruct.value: "meta/llama-3.2-11b-vision-instruct",
|
||||||
|
CoreModelId.llama3_2_90b_vision_instruct.value: "meta/llama-3.2-90b-vision-instruct"
|
||||||
|
}
|
||||||
|
|
||||||
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
def __init__(self, config: NVIDIASafetyConfig) -> None:
|
def __init__(self, config: NVIDIASafetyConfig) -> None:
|
||||||
print(f"Initializing NVIDIASafetyAdapter({config.url})...")
|
print(f"Initializing NVIDIASafetyAdapter({config.guardrails_service_url})...")
|
||||||
self.config = config
|
self.config = config
|
||||||
self.registered_shields = []
|
self.registered_shields = []
|
||||||
|
|
||||||
|
@ -46,7 +59,7 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
shield = await self.shield_store.get_shield(shield_id)
|
shield = await self.shield_store.get_shield(shield_id)
|
||||||
if not shield:
|
if not shield:
|
||||||
raise ValueError(f"Shield {shield_id} not found")
|
raise ValueError(f"Shield {shield_id} not found")
|
||||||
self.shield = NeMoGuardrails(self.config, shield.provider_resource_id)
|
self.shield = NeMoGuardrails(self.config, SHIELD_IDS_TO_MODEL_MAPPING[shield.shield_id])
|
||||||
return await self.shield.run(messages)
|
return await self.shield.run(messages)
|
||||||
|
|
||||||
|
|
||||||
|
@ -58,24 +71,25 @@ class NeMoGuardrails:
|
||||||
threshold: float = 0.9,
|
threshold: float = 0.9,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
):
|
):
|
||||||
config_id = config["config_id"]
|
self.config_id = config.config_id
|
||||||
config_store_path = config["config_store_path"]
|
self.config_store_path = config.config_store_path
|
||||||
assert config_id is not None or config_store_path is not None, "Must provide one of config id or config store path"
|
self.model = model
|
||||||
|
assert self.config_id is not None or self.config_store_path is not None, "Must provide one of config id or config store path"
|
||||||
if temperature <= 0:
|
if temperature <= 0:
|
||||||
raise ValueError("Temperature must be greater than 0")
|
raise ValueError("Temperature must be greater than 0")
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.threshold = threshold
|
self.threshold = threshold
|
||||||
self.guardrails_service_url = config["guardrails_service_url"]
|
self.guardrails_service_url = config.guardrails_service_url
|
||||||
|
|
||||||
async def run(self, messages: List[Message]) -> RunShieldResponse:
|
async def run(self, messages: List[Message]) -> RunShieldResponse:
|
||||||
headers = {
|
headers = {
|
||||||
"Accept": "application/json",
|
"Accept": "application/json",
|
||||||
}
|
}
|
||||||
request_data = {
|
request_data = {
|
||||||
"model": "meta/llama-3.1-8b-instruct",
|
"model": self.model,
|
||||||
"messages": messages,
|
"messages": convert_pydantic_to_json_value(messages),
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
"top_p": 1,
|
"top_p": 1,
|
||||||
"frequency_penalty": 0,
|
"frequency_penalty": 0,
|
||||||
|
@ -83,7 +97,7 @@ class NeMoGuardrails:
|
||||||
"max_tokens": 160,
|
"max_tokens": 160,
|
||||||
"stream": False,
|
"stream": False,
|
||||||
"guardrails": {
|
"guardrails": {
|
||||||
"config_id": self.config["config_id"],
|
"config_id": self.config_id,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
|
@ -91,6 +105,7 @@ class NeMoGuardrails:
|
||||||
headers=headers,
|
headers=headers,
|
||||||
json=request_data
|
json=request_data
|
||||||
)
|
)
|
||||||
|
print(response)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
if 'Content-Type' in response.headers and response.headers['Content-Type'].startswith('application/json'):
|
if 'Content-Type' in response.headers and response.headers['Content-Type'].startswith('application/json'):
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
|
|
|
@ -30,7 +30,7 @@ providers:
|
||||||
provider_type: remote::nvidia
|
provider_type: remote::nvidia
|
||||||
config:
|
config:
|
||||||
url: ${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}
|
url: ${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}
|
||||||
config_id:
|
config_id: self-check
|
||||||
agents:
|
agents:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
|
@ -139,6 +139,7 @@ models:
|
||||||
model_type: llm
|
model_type: llm
|
||||||
shields:
|
shields:
|
||||||
- shield_id: ${env.SAFETY_MODEL}
|
- shield_id: ${env.SAFETY_MODEL}
|
||||||
|
provider_id: nvidia
|
||||||
vector_dbs: []
|
vector_dbs: []
|
||||||
datasets: []
|
datasets: []
|
||||||
scoring_fns: []
|
scoring_fns: []
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue