mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 06:28:50 +00:00
Merge branch 'main' into fix-chroma
This commit is contained in:
commit
062c6a419a
76 changed files with 2468 additions and 913 deletions
|
@ -15,6 +15,7 @@ from llama_stack.apis.models import Model, ModelType
|
|||
from llama_stack.apis.shields.shields import Shield
|
||||
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.distribution.datatypes import RegistryEntrySource
|
||||
from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable
|
||||
from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable
|
||||
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
|
||||
|
@ -45,6 +46,30 @@ class InferenceImpl(Impl):
|
|||
async def unregister_model(self, model_id: str):
|
||||
return model_id
|
||||
|
||||
async def should_refresh_models(self):
|
||||
return False
|
||||
|
||||
async def list_models(self):
|
||||
return [
|
||||
Model(
|
||||
identifier="provider-model-1",
|
||||
provider_resource_id="provider-model-1",
|
||||
provider_id="test_provider",
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
Model(
|
||||
identifier="provider-model-2",
|
||||
provider_resource_id="provider-model-2",
|
||||
provider_id="test_provider",
|
||||
metadata={"embedding_dimension": 512},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
]
|
||||
|
||||
async def shutdown(self):
|
||||
pass
|
||||
|
||||
|
||||
class SafetyImpl(Impl):
|
||||
def __init__(self):
|
||||
|
@ -378,3 +403,170 @@ async def test_models_fallback_lookup_behavior(cached_disk_dist_registry):
|
|||
raise AssertionError("Should have raised ValueError for non-existent model")
|
||||
except ValueError as e:
|
||||
assert "not found" in str(e)
|
||||
|
||||
|
||||
async def test_models_source_tracking_default(cached_disk_dist_registry):
|
||||
"""Test that models registered via register_model get default source."""
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register model via register_model (should get default source)
|
||||
await table.register_model(model_id="user-model", provider_id="test_provider")
|
||||
|
||||
models = await table.list_models()
|
||||
assert len(models.data) == 1
|
||||
model = models.data[0]
|
||||
assert model.source == RegistryEntrySource.via_register_api
|
||||
assert model.identifier == "test_provider/user-model"
|
||||
|
||||
# Cleanup
|
||||
await table.shutdown()
|
||||
|
||||
|
||||
async def test_models_source_tracking_provider(cached_disk_dist_registry):
|
||||
"""Test that models registered via update_registered_models get provider source."""
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Simulate provider refresh by calling update_registered_models
|
||||
provider_models = [
|
||||
Model(
|
||||
identifier="provider-model-1",
|
||||
provider_resource_id="provider-model-1",
|
||||
provider_id="test_provider",
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
Model(
|
||||
identifier="provider-model-2",
|
||||
provider_resource_id="provider-model-2",
|
||||
provider_id="test_provider",
|
||||
metadata={"embedding_dimension": 512},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
]
|
||||
await table.update_registered_models("test_provider", provider_models)
|
||||
|
||||
models = await table.list_models()
|
||||
assert len(models.data) == 2
|
||||
|
||||
# All models should have provider source
|
||||
for model in models.data:
|
||||
assert model.source == RegistryEntrySource.listed_from_provider
|
||||
assert model.provider_id == "test_provider"
|
||||
|
||||
# Cleanup
|
||||
await table.shutdown()
|
||||
|
||||
|
||||
async def test_models_source_interaction_preserves_default(cached_disk_dist_registry):
|
||||
"""Test that provider refresh preserves user-registered models with default source."""
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# First register a user model with same provider_resource_id as provider will later provide
|
||||
await table.register_model(
|
||||
model_id="my-custom-alias", provider_model_id="provider-model-1", provider_id="test_provider"
|
||||
)
|
||||
|
||||
# Verify user model is registered with default source
|
||||
models = await table.list_models()
|
||||
assert len(models.data) == 1
|
||||
user_model = models.data[0]
|
||||
assert user_model.source == RegistryEntrySource.via_register_api
|
||||
assert user_model.identifier == "my-custom-alias"
|
||||
assert user_model.provider_resource_id == "provider-model-1"
|
||||
|
||||
# Now simulate provider refresh
|
||||
provider_models = [
|
||||
Model(
|
||||
identifier="provider-model-1",
|
||||
provider_resource_id="provider-model-1",
|
||||
provider_id="test_provider",
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
Model(
|
||||
identifier="different-model",
|
||||
provider_resource_id="different-model",
|
||||
provider_id="test_provider",
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
]
|
||||
await table.update_registered_models("test_provider", provider_models)
|
||||
|
||||
# Verify user model with alias is preserved, but provider added new model
|
||||
models = await table.list_models()
|
||||
assert len(models.data) == 2
|
||||
|
||||
# Find the user model and provider model
|
||||
user_model = next((m for m in models.data if m.identifier == "my-custom-alias"), None)
|
||||
provider_model = next((m for m in models.data if m.identifier == "different-model"), None)
|
||||
|
||||
assert user_model is not None
|
||||
assert user_model.source == RegistryEntrySource.via_register_api
|
||||
assert user_model.provider_resource_id == "provider-model-1"
|
||||
|
||||
assert provider_model is not None
|
||||
assert provider_model.source == RegistryEntrySource.listed_from_provider
|
||||
assert provider_model.provider_resource_id == "different-model"
|
||||
|
||||
# Cleanup
|
||||
await table.shutdown()
|
||||
|
||||
|
||||
async def test_models_source_interaction_cleanup_provider_models(cached_disk_dist_registry):
|
||||
"""Test that provider refresh removes old provider models but keeps default ones."""
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register a user model
|
||||
await table.register_model(model_id="user-model", provider_id="test_provider")
|
||||
|
||||
# Add some provider models
|
||||
provider_models_v1 = [
|
||||
Model(
|
||||
identifier="provider-model-old",
|
||||
provider_resource_id="provider-model-old",
|
||||
provider_id="test_provider",
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
]
|
||||
await table.update_registered_models("test_provider", provider_models_v1)
|
||||
|
||||
# Verify we have both user and provider models
|
||||
models = await table.list_models()
|
||||
assert len(models.data) == 2
|
||||
|
||||
# Now update with new provider models (should remove old provider models)
|
||||
provider_models_v2 = [
|
||||
Model(
|
||||
identifier="provider-model-new",
|
||||
provider_resource_id="provider-model-new",
|
||||
provider_id="test_provider",
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
]
|
||||
await table.update_registered_models("test_provider", provider_models_v2)
|
||||
|
||||
# Should have user model + new provider model, old provider model gone
|
||||
models = await table.list_models()
|
||||
assert len(models.data) == 2
|
||||
|
||||
identifiers = {m.identifier for m in models.data}
|
||||
assert "test_provider/user-model" in identifiers # User model preserved
|
||||
assert "provider-model-new" in identifiers # New provider model (uses provider's identifier)
|
||||
assert "provider-model-old" not in identifiers # Old provider model removed
|
||||
|
||||
# Verify sources are correct
|
||||
user_model = next((m for m in models.data if m.identifier == "test_provider/user-model"), None)
|
||||
provider_model = next((m for m in models.data if m.identifier == "provider-model-new"), None)
|
||||
|
||||
assert user_model.source == RegistryEntrySource.via_register_api
|
||||
assert provider_model.source == RegistryEntrySource.listed_from_provider
|
||||
|
||||
# Cleanup
|
||||
await table.shutdown()
|
||||
|
|
|
@ -5,321 +5,353 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.inference import CompletionMessage, UserMessage
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.safety import RunShieldResponse, ViolationLevel
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.models.llama.datatypes import StopReason
|
||||
from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig
|
||||
from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter
|
||||
|
||||
|
||||
class TestNVIDIASafetyAdapter(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
|
||||
class TestNVIDIASafetyAdapter(NVIDIASafetyAdapter):
|
||||
"""Test implementation that provides the required shield_store."""
|
||||
|
||||
# Initialize the adapter
|
||||
self.config = NVIDIASafetyConfig(
|
||||
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
||||
)
|
||||
self.adapter = NVIDIASafetyAdapter(config=self.config)
|
||||
self.shield_store = AsyncMock()
|
||||
self.adapter.shield_store = self.shield_store
|
||||
def __init__(self, config: NVIDIASafetyConfig, shield_store):
|
||||
super().__init__(config)
|
||||
self.shield_store = shield_store
|
||||
|
||||
# Mock the HTTP request methods
|
||||
self.guardrails_post_patcher = patch(
|
||||
"llama_stack.providers.remote.safety.nvidia.nvidia.NeMoGuardrails._guardrails_post"
|
||||
)
|
||||
self.mock_guardrails_post = self.guardrails_post_patcher.start()
|
||||
self.mock_guardrails_post.return_value = {"status": "allowed"}
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up after each test."""
|
||||
self.guardrails_post_patcher.stop()
|
||||
@pytest.fixture
|
||||
def nvidia_adapter():
|
||||
"""Set up the NVIDIASafetyAdapter for testing."""
|
||||
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, run_async):
|
||||
self.run_async = run_async
|
||||
# Initialize the adapter
|
||||
config = NVIDIASafetyConfig(
|
||||
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
||||
)
|
||||
|
||||
def _assert_request(
|
||||
self,
|
||||
mock_call: MagicMock,
|
||||
expected_url: str,
|
||||
expected_headers: dict[str, str] | None = None,
|
||||
expected_json: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Helper method to verify request details in mock API calls.
|
||||
# Create a mock shield store that implements the ShieldStore protocol
|
||||
shield_store = AsyncMock()
|
||||
shield_store.get_shield = AsyncMock()
|
||||
|
||||
Args:
|
||||
mock_call: The MagicMock object that was called
|
||||
expected_url: The expected URL to which the request was made
|
||||
expected_headers: Optional dictionary of expected request headers
|
||||
expected_json: Optional dictionary of expected JSON payload
|
||||
"""
|
||||
call_args = mock_call.call_args
|
||||
adapter = TestNVIDIASafetyAdapter(config=config, shield_store=shield_store)
|
||||
|
||||
# Check URL
|
||||
assert call_args[0][0] == expected_url
|
||||
return adapter
|
||||
|
||||
# Check headers if provided
|
||||
if expected_headers:
|
||||
for key, value in expected_headers.items():
|
||||
assert call_args[1]["headers"][key] == value
|
||||
|
||||
# Check JSON if provided
|
||||
if expected_json:
|
||||
for key, value in expected_json.items():
|
||||
if isinstance(value, dict):
|
||||
for nested_key, nested_value in value.items():
|
||||
assert call_args[1]["json"][key][nested_key] == nested_value
|
||||
else:
|
||||
assert call_args[1]["json"][key] == value
|
||||
@pytest.fixture
|
||||
def mock_guardrails_post():
|
||||
"""Mock the HTTP request methods."""
|
||||
with patch("llama_stack.providers.remote.safety.nvidia.nvidia.NeMoGuardrails._guardrails_post") as mock_post:
|
||||
mock_post.return_value = {"status": "allowed"}
|
||||
yield mock_post
|
||||
|
||||
def test_register_shield_with_valid_id(self):
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type="shield",
|
||||
identifier="test-shield",
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
|
||||
# Register the shield
|
||||
self.run_async(self.adapter.register_shield(shield))
|
||||
def _assert_request(
|
||||
mock_call: MagicMock,
|
||||
expected_url: str,
|
||||
expected_headers: dict[str, str] | None = None,
|
||||
expected_json: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Helper method to verify request details in mock API calls.
|
||||
|
||||
def test_register_shield_without_id(self):
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type="shield",
|
||||
identifier="test-shield",
|
||||
provider_resource_id="",
|
||||
)
|
||||
Args:
|
||||
mock_call: The MagicMock object that was called
|
||||
expected_url: The expected URL to which the request was made
|
||||
expected_headers: Optional dictionary of expected request headers
|
||||
expected_json: Optional dictionary of expected JSON payload
|
||||
"""
|
||||
call_args = mock_call.call_args
|
||||
|
||||
# Register the shield should raise a ValueError
|
||||
with self.assertRaises(ValueError):
|
||||
self.run_async(self.adapter.register_shield(shield))
|
||||
# Check URL
|
||||
assert call_args[0][0] == expected_url
|
||||
|
||||
def test_run_shield_allowed(self):
|
||||
# Set up the shield
|
||||
shield_id = "test-shield"
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type="shield",
|
||||
identifier=shield_id,
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
self.shield_store.get_shield.return_value = shield
|
||||
# Check headers if provided
|
||||
if expected_headers:
|
||||
for key, value in expected_headers.items():
|
||||
assert call_args[1]["headers"][key] == value
|
||||
|
||||
# Mock Guardrails API response
|
||||
self.mock_guardrails_post.return_value = {"status": "allowed"}
|
||||
# Check JSON if provided
|
||||
if expected_json:
|
||||
for key, value in expected_json.items():
|
||||
if isinstance(value, dict):
|
||||
for nested_key, nested_value in value.items():
|
||||
assert call_args[1]["json"][key][nested_key] == nested_value
|
||||
else:
|
||||
assert call_args[1]["json"][key] == value
|
||||
|
||||
# Run the shield
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason="end_of_message",
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
result = self.run_async(self.adapter.run_shield(shield_id, messages))
|
||||
|
||||
# Verify the shield store was called
|
||||
self.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
async def test_register_shield_with_valid_id(nvidia_adapter):
|
||||
adapter = nvidia_adapter
|
||||
|
||||
# Verify the Guardrails API was called correctly
|
||||
self.mock_guardrails_post.assert_called_once_with(
|
||||
path="/v1/guardrail/checks",
|
||||
data={
|
||||
"model": shield_id,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||
],
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"max_tokens": 160,
|
||||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": "self-check",
|
||||
},
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type=ResourceType.shield,
|
||||
identifier="test-shield",
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
|
||||
# Register the shield
|
||||
await adapter.register_shield(shield)
|
||||
|
||||
|
||||
async def test_register_shield_without_id(nvidia_adapter):
|
||||
adapter = nvidia_adapter
|
||||
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type=ResourceType.shield,
|
||||
identifier="test-shield",
|
||||
provider_resource_id="",
|
||||
)
|
||||
|
||||
# Register the shield should raise a ValueError
|
||||
with pytest.raises(ValueError):
|
||||
await adapter.register_shield(shield)
|
||||
|
||||
|
||||
async def test_run_shield_allowed(nvidia_adapter, mock_guardrails_post):
|
||||
adapter = nvidia_adapter
|
||||
|
||||
# Set up the shield
|
||||
shield_id = "test-shield"
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type=ResourceType.shield,
|
||||
identifier=shield_id,
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
adapter.shield_store.get_shield.return_value = shield
|
||||
|
||||
# Mock Guardrails API response
|
||||
mock_guardrails_post.return_value = {"status": "allowed"}
|
||||
|
||||
# Run the shield
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason=StopReason.end_of_message,
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
result = await adapter.run_shield(shield_id, messages)
|
||||
|
||||
# Verify the shield store was called
|
||||
adapter.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
|
||||
# Verify the Guardrails API was called correctly
|
||||
mock_guardrails_post.assert_called_once_with(
|
||||
path="/v1/guardrail/checks",
|
||||
data={
|
||||
"model": shield_id,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||
],
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"max_tokens": 160,
|
||||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": "self-check",
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, RunShieldResponse)
|
||||
assert result.violation is None
|
||||
# Verify the result
|
||||
assert isinstance(result, RunShieldResponse)
|
||||
assert result.violation is None
|
||||
|
||||
def test_run_shield_blocked(self):
|
||||
# Set up the shield
|
||||
shield_id = "test-shield"
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type="shield",
|
||||
identifier=shield_id,
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
self.shield_store.get_shield.return_value = shield
|
||||
|
||||
# Mock Guardrails API response
|
||||
self.mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}}
|
||||
async def test_run_shield_blocked(nvidia_adapter, mock_guardrails_post):
|
||||
adapter = nvidia_adapter
|
||||
|
||||
# Run the shield
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason="end_of_message",
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
result = self.run_async(self.adapter.run_shield(shield_id, messages))
|
||||
# Set up the shield
|
||||
shield_id = "test-shield"
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type=ResourceType.shield,
|
||||
identifier=shield_id,
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
adapter.shield_store.get_shield.return_value = shield
|
||||
|
||||
# Verify the shield store was called
|
||||
self.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
# Mock Guardrails API response
|
||||
mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}}
|
||||
|
||||
# Verify the Guardrails API was called correctly
|
||||
self.mock_guardrails_post.assert_called_once_with(
|
||||
path="/v1/guardrail/checks",
|
||||
data={
|
||||
"model": shield_id,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||
],
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"max_tokens": 160,
|
||||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": "self-check",
|
||||
},
|
||||
# Run the shield
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason=StopReason.end_of_message,
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
result = await adapter.run_shield(shield_id, messages)
|
||||
|
||||
# Verify the shield store was called
|
||||
adapter.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
|
||||
# Verify the Guardrails API was called correctly
|
||||
mock_guardrails_post.assert_called_once_with(
|
||||
path="/v1/guardrail/checks",
|
||||
data={
|
||||
"model": shield_id,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||
],
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"max_tokens": 160,
|
||||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": "self-check",
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result.violation is not None
|
||||
assert isinstance(result, RunShieldResponse)
|
||||
assert result.violation.user_message == "Sorry I cannot do this."
|
||||
assert result.violation.violation_level == ViolationLevel.ERROR
|
||||
assert result.violation.metadata == {"reason": "harmful_content"}
|
||||
# Verify the result
|
||||
assert result.violation is not None
|
||||
assert isinstance(result, RunShieldResponse)
|
||||
assert result.violation.user_message == "Sorry I cannot do this."
|
||||
assert result.violation.violation_level == ViolationLevel.ERROR
|
||||
assert result.violation.metadata == {"reason": "harmful_content"}
|
||||
|
||||
def test_run_shield_not_found(self):
|
||||
# Set up shield store to return None
|
||||
shield_id = "non-existent-shield"
|
||||
self.shield_store.get_shield.return_value = None
|
||||
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
]
|
||||
async def test_run_shield_not_found(nvidia_adapter, mock_guardrails_post):
|
||||
adapter = nvidia_adapter
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
self.run_async(self.adapter.run_shield(shield_id, messages))
|
||||
# Set up shield store to return None
|
||||
shield_id = "non-existent-shield"
|
||||
adapter.shield_store.get_shield.return_value = None
|
||||
|
||||
# Verify the shield store was called
|
||||
self.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
]
|
||||
|
||||
# Verify the Guardrails API was not called
|
||||
self.mock_guardrails_post.assert_not_called()
|
||||
with pytest.raises(ValueError):
|
||||
await adapter.run_shield(shield_id, messages)
|
||||
|
||||
def test_run_shield_http_error(self):
|
||||
shield_id = "test-shield"
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type="shield",
|
||||
identifier=shield_id,
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
self.shield_store.get_shield.return_value = shield
|
||||
# Verify the shield store was called
|
||||
adapter.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
|
||||
# Mock Guardrails API to raise an exception
|
||||
error_msg = "API Error: 500 Internal Server Error"
|
||||
self.mock_guardrails_post.side_effect = Exception(error_msg)
|
||||
# Verify the Guardrails API was not called
|
||||
mock_guardrails_post.assert_not_called()
|
||||
|
||||
# Running the shield should raise an exception
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason="end_of_message",
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
with self.assertRaises(Exception) as context:
|
||||
self.run_async(self.adapter.run_shield(shield_id, messages))
|
||||
|
||||
# Verify the shield store was called
|
||||
self.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
async def test_run_shield_http_error(nvidia_adapter, mock_guardrails_post):
|
||||
adapter = nvidia_adapter
|
||||
|
||||
# Verify the Guardrails API was called correctly
|
||||
self.mock_guardrails_post.assert_called_once_with(
|
||||
path="/v1/guardrail/checks",
|
||||
data={
|
||||
"model": shield_id,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||
],
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"max_tokens": 160,
|
||||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": "self-check",
|
||||
},
|
||||
shield_id = "test-shield"
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type=ResourceType.shield,
|
||||
identifier=shield_id,
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
adapter.shield_store.get_shield.return_value = shield
|
||||
|
||||
# Mock Guardrails API to raise an exception
|
||||
error_msg = "API Error: 500 Internal Server Error"
|
||||
mock_guardrails_post.side_effect = Exception(error_msg)
|
||||
|
||||
# Running the shield should raise an exception
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason=StopReason.end_of_message,
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await adapter.run_shield(shield_id, messages)
|
||||
|
||||
# Verify the shield store was called
|
||||
adapter.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
|
||||
# Verify the Guardrails API was called correctly
|
||||
mock_guardrails_post.assert_called_once_with(
|
||||
path="/v1/guardrail/checks",
|
||||
data={
|
||||
"model": shield_id,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||
],
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"max_tokens": 160,
|
||||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": "self-check",
|
||||
},
|
||||
)
|
||||
# Verify the exception message
|
||||
assert error_msg in str(context.exception)
|
||||
},
|
||||
)
|
||||
# Verify the exception message
|
||||
assert error_msg in str(exc_info.value)
|
||||
|
||||
def test_init_nemo_guardrails(self):
|
||||
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
|
||||
|
||||
test_config_id = "test-custom-config-id"
|
||||
config = NVIDIASafetyConfig(
|
||||
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
||||
config_id=test_config_id,
|
||||
)
|
||||
# Initialize with default parameters
|
||||
test_model = "test-model"
|
||||
guardrails = NeMoGuardrails(config, test_model)
|
||||
def test_init_nemo_guardrails():
|
||||
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
|
||||
|
||||
# Verify the attributes are set correctly
|
||||
assert guardrails.config_id == test_config_id
|
||||
assert guardrails.model == test_model
|
||||
assert guardrails.threshold == 0.9 # Default value
|
||||
assert guardrails.temperature == 1.0 # Default value
|
||||
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
|
||||
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
|
||||
|
||||
# Initialize with custom parameters
|
||||
guardrails = NeMoGuardrails(config, test_model, threshold=0.8, temperature=0.7)
|
||||
test_config_id = "test-custom-config-id"
|
||||
config = NVIDIASafetyConfig(
|
||||
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
||||
config_id=test_config_id,
|
||||
)
|
||||
# Initialize with default parameters
|
||||
test_model = "test-model"
|
||||
guardrails = NeMoGuardrails(config, test_model)
|
||||
|
||||
# Verify the attributes are set correctly
|
||||
assert guardrails.config_id == test_config_id
|
||||
assert guardrails.model == test_model
|
||||
assert guardrails.threshold == 0.8
|
||||
assert guardrails.temperature == 0.7
|
||||
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
|
||||
# Verify the attributes are set correctly
|
||||
assert guardrails.config_id == test_config_id
|
||||
assert guardrails.model == test_model
|
||||
assert guardrails.threshold == 0.9 # Default value
|
||||
assert guardrails.temperature == 1.0 # Default value
|
||||
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
|
||||
|
||||
def test_init_nemo_guardrails_invalid_temperature(self):
|
||||
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
|
||||
# Initialize with custom parameters
|
||||
guardrails = NeMoGuardrails(config, test_model, threshold=0.8, temperature=0.7)
|
||||
|
||||
config = NVIDIASafetyConfig(
|
||||
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
||||
config_id="test-custom-config-id",
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
NeMoGuardrails(config, "test-model", temperature=0)
|
||||
# Verify the attributes are set correctly
|
||||
assert guardrails.config_id == test_config_id
|
||||
assert guardrails.model == test_model
|
||||
assert guardrails.threshold == 0.8
|
||||
assert guardrails.temperature == 0.7
|
||||
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
|
||||
|
||||
|
||||
def test_init_nemo_guardrails_invalid_temperature():
|
||||
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
|
||||
|
||||
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
|
||||
|
||||
config = NVIDIASafetyConfig(
|
||||
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
||||
config_id="test-custom-config-id",
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
NeMoGuardrails(config, "test-model", temperature=0)
|
||||
|
|
|
@ -19,7 +19,8 @@ from llama_stack.distribution.datatypes import (
|
|||
OAuth2JWKSConfig,
|
||||
OAuth2TokenAuthConfig,
|
||||
)
|
||||
from llama_stack.distribution.server.auth import AuthenticationMiddleware
|
||||
from llama_stack.distribution.request_headers import User
|
||||
from llama_stack.distribution.server.auth import AuthenticationMiddleware, _has_required_scope
|
||||
from llama_stack.distribution.server.auth_providers import (
|
||||
get_attributes_from_claims,
|
||||
)
|
||||
|
@ -73,7 +74,7 @@ def http_app(mock_auth_endpoint):
|
|||
),
|
||||
access_policy=[],
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
|
@ -111,7 +112,50 @@ def mock_http_middleware(mock_auth_endpoint):
|
|||
),
|
||||
access_policy=[],
|
||||
)
|
||||
return AuthenticationMiddleware(mock_app, auth_config), mock_app
|
||||
return AuthenticationMiddleware(mock_app, auth_config, {}), mock_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_impls():
|
||||
"""Mock implementations for scope testing"""
|
||||
return {}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def scope_middleware_with_mocks(mock_auth_endpoint):
|
||||
"""Create AuthenticationMiddleware with mocked route implementations"""
|
||||
mock_app = AsyncMock()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_config=CustomAuthConfig(
|
||||
type=AuthProviderType.CUSTOM,
|
||||
endpoint=mock_auth_endpoint,
|
||||
),
|
||||
access_policy=[],
|
||||
)
|
||||
middleware = AuthenticationMiddleware(mock_app, auth_config, {})
|
||||
|
||||
# Mock the route_impls to simulate finding routes with required scopes
|
||||
from llama_stack.schema_utils import WebMethod
|
||||
|
||||
scoped_webmethod = WebMethod(route="/test/scoped", method="POST", required_scope="test.read")
|
||||
|
||||
public_webmethod = WebMethod(route="/test/public", method="GET")
|
||||
|
||||
# Mock the route finding logic
|
||||
def mock_find_matching_route(method, path, route_impls):
|
||||
if method == "POST" and path == "/test/scoped":
|
||||
return None, {}, "/test/scoped", scoped_webmethod
|
||||
elif method == "GET" and path == "/test/public":
|
||||
return None, {}, "/test/public", public_webmethod
|
||||
else:
|
||||
raise ValueError("No matching route")
|
||||
|
||||
import llama_stack.distribution.server.auth
|
||||
|
||||
llama_stack.distribution.server.auth.find_matching_route = mock_find_matching_route
|
||||
llama_stack.distribution.server.auth.initialize_route_impls = lambda impls: {}
|
||||
|
||||
return middleware, mock_app
|
||||
|
||||
|
||||
async def mock_post_success(*args, **kwargs):
|
||||
|
@ -138,6 +182,36 @@ async def mock_post_exception(*args, **kwargs):
|
|||
raise Exception("Connection error")
|
||||
|
||||
|
||||
async def mock_post_success_with_scope(*args, **kwargs):
|
||||
"""Mock auth response for user with test.read scope"""
|
||||
return MockResponse(
|
||||
200,
|
||||
{
|
||||
"message": "Authentication successful",
|
||||
"principal": "test-user",
|
||||
"attributes": {
|
||||
"scopes": ["test.read", "other.scope"],
|
||||
"roles": ["user"],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def mock_post_success_no_scope(*args, **kwargs):
|
||||
"""Mock auth response for user without required scope"""
|
||||
return MockResponse(
|
||||
200,
|
||||
{
|
||||
"message": "Authentication successful",
|
||||
"principal": "test-user",
|
||||
"attributes": {
|
||||
"scopes": ["other.scope"],
|
||||
"roles": ["user"],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# HTTP Endpoint Tests
|
||||
def test_missing_auth_header(http_client):
|
||||
response = http_client.get("/test")
|
||||
|
@ -252,7 +326,7 @@ def oauth2_app():
|
|||
),
|
||||
access_policy=[],
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
|
@ -351,7 +425,7 @@ def oauth2_app_with_jwks_token():
|
|||
),
|
||||
access_policy=[],
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
|
@ -442,7 +516,7 @@ def introspection_app(mock_introspection_endpoint):
|
|||
),
|
||||
access_policy=[],
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
|
@ -472,7 +546,7 @@ def introspection_app_with_custom_mapping(mock_introspection_endpoint):
|
|||
),
|
||||
access_policy=[],
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
|
@ -581,3 +655,122 @@ def test_valid_introspection_with_custom_mapping_authentication(
|
|||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Authentication successful"}
|
||||
|
||||
|
||||
# Scope-based authorization tests
|
||||
@patch("httpx.AsyncClient.post", new=mock_post_success_with_scope)
|
||||
async def test_scope_authorization_success(scope_middleware_with_mocks, valid_api_key):
|
||||
"""Test that user with required scope can access protected endpoint"""
|
||||
middleware, mock_app = scope_middleware_with_mocks
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
scope = {
|
||||
"type": "http",
|
||||
"path": "/test/scoped",
|
||||
"method": "POST",
|
||||
"headers": [(b"authorization", f"Bearer {valid_api_key}".encode())],
|
||||
}
|
||||
|
||||
await middleware(scope, mock_receive, mock_send)
|
||||
|
||||
# Should call the downstream app (no 403 error sent)
|
||||
mock_app.assert_called_once_with(scope, mock_receive, mock_send)
|
||||
mock_send.assert_not_called()
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_post_success_no_scope)
|
||||
async def test_scope_authorization_denied(scope_middleware_with_mocks, valid_api_key):
|
||||
"""Test that user without required scope gets 403 access denied"""
|
||||
middleware, mock_app = scope_middleware_with_mocks
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
scope = {
|
||||
"type": "http",
|
||||
"path": "/test/scoped",
|
||||
"method": "POST",
|
||||
"headers": [(b"authorization", f"Bearer {valid_api_key}".encode())],
|
||||
}
|
||||
|
||||
await middleware(scope, mock_receive, mock_send)
|
||||
|
||||
# Should send 403 error, not call downstream app
|
||||
mock_app.assert_not_called()
|
||||
assert mock_send.call_count == 2 # start + body
|
||||
|
||||
# Check the response
|
||||
start_call = mock_send.call_args_list[0][0][0]
|
||||
assert start_call["status"] == 403
|
||||
|
||||
body_call = mock_send.call_args_list[1][0][0]
|
||||
body_text = body_call["body"].decode()
|
||||
assert "Access denied" in body_text
|
||||
assert "test.read" in body_text
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_post_success_no_scope)
|
||||
async def test_public_endpoint_no_scope_required(scope_middleware_with_mocks, valid_api_key):
|
||||
"""Test that public endpoints work without specific scopes"""
|
||||
middleware, mock_app = scope_middleware_with_mocks
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
scope = {
|
||||
"type": "http",
|
||||
"path": "/test/public",
|
||||
"method": "GET",
|
||||
"headers": [(b"authorization", f"Bearer {valid_api_key}".encode())],
|
||||
}
|
||||
|
||||
await middleware(scope, mock_receive, mock_send)
|
||||
|
||||
# Should call the downstream app (no error)
|
||||
mock_app.assert_called_once_with(scope, mock_receive, mock_send)
|
||||
mock_send.assert_not_called()
|
||||
|
||||
|
||||
async def test_scope_authorization_no_auth_disabled(scope_middleware_with_mocks):
|
||||
"""Test that when auth is disabled (no user), scope checks are bypassed"""
|
||||
middleware, mock_app = scope_middleware_with_mocks
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
scope = {
|
||||
"type": "http",
|
||||
"path": "/test/scoped",
|
||||
"method": "POST",
|
||||
"headers": [], # No authorization header
|
||||
}
|
||||
|
||||
await middleware(scope, mock_receive, mock_send)
|
||||
|
||||
# Should send 401 auth error, not call downstream app
|
||||
mock_app.assert_not_called()
|
||||
assert mock_send.call_count == 2 # start + body
|
||||
|
||||
# Check the response
|
||||
start_call = mock_send.call_args_list[0][0][0]
|
||||
assert start_call["status"] == 401
|
||||
|
||||
body_call = mock_send.call_args_list[1][0][0]
|
||||
body_text = body_call["body"].decode()
|
||||
assert "Authentication required" in body_text
|
||||
|
||||
|
||||
def test_has_required_scope_function():
|
||||
"""Test the _has_required_scope function directly"""
|
||||
# Test user with required scope
|
||||
user_with_scope = User(principal="test-user", attributes={"scopes": ["test.read", "other.scope"]})
|
||||
assert _has_required_scope("test.read", user_with_scope)
|
||||
|
||||
# Test user without required scope
|
||||
user_without_scope = User(principal="test-user", attributes={"scopes": ["other.scope"]})
|
||||
assert not _has_required_scope("test.read", user_without_scope)
|
||||
|
||||
# Test user with no scopes attribute
|
||||
user_no_scopes = User(principal="test-user", attributes={})
|
||||
assert not _has_required_scope("test.read", user_no_scopes)
|
||||
|
||||
# Test no user (auth disabled)
|
||||
assert _has_required_scope("test.read", None)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue