mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 06:28:50 +00:00
chore(test): migrate unit tests from unittest to pytest nvidia test safety (#2793)
This PR replaces unittest with pytest. Part of https://github.com/meta-llama/llama-stack/issues/2680 cc @leseb Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
parent
9069d878ef
commit
6ab5760a1b
1 changed files with 293 additions and 261 deletions
|
@ -5,321 +5,353 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import unittest
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.apis.inference import CompletionMessage, UserMessage
|
from llama_stack.apis.inference import CompletionMessage, UserMessage
|
||||||
|
from llama_stack.apis.resource import ResourceType
|
||||||
from llama_stack.apis.safety import RunShieldResponse, ViolationLevel
|
from llama_stack.apis.safety import RunShieldResponse, ViolationLevel
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
|
from llama_stack.models.llama.datatypes import StopReason
|
||||||
from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig
|
from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig
|
||||||
from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter
|
from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter
|
||||||
|
|
||||||
|
|
||||||
class TestNVIDIASafetyAdapter(unittest.TestCase):
|
class TestNVIDIASafetyAdapter(NVIDIASafetyAdapter):
|
||||||
def setUp(self):
|
"""Test implementation that provides the required shield_store."""
|
||||||
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
|
|
||||||
|
|
||||||
# Initialize the adapter
|
def __init__(self, config: NVIDIASafetyConfig, shield_store):
|
||||||
self.config = NVIDIASafetyConfig(
|
super().__init__(config)
|
||||||
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
self.shield_store = shield_store
|
||||||
)
|
|
||||||
self.adapter = NVIDIASafetyAdapter(config=self.config)
|
|
||||||
self.shield_store = AsyncMock()
|
|
||||||
self.adapter.shield_store = self.shield_store
|
|
||||||
|
|
||||||
# Mock the HTTP request methods
|
|
||||||
self.guardrails_post_patcher = patch(
|
|
||||||
"llama_stack.providers.remote.safety.nvidia.nvidia.NeMoGuardrails._guardrails_post"
|
|
||||||
)
|
|
||||||
self.mock_guardrails_post = self.guardrails_post_patcher.start()
|
|
||||||
self.mock_guardrails_post.return_value = {"status": "allowed"}
|
|
||||||
|
|
||||||
def tearDown(self):
|
@pytest.fixture
|
||||||
"""Clean up after each test."""
|
def nvidia_adapter():
|
||||||
self.guardrails_post_patcher.stop()
|
"""Set up the NVIDIASafetyAdapter for testing."""
|
||||||
|
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
# Initialize the adapter
|
||||||
def inject_fixtures(self, run_async):
|
config = NVIDIASafetyConfig(
|
||||||
self.run_async = run_async
|
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
||||||
|
)
|
||||||
|
|
||||||
def _assert_request(
|
# Create a mock shield store that implements the ShieldStore protocol
|
||||||
self,
|
shield_store = AsyncMock()
|
||||||
mock_call: MagicMock,
|
shield_store.get_shield = AsyncMock()
|
||||||
expected_url: str,
|
|
||||||
expected_headers: dict[str, str] | None = None,
|
|
||||||
expected_json: dict[str, Any] | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Helper method to verify request details in mock API calls.
|
|
||||||
|
|
||||||
Args:
|
adapter = TestNVIDIASafetyAdapter(config=config, shield_store=shield_store)
|
||||||
mock_call: The MagicMock object that was called
|
|
||||||
expected_url: The expected URL to which the request was made
|
|
||||||
expected_headers: Optional dictionary of expected request headers
|
|
||||||
expected_json: Optional dictionary of expected JSON payload
|
|
||||||
"""
|
|
||||||
call_args = mock_call.call_args
|
|
||||||
|
|
||||||
# Check URL
|
return adapter
|
||||||
assert call_args[0][0] == expected_url
|
|
||||||
|
|
||||||
# Check headers if provided
|
|
||||||
if expected_headers:
|
|
||||||
for key, value in expected_headers.items():
|
|
||||||
assert call_args[1]["headers"][key] == value
|
|
||||||
|
|
||||||
# Check JSON if provided
|
@pytest.fixture
|
||||||
if expected_json:
|
def mock_guardrails_post():
|
||||||
for key, value in expected_json.items():
|
"""Mock the HTTP request methods."""
|
||||||
if isinstance(value, dict):
|
with patch("llama_stack.providers.remote.safety.nvidia.nvidia.NeMoGuardrails._guardrails_post") as mock_post:
|
||||||
for nested_key, nested_value in value.items():
|
mock_post.return_value = {"status": "allowed"}
|
||||||
assert call_args[1]["json"][key][nested_key] == nested_value
|
yield mock_post
|
||||||
else:
|
|
||||||
assert call_args[1]["json"][key] == value
|
|
||||||
|
|
||||||
def test_register_shield_with_valid_id(self):
|
|
||||||
shield = Shield(
|
|
||||||
provider_id="nvidia",
|
|
||||||
type="shield",
|
|
||||||
identifier="test-shield",
|
|
||||||
provider_resource_id="test-model",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Register the shield
|
def _assert_request(
|
||||||
self.run_async(self.adapter.register_shield(shield))
|
mock_call: MagicMock,
|
||||||
|
expected_url: str,
|
||||||
|
expected_headers: dict[str, str] | None = None,
|
||||||
|
expected_json: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Helper method to verify request details in mock API calls.
|
||||||
|
|
||||||
def test_register_shield_without_id(self):
|
Args:
|
||||||
shield = Shield(
|
mock_call: The MagicMock object that was called
|
||||||
provider_id="nvidia",
|
expected_url: The expected URL to which the request was made
|
||||||
type="shield",
|
expected_headers: Optional dictionary of expected request headers
|
||||||
identifier="test-shield",
|
expected_json: Optional dictionary of expected JSON payload
|
||||||
provider_resource_id="",
|
"""
|
||||||
)
|
call_args = mock_call.call_args
|
||||||
|
|
||||||
# Register the shield should raise a ValueError
|
# Check URL
|
||||||
with self.assertRaises(ValueError):
|
assert call_args[0][0] == expected_url
|
||||||
self.run_async(self.adapter.register_shield(shield))
|
|
||||||
|
|
||||||
def test_run_shield_allowed(self):
|
# Check headers if provided
|
||||||
# Set up the shield
|
if expected_headers:
|
||||||
shield_id = "test-shield"
|
for key, value in expected_headers.items():
|
||||||
shield = Shield(
|
assert call_args[1]["headers"][key] == value
|
||||||
provider_id="nvidia",
|
|
||||||
type="shield",
|
|
||||||
identifier=shield_id,
|
|
||||||
provider_resource_id="test-model",
|
|
||||||
)
|
|
||||||
self.shield_store.get_shield.return_value = shield
|
|
||||||
|
|
||||||
# Mock Guardrails API response
|
# Check JSON if provided
|
||||||
self.mock_guardrails_post.return_value = {"status": "allowed"}
|
if expected_json:
|
||||||
|
for key, value in expected_json.items():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
for nested_key, nested_value in value.items():
|
||||||
|
assert call_args[1]["json"][key][nested_key] == nested_value
|
||||||
|
else:
|
||||||
|
assert call_args[1]["json"][key] == value
|
||||||
|
|
||||||
# Run the shield
|
|
||||||
messages = [
|
|
||||||
UserMessage(role="user", content="Hello, how are you?"),
|
|
||||||
CompletionMessage(
|
|
||||||
role="assistant",
|
|
||||||
content="I'm doing well, thank you for asking!",
|
|
||||||
stop_reason="end_of_message",
|
|
||||||
tool_calls=[],
|
|
||||||
),
|
|
||||||
]
|
|
||||||
result = self.run_async(self.adapter.run_shield(shield_id, messages))
|
|
||||||
|
|
||||||
# Verify the shield store was called
|
async def test_register_shield_with_valid_id(nvidia_adapter):
|
||||||
self.shield_store.get_shield.assert_called_once_with(shield_id)
|
adapter = nvidia_adapter
|
||||||
|
|
||||||
# Verify the Guardrails API was called correctly
|
shield = Shield(
|
||||||
self.mock_guardrails_post.assert_called_once_with(
|
provider_id="nvidia",
|
||||||
path="/v1/guardrail/checks",
|
type=ResourceType.shield,
|
||||||
data={
|
identifier="test-shield",
|
||||||
"model": shield_id,
|
provider_resource_id="test-model",
|
||||||
"messages": [
|
)
|
||||||
{"role": "user", "content": "Hello, how are you?"},
|
|
||||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
# Register the shield
|
||||||
],
|
await adapter.register_shield(shield)
|
||||||
"temperature": 1.0,
|
|
||||||
"top_p": 1,
|
|
||||||
"frequency_penalty": 0,
|
async def test_register_shield_without_id(nvidia_adapter):
|
||||||
"presence_penalty": 0,
|
adapter = nvidia_adapter
|
||||||
"max_tokens": 160,
|
|
||||||
"stream": False,
|
shield = Shield(
|
||||||
"guardrails": {
|
provider_id="nvidia",
|
||||||
"config_id": "self-check",
|
type=ResourceType.shield,
|
||||||
},
|
identifier="test-shield",
|
||||||
|
provider_resource_id="",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register the shield should raise a ValueError
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await adapter.register_shield(shield)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_run_shield_allowed(nvidia_adapter, mock_guardrails_post):
|
||||||
|
adapter = nvidia_adapter
|
||||||
|
|
||||||
|
# Set up the shield
|
||||||
|
shield_id = "test-shield"
|
||||||
|
shield = Shield(
|
||||||
|
provider_id="nvidia",
|
||||||
|
type=ResourceType.shield,
|
||||||
|
identifier=shield_id,
|
||||||
|
provider_resource_id="test-model",
|
||||||
|
)
|
||||||
|
adapter.shield_store.get_shield.return_value = shield
|
||||||
|
|
||||||
|
# Mock Guardrails API response
|
||||||
|
mock_guardrails_post.return_value = {"status": "allowed"}
|
||||||
|
|
||||||
|
# Run the shield
|
||||||
|
messages = [
|
||||||
|
UserMessage(role="user", content="Hello, how are you?"),
|
||||||
|
CompletionMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="I'm doing well, thank you for asking!",
|
||||||
|
stop_reason=StopReason.end_of_message,
|
||||||
|
tool_calls=[],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
result = await adapter.run_shield(shield_id, messages)
|
||||||
|
|
||||||
|
# Verify the shield store was called
|
||||||
|
adapter.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||||
|
|
||||||
|
# Verify the Guardrails API was called correctly
|
||||||
|
mock_guardrails_post.assert_called_once_with(
|
||||||
|
path="/v1/guardrail/checks",
|
||||||
|
data={
|
||||||
|
"model": shield_id,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello, how are you?"},
|
||||||
|
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||||
|
],
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 1,
|
||||||
|
"frequency_penalty": 0,
|
||||||
|
"presence_penalty": 0,
|
||||||
|
"max_tokens": 160,
|
||||||
|
"stream": False,
|
||||||
|
"guardrails": {
|
||||||
|
"config_id": "self-check",
|
||||||
},
|
},
|
||||||
)
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Verify the result
|
# Verify the result
|
||||||
assert isinstance(result, RunShieldResponse)
|
assert isinstance(result, RunShieldResponse)
|
||||||
assert result.violation is None
|
assert result.violation is None
|
||||||
|
|
||||||
def test_run_shield_blocked(self):
|
|
||||||
# Set up the shield
|
|
||||||
shield_id = "test-shield"
|
|
||||||
shield = Shield(
|
|
||||||
provider_id="nvidia",
|
|
||||||
type="shield",
|
|
||||||
identifier=shield_id,
|
|
||||||
provider_resource_id="test-model",
|
|
||||||
)
|
|
||||||
self.shield_store.get_shield.return_value = shield
|
|
||||||
|
|
||||||
# Mock Guardrails API response
|
async def test_run_shield_blocked(nvidia_adapter, mock_guardrails_post):
|
||||||
self.mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}}
|
adapter = nvidia_adapter
|
||||||
|
|
||||||
# Run the shield
|
# Set up the shield
|
||||||
messages = [
|
shield_id = "test-shield"
|
||||||
UserMessage(role="user", content="Hello, how are you?"),
|
shield = Shield(
|
||||||
CompletionMessage(
|
provider_id="nvidia",
|
||||||
role="assistant",
|
type=ResourceType.shield,
|
||||||
content="I'm doing well, thank you for asking!",
|
identifier=shield_id,
|
||||||
stop_reason="end_of_message",
|
provider_resource_id="test-model",
|
||||||
tool_calls=[],
|
)
|
||||||
),
|
adapter.shield_store.get_shield.return_value = shield
|
||||||
]
|
|
||||||
result = self.run_async(self.adapter.run_shield(shield_id, messages))
|
|
||||||
|
|
||||||
# Verify the shield store was called
|
# Mock Guardrails API response
|
||||||
self.shield_store.get_shield.assert_called_once_with(shield_id)
|
mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}}
|
||||||
|
|
||||||
# Verify the Guardrails API was called correctly
|
# Run the shield
|
||||||
self.mock_guardrails_post.assert_called_once_with(
|
messages = [
|
||||||
path="/v1/guardrail/checks",
|
UserMessage(role="user", content="Hello, how are you?"),
|
||||||
data={
|
CompletionMessage(
|
||||||
"model": shield_id,
|
role="assistant",
|
||||||
"messages": [
|
content="I'm doing well, thank you for asking!",
|
||||||
{"role": "user", "content": "Hello, how are you?"},
|
stop_reason=StopReason.end_of_message,
|
||||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
tool_calls=[],
|
||||||
],
|
),
|
||||||
"temperature": 1.0,
|
]
|
||||||
"top_p": 1,
|
result = await adapter.run_shield(shield_id, messages)
|
||||||
"frequency_penalty": 0,
|
|
||||||
"presence_penalty": 0,
|
# Verify the shield store was called
|
||||||
"max_tokens": 160,
|
adapter.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||||
"stream": False,
|
|
||||||
"guardrails": {
|
# Verify the Guardrails API was called correctly
|
||||||
"config_id": "self-check",
|
mock_guardrails_post.assert_called_once_with(
|
||||||
},
|
path="/v1/guardrail/checks",
|
||||||
|
data={
|
||||||
|
"model": shield_id,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello, how are you?"},
|
||||||
|
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||||
|
],
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 1,
|
||||||
|
"frequency_penalty": 0,
|
||||||
|
"presence_penalty": 0,
|
||||||
|
"max_tokens": 160,
|
||||||
|
"stream": False,
|
||||||
|
"guardrails": {
|
||||||
|
"config_id": "self-check",
|
||||||
},
|
},
|
||||||
)
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Verify the result
|
# Verify the result
|
||||||
assert result.violation is not None
|
assert result.violation is not None
|
||||||
assert isinstance(result, RunShieldResponse)
|
assert isinstance(result, RunShieldResponse)
|
||||||
assert result.violation.user_message == "Sorry I cannot do this."
|
assert result.violation.user_message == "Sorry I cannot do this."
|
||||||
assert result.violation.violation_level == ViolationLevel.ERROR
|
assert result.violation.violation_level == ViolationLevel.ERROR
|
||||||
assert result.violation.metadata == {"reason": "harmful_content"}
|
assert result.violation.metadata == {"reason": "harmful_content"}
|
||||||
|
|
||||||
def test_run_shield_not_found(self):
|
|
||||||
# Set up shield store to return None
|
|
||||||
shield_id = "non-existent-shield"
|
|
||||||
self.shield_store.get_shield.return_value = None
|
|
||||||
|
|
||||||
messages = [
|
async def test_run_shield_not_found(nvidia_adapter, mock_guardrails_post):
|
||||||
UserMessage(role="user", content="Hello, how are you?"),
|
adapter = nvidia_adapter
|
||||||
]
|
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
# Set up shield store to return None
|
||||||
self.run_async(self.adapter.run_shield(shield_id, messages))
|
shield_id = "non-existent-shield"
|
||||||
|
adapter.shield_store.get_shield.return_value = None
|
||||||
|
|
||||||
# Verify the shield store was called
|
messages = [
|
||||||
self.shield_store.get_shield.assert_called_once_with(shield_id)
|
UserMessage(role="user", content="Hello, how are you?"),
|
||||||
|
]
|
||||||
|
|
||||||
# Verify the Guardrails API was not called
|
with pytest.raises(ValueError):
|
||||||
self.mock_guardrails_post.assert_not_called()
|
await adapter.run_shield(shield_id, messages)
|
||||||
|
|
||||||
def test_run_shield_http_error(self):
|
# Verify the shield store was called
|
||||||
shield_id = "test-shield"
|
adapter.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||||
shield = Shield(
|
|
||||||
provider_id="nvidia",
|
|
||||||
type="shield",
|
|
||||||
identifier=shield_id,
|
|
||||||
provider_resource_id="test-model",
|
|
||||||
)
|
|
||||||
self.shield_store.get_shield.return_value = shield
|
|
||||||
|
|
||||||
# Mock Guardrails API to raise an exception
|
# Verify the Guardrails API was not called
|
||||||
error_msg = "API Error: 500 Internal Server Error"
|
mock_guardrails_post.assert_not_called()
|
||||||
self.mock_guardrails_post.side_effect = Exception(error_msg)
|
|
||||||
|
|
||||||
# Running the shield should raise an exception
|
|
||||||
messages = [
|
|
||||||
UserMessage(role="user", content="Hello, how are you?"),
|
|
||||||
CompletionMessage(
|
|
||||||
role="assistant",
|
|
||||||
content="I'm doing well, thank you for asking!",
|
|
||||||
stop_reason="end_of_message",
|
|
||||||
tool_calls=[],
|
|
||||||
),
|
|
||||||
]
|
|
||||||
with self.assertRaises(Exception) as context:
|
|
||||||
self.run_async(self.adapter.run_shield(shield_id, messages))
|
|
||||||
|
|
||||||
# Verify the shield store was called
|
async def test_run_shield_http_error(nvidia_adapter, mock_guardrails_post):
|
||||||
self.shield_store.get_shield.assert_called_once_with(shield_id)
|
adapter = nvidia_adapter
|
||||||
|
|
||||||
# Verify the Guardrails API was called correctly
|
shield_id = "test-shield"
|
||||||
self.mock_guardrails_post.assert_called_once_with(
|
shield = Shield(
|
||||||
path="/v1/guardrail/checks",
|
provider_id="nvidia",
|
||||||
data={
|
type=ResourceType.shield,
|
||||||
"model": shield_id,
|
identifier=shield_id,
|
||||||
"messages": [
|
provider_resource_id="test-model",
|
||||||
{"role": "user", "content": "Hello, how are you?"},
|
)
|
||||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
adapter.shield_store.get_shield.return_value = shield
|
||||||
],
|
|
||||||
"temperature": 1.0,
|
# Mock Guardrails API to raise an exception
|
||||||
"top_p": 1,
|
error_msg = "API Error: 500 Internal Server Error"
|
||||||
"frequency_penalty": 0,
|
mock_guardrails_post.side_effect = Exception(error_msg)
|
||||||
"presence_penalty": 0,
|
|
||||||
"max_tokens": 160,
|
# Running the shield should raise an exception
|
||||||
"stream": False,
|
messages = [
|
||||||
"guardrails": {
|
UserMessage(role="user", content="Hello, how are you?"),
|
||||||
"config_id": "self-check",
|
CompletionMessage(
|
||||||
},
|
role="assistant",
|
||||||
|
content="I'm doing well, thank you for asking!",
|
||||||
|
stop_reason=StopReason.end_of_message,
|
||||||
|
tool_calls=[],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
await adapter.run_shield(shield_id, messages)
|
||||||
|
|
||||||
|
# Verify the shield store was called
|
||||||
|
adapter.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||||
|
|
||||||
|
# Verify the Guardrails API was called correctly
|
||||||
|
mock_guardrails_post.assert_called_once_with(
|
||||||
|
path="/v1/guardrail/checks",
|
||||||
|
data={
|
||||||
|
"model": shield_id,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello, how are you?"},
|
||||||
|
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||||
|
],
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 1,
|
||||||
|
"frequency_penalty": 0,
|
||||||
|
"presence_penalty": 0,
|
||||||
|
"max_tokens": 160,
|
||||||
|
"stream": False,
|
||||||
|
"guardrails": {
|
||||||
|
"config_id": "self-check",
|
||||||
},
|
},
|
||||||
)
|
},
|
||||||
# Verify the exception message
|
)
|
||||||
assert error_msg in str(context.exception)
|
# Verify the exception message
|
||||||
|
assert error_msg in str(exc_info.value)
|
||||||
|
|
||||||
def test_init_nemo_guardrails(self):
|
|
||||||
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
|
|
||||||
|
|
||||||
test_config_id = "test-custom-config-id"
|
def test_init_nemo_guardrails():
|
||||||
config = NVIDIASafetyConfig(
|
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
|
||||||
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
|
||||||
config_id=test_config_id,
|
|
||||||
)
|
|
||||||
# Initialize with default parameters
|
|
||||||
test_model = "test-model"
|
|
||||||
guardrails = NeMoGuardrails(config, test_model)
|
|
||||||
|
|
||||||
# Verify the attributes are set correctly
|
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
|
||||||
assert guardrails.config_id == test_config_id
|
|
||||||
assert guardrails.model == test_model
|
|
||||||
assert guardrails.threshold == 0.9 # Default value
|
|
||||||
assert guardrails.temperature == 1.0 # Default value
|
|
||||||
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
|
|
||||||
|
|
||||||
# Initialize with custom parameters
|
test_config_id = "test-custom-config-id"
|
||||||
guardrails = NeMoGuardrails(config, test_model, threshold=0.8, temperature=0.7)
|
config = NVIDIASafetyConfig(
|
||||||
|
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
||||||
|
config_id=test_config_id,
|
||||||
|
)
|
||||||
|
# Initialize with default parameters
|
||||||
|
test_model = "test-model"
|
||||||
|
guardrails = NeMoGuardrails(config, test_model)
|
||||||
|
|
||||||
# Verify the attributes are set correctly
|
# Verify the attributes are set correctly
|
||||||
assert guardrails.config_id == test_config_id
|
assert guardrails.config_id == test_config_id
|
||||||
assert guardrails.model == test_model
|
assert guardrails.model == test_model
|
||||||
assert guardrails.threshold == 0.8
|
assert guardrails.threshold == 0.9 # Default value
|
||||||
assert guardrails.temperature == 0.7
|
assert guardrails.temperature == 1.0 # Default value
|
||||||
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
|
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
|
||||||
|
|
||||||
def test_init_nemo_guardrails_invalid_temperature(self):
|
# Initialize with custom parameters
|
||||||
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
|
guardrails = NeMoGuardrails(config, test_model, threshold=0.8, temperature=0.7)
|
||||||
|
|
||||||
config = NVIDIASafetyConfig(
|
# Verify the attributes are set correctly
|
||||||
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
assert guardrails.config_id == test_config_id
|
||||||
config_id="test-custom-config-id",
|
assert guardrails.model == test_model
|
||||||
)
|
assert guardrails.threshold == 0.8
|
||||||
with self.assertRaises(ValueError):
|
assert guardrails.temperature == 0.7
|
||||||
NeMoGuardrails(config, "test-model", temperature=0)
|
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_nemo_guardrails_invalid_temperature():
|
||||||
|
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
|
||||||
|
|
||||||
|
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
|
||||||
|
|
||||||
|
config = NVIDIASafetyConfig(
|
||||||
|
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
||||||
|
config_id="test-custom-config-id",
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
NeMoGuardrails(config, "test-model", temperature=0)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue