From f7e5ae5dfcbb9c6a6e415dc24f50f8db3bcb70b8 Mon Sep 17 00:00:00 2001 From: Chantal D Gama Rose Date: Tue, 25 Feb 2025 19:49:18 +0000 Subject: [PATCH] Adding docstrings --- .../providers/remote/safety/nvidia/config.py | 18 ++---- .../providers/remote/safety/nvidia/nvidia.py | 57 ++++++++++++++++++- 2 files changed, 59 insertions(+), 16 deletions(-) diff --git a/llama_stack/providers/remote/safety/nvidia/config.py b/llama_stack/providers/remote/safety/nvidia/config.py index 90930bf00..3df80ed4f 100644 --- a/llama_stack/providers/remote/safety/nvidia/config.py +++ b/llama_stack/providers/remote/safety/nvidia/config.py @@ -14,25 +14,17 @@ from llama_stack.schema_utils import json_schema_type @json_schema_type class NVIDIASafetyConfig(BaseModel): """ - Configuration for the NVIDIA Guardrail microservice endpoint. + Configuration for the NVIDIA Guardrail microservice endpoint. Attributes: - guardrails_service_url (str): A base url for accessing the NVIDIA guardrail endpoint, e.g. http://localhost:8000 - api_key (str): The access key for the hosted NIM endpoints + guardrails_service_url (str): A base url for accessing the NVIDIA guardrail endpoint, e.g. http://0.0.0.0:7331 + config_id (str): The ID of the guardrails configuration to use from the configuration store + (https://developer.nvidia.com/docs/nemo-microservices/guardrails/source/guides/configuration-store-guide.html) - There are two ways to access NVIDIA NIMs - - 0. Hosted: Preview APIs hosted at https://integrate.api.nvidia.com - 1. Self-hosted: You can run NVIDIA NIMs on your own infrastructure - - By default the configuration is set to use the hosted APIs. This requires - an API key which can be obtained from https://ngc.nvidia.com/. - - By default the configuration will attempt to read the NVIDIA_API_KEY environment - variable to set the api_key. Please do not put your API key in code. """ guardrails_service_url: str = Field( - default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "http://0.0.0.0:7331"), + default_factory=lambda: os.getenv("GUARDRAILS_SERVICE_URL", "http://0.0.0.0:7331"), description="The url for accessing the guardrails service", ) config_id: Optional[str] = Field(default="self-check", description="Config ID to use from the config store") diff --git a/llama_stack/providers/remote/safety/nvidia/nvidia.py b/llama_stack/providers/remote/safety/nvidia/nvidia.py index 1b7295df2..40c6c2dfd 100644 --- a/llama_stack/providers/remote/safety/nvidia/nvidia.py +++ b/llama_stack/providers/remote/safety/nvidia/nvidia.py @@ -22,6 +22,12 @@ logger = logging.getLogger(__name__) class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate): def __init__(self, config: NVIDIASafetyConfig) -> None: + """ + Initialize the NVIDIASafetyAdapter with a given safety configuration. + + Args: + config (NVIDIASafetyConfig): The configuration containing the guardrails service URL and config ID. + """ print(f"Initializing NVIDIASafetyAdapter({config.guardrails_service_url})...") self.config = config @@ -38,6 +44,20 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate): async def run_shield( self, shield_id: str, messages: List[Message], params: Dict[str, Any] = None ) -> RunShieldResponse: + """ + Run a safety shield check against the provided messages. + + Args: + shield_id (str): The unique identifier for the shield to be used. + messages (List[Message]): A list of Message objects representing the conversation history. + params (Dict[str, Any], optional): Additional parameters for the safety check. + + Returns: + RunShieldResponse: The response containing safety violation details if any. + + Raises: + ValueError: If the shield with the provided shield_id is not found. + """ shield = await self.shield_store.get_shield(shield_id) if not shield: raise ValueError(f"Shield {shield_id} not found") @@ -47,6 +67,13 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate): class NeMoGuardrails: + """ + A class that encapsulates NVIDIA's guardrails safety logic. + + Sends messages to the guardrails service and interprets the response to determine + if a safety violation has occurred. + """ + def __init__( self, config: NVIDIASafetyConfig, @@ -54,11 +81,22 @@ class NeMoGuardrails: threshold: float = 0.9, temperature: float = 1.0, ): + """ + Initialize a NeMoGuardrails instance with the provided parameters. + + Args: + config (NVIDIASafetyConfig): The safety configuration containing the config ID and guardrails URL. + model (str): The identifier or name of the model to be used for safety checks. + threshold (float, optional): The threshold for flagging violations. Defaults to 0.9. + temperature (float, optional): The temperature setting for the underlying model. Must be greater than 0. Defaults to 1.0. + + Raises: + ValueError: If temperature is less than or equal to 0. + AssertionError: If config_id is not provided in the configuration. + """ self.config_id = config.config_id self.model = model - assert self.config_id is not None or self.config_store_path is not None, ( - "Must provide one of config id or config store path" - ) + assert self.config_id is not None("Must provide config id") if temperature <= 0: raise ValueError("Temperature must be greater than 0") @@ -67,6 +105,19 @@ class NeMoGuardrails: self.guardrails_service_url = config.guardrails_service_url async def run(self, messages: List[Message]) -> RunShieldResponse: + """ + Queries the /v1/guardrails/checks endpoint of the NeMo guardrails deployed API. + + Args: + messages (List[Message]): A list of Message objects to be checked for safety violations. + + Returns: + RunShieldResponse: If the response indicates a violation ("blocked" status), returns a + RunShieldResponse with a SafetyViolation; otherwise, returns a RunShieldResponse with violation set to None. + + Raises: + requests.HTTPError: If the POST request fails. + """ headers = { "Accept": "application/json", }