Adding docstrings

This commit is contained in:
Chantal D Gama Rose 2025-02-25 19:49:18 +00:00
parent 8d095aabe6
commit f7e5ae5dfc
2 changed files with 59 additions and 16 deletions

View file

@ -17,22 +17,14 @@ class NVIDIASafetyConfig(BaseModel):
Configuration for the NVIDIA Guardrail microservice endpoint.
Attributes:
guardrails_service_url (str): A base url for accessing the NVIDIA guardrail endpoint, e.g. http://localhost:8000
api_key (str): The access key for the hosted NIM endpoints
guardrails_service_url (str): A base url for accessing the NVIDIA guardrail endpoint, e.g. http://0.0.0.0:7331
config_id (str): The ID of the guardrails configuration to use from the configuration store
(https://developer.nvidia.com/docs/nemo-microservices/guardrails/source/guides/configuration-store-guide.html)
There are two ways to access NVIDIA NIMs -
0. Hosted: Preview APIs hosted at https://integrate.api.nvidia.com
1. Self-hosted: You can run NVIDIA NIMs on your own infrastructure
By default the configuration is set to use the hosted APIs. This requires
an API key which can be obtained from https://ngc.nvidia.com/.
By default the configuration will attempt to read the NVIDIA_API_KEY environment
variable to set the api_key. Please do not put your API key in code.
"""
guardrails_service_url: str = Field(
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "http://0.0.0.0:7331"),
default_factory=lambda: os.getenv("GUARDRAILS_SERVICE_URL", "http://0.0.0.0:7331"),
description="The url for accessing the guardrails service",
)
config_id: Optional[str] = Field(default="self-check", description="Config ID to use from the config store")

View file

@ -22,6 +22,12 @@ logger = logging.getLogger(__name__)
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
def __init__(self, config: NVIDIASafetyConfig) -> None:
"""
Initialize the NVIDIASafetyAdapter with a given safety configuration.
Args:
config (NVIDIASafetyConfig): The configuration containing the guardrails service URL and config ID.
"""
print(f"Initializing NVIDIASafetyAdapter({config.guardrails_service_url})...")
self.config = config
@ -38,6 +44,20 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
async def run_shield(
self, shield_id: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse:
"""
Run a safety shield check against the provided messages.
Args:
shield_id (str): The unique identifier for the shield to be used.
messages (List[Message]): A list of Message objects representing the conversation history.
params (Dict[str, Any], optional): Additional parameters for the safety check.
Returns:
RunShieldResponse: The response containing safety violation details if any.
Raises:
ValueError: If the shield with the provided shield_id is not found.
"""
shield = await self.shield_store.get_shield(shield_id)
if not shield:
raise ValueError(f"Shield {shield_id} not found")
@ -47,6 +67,13 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
class NeMoGuardrails:
"""
A class that encapsulates NVIDIA's guardrails safety logic.
Sends messages to the guardrails service and interprets the response to determine
if a safety violation has occurred.
"""
def __init__(
self,
config: NVIDIASafetyConfig,
@ -54,11 +81,22 @@ class NeMoGuardrails:
threshold: float = 0.9,
temperature: float = 1.0,
):
"""
Initialize a NeMoGuardrails instance with the provided parameters.
Args:
config (NVIDIASafetyConfig): The safety configuration containing the config ID and guardrails URL.
model (str): The identifier or name of the model to be used for safety checks.
threshold (float, optional): The threshold for flagging violations. Defaults to 0.9.
temperature (float, optional): The temperature setting for the underlying model. Must be greater than 0. Defaults to 1.0.
Raises:
ValueError: If temperature is less than or equal to 0.
AssertionError: If config_id is not provided in the configuration.
"""
self.config_id = config.config_id
self.model = model
assert self.config_id is not None or self.config_store_path is not None, (
"Must provide one of config id or config store path"
)
assert self.config_id is not None("Must provide config id")
if temperature <= 0:
raise ValueError("Temperature must be greater than 0")
@ -67,6 +105,19 @@ class NeMoGuardrails:
self.guardrails_service_url = config.guardrails_service_url
async def run(self, messages: List[Message]) -> RunShieldResponse:
"""
Queries the /v1/guardrails/checks endpoint of the NeMo guardrails deployed API.
Args:
messages (List[Message]): A list of Message objects to be checked for safety violations.
Returns:
RunShieldResponse: If the response indicates a violation ("blocked" status), returns a
RunShieldResponse with a SafetyViolation; otherwise, returns a RunShieldResponse with violation set to None.
Raises:
requests.HTTPError: If the POST request fails.
"""
headers = {
"Accept": "application/json",
}