mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-07 19:12:09 +00:00
Adding docstrings
This commit is contained in:
parent
8d095aabe6
commit
f7e5ae5dfc
2 changed files with 59 additions and 16 deletions
|
@ -17,22 +17,14 @@ class NVIDIASafetyConfig(BaseModel):
|
||||||
Configuration for the NVIDIA Guardrail microservice endpoint.
|
Configuration for the NVIDIA Guardrail microservice endpoint.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
guardrails_service_url (str): A base url for accessing the NVIDIA guardrail endpoint, e.g. http://localhost:8000
|
guardrails_service_url (str): A base url for accessing the NVIDIA guardrail endpoint, e.g. http://0.0.0.0:7331
|
||||||
api_key (str): The access key for the hosted NIM endpoints
|
config_id (str): The ID of the guardrails configuration to use from the configuration store
|
||||||
|
(https://developer.nvidia.com/docs/nemo-microservices/guardrails/source/guides/configuration-store-guide.html)
|
||||||
|
|
||||||
There are two ways to access NVIDIA NIMs -
|
|
||||||
0. Hosted: Preview APIs hosted at https://integrate.api.nvidia.com
|
|
||||||
1. Self-hosted: You can run NVIDIA NIMs on your own infrastructure
|
|
||||||
|
|
||||||
By default the configuration is set to use the hosted APIs. This requires
|
|
||||||
an API key which can be obtained from https://ngc.nvidia.com/.
|
|
||||||
|
|
||||||
By default the configuration will attempt to read the NVIDIA_API_KEY environment
|
|
||||||
variable to set the api_key. Please do not put your API key in code.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
guardrails_service_url: str = Field(
|
guardrails_service_url: str = Field(
|
||||||
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "http://0.0.0.0:7331"),
|
default_factory=lambda: os.getenv("GUARDRAILS_SERVICE_URL", "http://0.0.0.0:7331"),
|
||||||
description="The url for accessing the guardrails service",
|
description="The url for accessing the guardrails service",
|
||||||
)
|
)
|
||||||
config_id: Optional[str] = Field(default="self-check", description="Config ID to use from the config store")
|
config_id: Optional[str] = Field(default="self-check", description="Config ID to use from the config store")
|
||||||
|
|
|
@ -22,6 +22,12 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
def __init__(self, config: NVIDIASafetyConfig) -> None:
|
def __init__(self, config: NVIDIASafetyConfig) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the NVIDIASafetyAdapter with a given safety configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (NVIDIASafetyConfig): The configuration containing the guardrails service URL and config ID.
|
||||||
|
"""
|
||||||
print(f"Initializing NVIDIASafetyAdapter({config.guardrails_service_url})...")
|
print(f"Initializing NVIDIASafetyAdapter({config.guardrails_service_url})...")
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
@ -38,6 +44,20 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self, shield_id: str, messages: List[Message], params: Dict[str, Any] = None
|
self, shield_id: str, messages: List[Message], params: Dict[str, Any] = None
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
|
"""
|
||||||
|
Run a safety shield check against the provided messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shield_id (str): The unique identifier for the shield to be used.
|
||||||
|
messages (List[Message]): A list of Message objects representing the conversation history.
|
||||||
|
params (Dict[str, Any], optional): Additional parameters for the safety check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RunShieldResponse: The response containing safety violation details if any.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the shield with the provided shield_id is not found.
|
||||||
|
"""
|
||||||
shield = await self.shield_store.get_shield(shield_id)
|
shield = await self.shield_store.get_shield(shield_id)
|
||||||
if not shield:
|
if not shield:
|
||||||
raise ValueError(f"Shield {shield_id} not found")
|
raise ValueError(f"Shield {shield_id} not found")
|
||||||
|
@ -47,6 +67,13 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
|
|
||||||
|
|
||||||
class NeMoGuardrails:
|
class NeMoGuardrails:
|
||||||
|
"""
|
||||||
|
A class that encapsulates NVIDIA's guardrails safety logic.
|
||||||
|
|
||||||
|
Sends messages to the guardrails service and interprets the response to determine
|
||||||
|
if a safety violation has occurred.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: NVIDIASafetyConfig,
|
config: NVIDIASafetyConfig,
|
||||||
|
@ -54,11 +81,22 @@ class NeMoGuardrails:
|
||||||
threshold: float = 0.9,
|
threshold: float = 0.9,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Initialize a NeMoGuardrails instance with the provided parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (NVIDIASafetyConfig): The safety configuration containing the config ID and guardrails URL.
|
||||||
|
model (str): The identifier or name of the model to be used for safety checks.
|
||||||
|
threshold (float, optional): The threshold for flagging violations. Defaults to 0.9.
|
||||||
|
temperature (float, optional): The temperature setting for the underlying model. Must be greater than 0. Defaults to 1.0.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If temperature is less than or equal to 0.
|
||||||
|
AssertionError: If config_id is not provided in the configuration.
|
||||||
|
"""
|
||||||
self.config_id = config.config_id
|
self.config_id = config.config_id
|
||||||
self.model = model
|
self.model = model
|
||||||
assert self.config_id is not None or self.config_store_path is not None, (
|
assert self.config_id is not None("Must provide config id")
|
||||||
"Must provide one of config id or config store path"
|
|
||||||
)
|
|
||||||
if temperature <= 0:
|
if temperature <= 0:
|
||||||
raise ValueError("Temperature must be greater than 0")
|
raise ValueError("Temperature must be greater than 0")
|
||||||
|
|
||||||
|
@ -67,6 +105,19 @@ class NeMoGuardrails:
|
||||||
self.guardrails_service_url = config.guardrails_service_url
|
self.guardrails_service_url = config.guardrails_service_url
|
||||||
|
|
||||||
async def run(self, messages: List[Message]) -> RunShieldResponse:
|
async def run(self, messages: List[Message]) -> RunShieldResponse:
|
||||||
|
"""
|
||||||
|
Queries the /v1/guardrails/checks endpoint of the NeMo guardrails deployed API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages (List[Message]): A list of Message objects to be checked for safety violations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RunShieldResponse: If the response indicates a violation ("blocked" status), returns a
|
||||||
|
RunShieldResponse with a SafetyViolation; otherwise, returns a RunShieldResponse with violation set to None.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
requests.HTTPError: If the POST request fails.
|
||||||
|
"""
|
||||||
headers = {
|
headers = {
|
||||||
"Accept": "application/json",
|
"Accept": "application/json",
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue