Merge branch 'main' into fix-chroma

This commit is contained in:
Francisco Arceo 2025-07-24 21:48:51 -04:00 committed by GitHub
commit 062c6a419a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
76 changed files with 2468 additions and 913 deletions

View file

@ -1,3 +0,0 @@
# Ollama external provider for Llama Stack
Template code to create a new external provider for Llama Stack.

View file

@ -1,7 +0,0 @@
adapter:
adapter_type: custom_ollama
pip_packages: ["ollama", "aiohttp", "tests/external-provider/llama-stack-provider-ollama"]
config_class: llama_stack_provider_ollama.config.OllamaImplConfig
module: llama_stack_provider_ollama
api_dependencies: []
optional_api_dependencies: []

View file

@ -1,43 +0,0 @@
[project]
dependencies = [
"llama-stack",
"pydantic",
"ollama",
"aiohttp",
"aiosqlite",
"autoevals",
"chardet",
"chromadb-client",
"datasets",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
]
name = "llama-stack-provider-ollama"
version = "0.1.0"
description = "External provider for Ollama using the Llama Stack API"
readme = "README.md"
requires-python = ">=3.12"

View file

@ -1,124 +0,0 @@
version: 2
image_name: ollama
apis:
- agents
- datasetio
- eval
- inference
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: ollama
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:=http://localhost:11434}
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
metadata_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/faiss_store.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config: {}
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
agents_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/agents_store.db
responses_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/responses_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200b}"
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/trace_store.db
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
metadata_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/huggingface_datasetio.db
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config:
metadata_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/huggingface_datasetio.db
- provider_id: localfs
provider_type: inline::localfs
config:
metadata_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/localfs_datasetio.db
scoring:
- provider_id: basic
provider_type: inline::basic
config: {}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:+}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:+}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
- provider_id: wolfram-alpha
provider_type: remote::wolfram-alpha
config:
api_key: ${env.WOLFRAM_ALPHA_API_KEY:+}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/registry.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: custom_ollama
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: custom_ollama
provider_model_id: all-minilm:l6-v2
model_type: embedding
shields: []
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
- toolgroup_id: builtin::wolfram_alpha
provider_id: wolfram-alpha
server:
port: 8321
external_providers_dir: ~/.llama/providers.d

View file

@ -2,8 +2,9 @@ version: '2'
distribution_spec:
description: Custom distro for CI tests
providers:
inference:
- remote::custom_ollama
image_type: container
weather:
- remote::kaze
image_type: venv
image_name: ci-test
external_providers_dir: ~/.llama/providers.d
external_apis_dir: ~/.llama/apis.d

6
tests/external/kaze.yaml vendored Normal file
View file

@ -0,0 +1,6 @@
adapter:
adapter_type: kaze
pip_packages: ["tests/external/llama-stack-provider-kaze"]
config_class: llama_stack_provider_kaze.config.KazeProviderConfig
module: llama_stack_provider_kaze
optional_api_dependencies: []

View file

@ -0,0 +1,15 @@
[project]
name = "llama-stack-api-weather"
version = "0.1.0"
description = "Weather API for Llama Stack"
readme = "README.md"
requires-python = ">=3.10"
dependencies = ["llama-stack", "pydantic"]
[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find]
where = ["src"]
include = ["llama_stack_api_weather", "llama_stack_api_weather.*"]

View file

@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""Weather API for Llama Stack."""
from .weather import WeatherProvider, available_providers
__all__ = ["WeatherProvider", "available_providers"]

View file

@ -0,0 +1,39 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Protocol
from llama_stack.providers.datatypes import AdapterSpec, Api, ProviderSpec, RemoteProviderSpec
from llama_stack.schema_utils import webmethod
def available_providers() -> list[ProviderSpec]:
return [
RemoteProviderSpec(
api=Api.weather,
provider_type="remote::kaze",
config_class="llama_stack_provider_kaze.KazeProviderConfig",
adapter=AdapterSpec(
adapter_type="kaze",
module="llama_stack_provider_kaze",
pip_packages=["llama_stack_provider_kaze"],
config_class="llama_stack_provider_kaze.KazeProviderConfig",
),
),
]
class WeatherProvider(Protocol):
"""
A protocol for the Weather API.
"""
@webmethod(route="/weather/locations", method="GET")
async def get_available_locations() -> dict[str, list[str]]:
"""
Get the available locations.
"""
...

View file

@ -0,0 +1,15 @@
[project]
name = "llama-stack-provider-kaze"
version = "0.1.0"
description = "Kaze weather provider for Llama Stack"
readme = "README.md"
requires-python = ">=3.10"
dependencies = ["llama-stack", "pydantic", "aiohttp"]
[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find]
where = ["src"]
include = ["llama_stack_provider_kaze", "llama_stack_provider_kaze.*"]

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""Kaze weather provider for Llama Stack."""
from .config import KazeProviderConfig
from .kaze import WeatherKazeAdapter
__all__ = ["KazeProviderConfig", "WeatherKazeAdapter"]
async def get_adapter_impl(config: KazeProviderConfig, _deps):
from .kaze import WeatherKazeAdapter
impl = WeatherKazeAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class KazeProviderConfig(BaseModel):
"""Configuration for the Kaze weather provider."""

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack_api_weather.weather import WeatherProvider
from .config import KazeProviderConfig
class WeatherKazeAdapter(WeatherProvider):
"""Kaze weather provider implementation."""
def __init__(
self,
config: KazeProviderConfig,
) -> None:
self.config = config
async def initialize(self) -> None:
pass
async def get_available_locations(self) -> dict[str, list[str]]:
"""Get available weather locations."""
return {"locations": ["Paris", "Tokyo"]}

13
tests/external/run-byoa.yaml vendored Normal file
View file

@ -0,0 +1,13 @@
version: "2"
image_name: "llama-stack-api-weather"
apis:
- weather
providers:
weather:
- provider_id: kaze
provider_type: remote::kaze
config: {}
external_apis_dir: ~/.llama/apis.d
external_providers_dir: ~/.llama/providers.d
server:
port: 8321

4
tests/external/weather.yaml vendored Normal file
View file

@ -0,0 +1,4 @@
module: llama_stack_api_weather
name: weather
pip_packages: ["tests/external/llama-stack-api-weather"]
protocol: WeatherProvider

View file

@ -179,9 +179,7 @@ def test_openai_completion_prompt_logprobs(llama_stack_client, client_with_model
model=text_model_id,
prompt=prompt,
stream=False,
extra_body={
"prompt_logprobs": prompt_logprobs,
},
prompt_logprobs=prompt_logprobs,
)
assert len(response.choices) > 0
choice = response.choices[0]
@ -196,9 +194,7 @@ def test_openai_completion_guided_choice(llama_stack_client, client_with_models,
model=text_model_id,
prompt=prompt,
stream=False,
extra_body={
"guided_choice": ["joy", "sadness"],
},
guided_choice=["joy", "sadness"],
)
assert len(response.choices) > 0
choice = response.choices[0]

View file

@ -15,6 +15,7 @@ from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.shields.shields import Shield
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.distribution.datatypes import RegistryEntrySource
from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable
from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
@ -45,6 +46,30 @@ class InferenceImpl(Impl):
async def unregister_model(self, model_id: str):
return model_id
async def should_refresh_models(self):
return False
async def list_models(self):
return [
Model(
identifier="provider-model-1",
provider_resource_id="provider-model-1",
provider_id="test_provider",
metadata={},
model_type=ModelType.llm,
),
Model(
identifier="provider-model-2",
provider_resource_id="provider-model-2",
provider_id="test_provider",
metadata={"embedding_dimension": 512},
model_type=ModelType.embedding,
),
]
async def shutdown(self):
pass
class SafetyImpl(Impl):
def __init__(self):
@ -378,3 +403,170 @@ async def test_models_fallback_lookup_behavior(cached_disk_dist_registry):
raise AssertionError("Should have raised ValueError for non-existent model")
except ValueError as e:
assert "not found" in str(e)
async def test_models_source_tracking_default(cached_disk_dist_registry):
"""Test that models registered via register_model get default source."""
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register model via register_model (should get default source)
await table.register_model(model_id="user-model", provider_id="test_provider")
models = await table.list_models()
assert len(models.data) == 1
model = models.data[0]
assert model.source == RegistryEntrySource.via_register_api
assert model.identifier == "test_provider/user-model"
# Cleanup
await table.shutdown()
async def test_models_source_tracking_provider(cached_disk_dist_registry):
"""Test that models registered via update_registered_models get provider source."""
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Simulate provider refresh by calling update_registered_models
provider_models = [
Model(
identifier="provider-model-1",
provider_resource_id="provider-model-1",
provider_id="test_provider",
metadata={},
model_type=ModelType.llm,
),
Model(
identifier="provider-model-2",
provider_resource_id="provider-model-2",
provider_id="test_provider",
metadata={"embedding_dimension": 512},
model_type=ModelType.embedding,
),
]
await table.update_registered_models("test_provider", provider_models)
models = await table.list_models()
assert len(models.data) == 2
# All models should have provider source
for model in models.data:
assert model.source == RegistryEntrySource.listed_from_provider
assert model.provider_id == "test_provider"
# Cleanup
await table.shutdown()
async def test_models_source_interaction_preserves_default(cached_disk_dist_registry):
"""Test that provider refresh preserves user-registered models with default source."""
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# First register a user model with same provider_resource_id as provider will later provide
await table.register_model(
model_id="my-custom-alias", provider_model_id="provider-model-1", provider_id="test_provider"
)
# Verify user model is registered with default source
models = await table.list_models()
assert len(models.data) == 1
user_model = models.data[0]
assert user_model.source == RegistryEntrySource.via_register_api
assert user_model.identifier == "my-custom-alias"
assert user_model.provider_resource_id == "provider-model-1"
# Now simulate provider refresh
provider_models = [
Model(
identifier="provider-model-1",
provider_resource_id="provider-model-1",
provider_id="test_provider",
metadata={},
model_type=ModelType.llm,
),
Model(
identifier="different-model",
provider_resource_id="different-model",
provider_id="test_provider",
metadata={},
model_type=ModelType.llm,
),
]
await table.update_registered_models("test_provider", provider_models)
# Verify user model with alias is preserved, but provider added new model
models = await table.list_models()
assert len(models.data) == 2
# Find the user model and provider model
user_model = next((m for m in models.data if m.identifier == "my-custom-alias"), None)
provider_model = next((m for m in models.data if m.identifier == "different-model"), None)
assert user_model is not None
assert user_model.source == RegistryEntrySource.via_register_api
assert user_model.provider_resource_id == "provider-model-1"
assert provider_model is not None
assert provider_model.source == RegistryEntrySource.listed_from_provider
assert provider_model.provider_resource_id == "different-model"
# Cleanup
await table.shutdown()
async def test_models_source_interaction_cleanup_provider_models(cached_disk_dist_registry):
"""Test that provider refresh removes old provider models but keeps default ones."""
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register a user model
await table.register_model(model_id="user-model", provider_id="test_provider")
# Add some provider models
provider_models_v1 = [
Model(
identifier="provider-model-old",
provider_resource_id="provider-model-old",
provider_id="test_provider",
metadata={},
model_type=ModelType.llm,
),
]
await table.update_registered_models("test_provider", provider_models_v1)
# Verify we have both user and provider models
models = await table.list_models()
assert len(models.data) == 2
# Now update with new provider models (should remove old provider models)
provider_models_v2 = [
Model(
identifier="provider-model-new",
provider_resource_id="provider-model-new",
provider_id="test_provider",
metadata={},
model_type=ModelType.llm,
),
]
await table.update_registered_models("test_provider", provider_models_v2)
# Should have user model + new provider model, old provider model gone
models = await table.list_models()
assert len(models.data) == 2
identifiers = {m.identifier for m in models.data}
assert "test_provider/user-model" in identifiers # User model preserved
assert "provider-model-new" in identifiers # New provider model (uses provider's identifier)
assert "provider-model-old" not in identifiers # Old provider model removed
# Verify sources are correct
user_model = next((m for m in models.data if m.identifier == "test_provider/user-model"), None)
provider_model = next((m for m in models.data if m.identifier == "provider-model-new"), None)
assert user_model.source == RegistryEntrySource.via_register_api
assert provider_model.source == RegistryEntrySource.listed_from_provider
# Cleanup
await table.shutdown()

View file

@ -5,321 +5,353 @@
# the root directory of this source tree.
import os
import unittest
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from llama_stack.apis.inference import CompletionMessage, UserMessage
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.safety import RunShieldResponse, ViolationLevel
from llama_stack.apis.shields import Shield
from llama_stack.models.llama.datatypes import StopReason
from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig
from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter
class TestNVIDIASafetyAdapter(unittest.TestCase):
def setUp(self):
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
class TestNVIDIASafetyAdapter(NVIDIASafetyAdapter):
"""Test implementation that provides the required shield_store."""
# Initialize the adapter
self.config = NVIDIASafetyConfig(
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
)
self.adapter = NVIDIASafetyAdapter(config=self.config)
self.shield_store = AsyncMock()
self.adapter.shield_store = self.shield_store
def __init__(self, config: NVIDIASafetyConfig, shield_store):
super().__init__(config)
self.shield_store = shield_store
# Mock the HTTP request methods
self.guardrails_post_patcher = patch(
"llama_stack.providers.remote.safety.nvidia.nvidia.NeMoGuardrails._guardrails_post"
)
self.mock_guardrails_post = self.guardrails_post_patcher.start()
self.mock_guardrails_post.return_value = {"status": "allowed"}
def tearDown(self):
"""Clean up after each test."""
self.guardrails_post_patcher.stop()
@pytest.fixture
def nvidia_adapter():
"""Set up the NVIDIASafetyAdapter for testing."""
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
@pytest.fixture(autouse=True)
def inject_fixtures(self, run_async):
self.run_async = run_async
# Initialize the adapter
config = NVIDIASafetyConfig(
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
)
def _assert_request(
self,
mock_call: MagicMock,
expected_url: str,
expected_headers: dict[str, str] | None = None,
expected_json: dict[str, Any] | None = None,
) -> None:
"""
Helper method to verify request details in mock API calls.
# Create a mock shield store that implements the ShieldStore protocol
shield_store = AsyncMock()
shield_store.get_shield = AsyncMock()
Args:
mock_call: The MagicMock object that was called
expected_url: The expected URL to which the request was made
expected_headers: Optional dictionary of expected request headers
expected_json: Optional dictionary of expected JSON payload
"""
call_args = mock_call.call_args
adapter = TestNVIDIASafetyAdapter(config=config, shield_store=shield_store)
# Check URL
assert call_args[0][0] == expected_url
return adapter
# Check headers if provided
if expected_headers:
for key, value in expected_headers.items():
assert call_args[1]["headers"][key] == value
# Check JSON if provided
if expected_json:
for key, value in expected_json.items():
if isinstance(value, dict):
for nested_key, nested_value in value.items():
assert call_args[1]["json"][key][nested_key] == nested_value
else:
assert call_args[1]["json"][key] == value
@pytest.fixture
def mock_guardrails_post():
"""Mock the HTTP request methods."""
with patch("llama_stack.providers.remote.safety.nvidia.nvidia.NeMoGuardrails._guardrails_post") as mock_post:
mock_post.return_value = {"status": "allowed"}
yield mock_post
def test_register_shield_with_valid_id(self):
shield = Shield(
provider_id="nvidia",
type="shield",
identifier="test-shield",
provider_resource_id="test-model",
)
# Register the shield
self.run_async(self.adapter.register_shield(shield))
def _assert_request(
mock_call: MagicMock,
expected_url: str,
expected_headers: dict[str, str] | None = None,
expected_json: dict[str, Any] | None = None,
) -> None:
"""
Helper method to verify request details in mock API calls.
def test_register_shield_without_id(self):
shield = Shield(
provider_id="nvidia",
type="shield",
identifier="test-shield",
provider_resource_id="",
)
Args:
mock_call: The MagicMock object that was called
expected_url: The expected URL to which the request was made
expected_headers: Optional dictionary of expected request headers
expected_json: Optional dictionary of expected JSON payload
"""
call_args = mock_call.call_args
# Register the shield should raise a ValueError
with self.assertRaises(ValueError):
self.run_async(self.adapter.register_shield(shield))
# Check URL
assert call_args[0][0] == expected_url
def test_run_shield_allowed(self):
# Set up the shield
shield_id = "test-shield"
shield = Shield(
provider_id="nvidia",
type="shield",
identifier=shield_id,
provider_resource_id="test-model",
)
self.shield_store.get_shield.return_value = shield
# Check headers if provided
if expected_headers:
for key, value in expected_headers.items():
assert call_args[1]["headers"][key] == value
# Mock Guardrails API response
self.mock_guardrails_post.return_value = {"status": "allowed"}
# Check JSON if provided
if expected_json:
for key, value in expected_json.items():
if isinstance(value, dict):
for nested_key, nested_value in value.items():
assert call_args[1]["json"][key][nested_key] == nested_value
else:
assert call_args[1]["json"][key] == value
# Run the shield
messages = [
UserMessage(role="user", content="Hello, how are you?"),
CompletionMessage(
role="assistant",
content="I'm doing well, thank you for asking!",
stop_reason="end_of_message",
tool_calls=[],
),
]
result = self.run_async(self.adapter.run_shield(shield_id, messages))
# Verify the shield store was called
self.shield_store.get_shield.assert_called_once_with(shield_id)
async def test_register_shield_with_valid_id(nvidia_adapter):
adapter = nvidia_adapter
# Verify the Guardrails API was called correctly
self.mock_guardrails_post.assert_called_once_with(
path="/v1/guardrail/checks",
data={
"model": shield_id,
"messages": [
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
],
"temperature": 1.0,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"max_tokens": 160,
"stream": False,
"guardrails": {
"config_id": "self-check",
},
shield = Shield(
provider_id="nvidia",
type=ResourceType.shield,
identifier="test-shield",
provider_resource_id="test-model",
)
# Register the shield
await adapter.register_shield(shield)
async def test_register_shield_without_id(nvidia_adapter):
adapter = nvidia_adapter
shield = Shield(
provider_id="nvidia",
type=ResourceType.shield,
identifier="test-shield",
provider_resource_id="",
)
# Register the shield should raise a ValueError
with pytest.raises(ValueError):
await adapter.register_shield(shield)
async def test_run_shield_allowed(nvidia_adapter, mock_guardrails_post):
adapter = nvidia_adapter
# Set up the shield
shield_id = "test-shield"
shield = Shield(
provider_id="nvidia",
type=ResourceType.shield,
identifier=shield_id,
provider_resource_id="test-model",
)
adapter.shield_store.get_shield.return_value = shield
# Mock Guardrails API response
mock_guardrails_post.return_value = {"status": "allowed"}
# Run the shield
messages = [
UserMessage(role="user", content="Hello, how are you?"),
CompletionMessage(
role="assistant",
content="I'm doing well, thank you for asking!",
stop_reason=StopReason.end_of_message,
tool_calls=[],
),
]
result = await adapter.run_shield(shield_id, messages)
# Verify the shield store was called
adapter.shield_store.get_shield.assert_called_once_with(shield_id)
# Verify the Guardrails API was called correctly
mock_guardrails_post.assert_called_once_with(
path="/v1/guardrail/checks",
data={
"model": shield_id,
"messages": [
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
],
"temperature": 1.0,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"max_tokens": 160,
"stream": False,
"guardrails": {
"config_id": "self-check",
},
)
},
)
# Verify the result
assert isinstance(result, RunShieldResponse)
assert result.violation is None
# Verify the result
assert isinstance(result, RunShieldResponse)
assert result.violation is None
def test_run_shield_blocked(self):
# Set up the shield
shield_id = "test-shield"
shield = Shield(
provider_id="nvidia",
type="shield",
identifier=shield_id,
provider_resource_id="test-model",
)
self.shield_store.get_shield.return_value = shield
# Mock Guardrails API response
self.mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}}
async def test_run_shield_blocked(nvidia_adapter, mock_guardrails_post):
adapter = nvidia_adapter
# Run the shield
messages = [
UserMessage(role="user", content="Hello, how are you?"),
CompletionMessage(
role="assistant",
content="I'm doing well, thank you for asking!",
stop_reason="end_of_message",
tool_calls=[],
),
]
result = self.run_async(self.adapter.run_shield(shield_id, messages))
# Set up the shield
shield_id = "test-shield"
shield = Shield(
provider_id="nvidia",
type=ResourceType.shield,
identifier=shield_id,
provider_resource_id="test-model",
)
adapter.shield_store.get_shield.return_value = shield
# Verify the shield store was called
self.shield_store.get_shield.assert_called_once_with(shield_id)
# Mock Guardrails API response
mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}}
# Verify the Guardrails API was called correctly
self.mock_guardrails_post.assert_called_once_with(
path="/v1/guardrail/checks",
data={
"model": shield_id,
"messages": [
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
],
"temperature": 1.0,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"max_tokens": 160,
"stream": False,
"guardrails": {
"config_id": "self-check",
},
# Run the shield
messages = [
UserMessage(role="user", content="Hello, how are you?"),
CompletionMessage(
role="assistant",
content="I'm doing well, thank you for asking!",
stop_reason=StopReason.end_of_message,
tool_calls=[],
),
]
result = await adapter.run_shield(shield_id, messages)
# Verify the shield store was called
adapter.shield_store.get_shield.assert_called_once_with(shield_id)
# Verify the Guardrails API was called correctly
mock_guardrails_post.assert_called_once_with(
path="/v1/guardrail/checks",
data={
"model": shield_id,
"messages": [
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
],
"temperature": 1.0,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"max_tokens": 160,
"stream": False,
"guardrails": {
"config_id": "self-check",
},
)
},
)
# Verify the result
assert result.violation is not None
assert isinstance(result, RunShieldResponse)
assert result.violation.user_message == "Sorry I cannot do this."
assert result.violation.violation_level == ViolationLevel.ERROR
assert result.violation.metadata == {"reason": "harmful_content"}
# Verify the result
assert result.violation is not None
assert isinstance(result, RunShieldResponse)
assert result.violation.user_message == "Sorry I cannot do this."
assert result.violation.violation_level == ViolationLevel.ERROR
assert result.violation.metadata == {"reason": "harmful_content"}
def test_run_shield_not_found(self):
# Set up shield store to return None
shield_id = "non-existent-shield"
self.shield_store.get_shield.return_value = None
messages = [
UserMessage(role="user", content="Hello, how are you?"),
]
async def test_run_shield_not_found(nvidia_adapter, mock_guardrails_post):
adapter = nvidia_adapter
with self.assertRaises(ValueError):
self.run_async(self.adapter.run_shield(shield_id, messages))
# Set up shield store to return None
shield_id = "non-existent-shield"
adapter.shield_store.get_shield.return_value = None
# Verify the shield store was called
self.shield_store.get_shield.assert_called_once_with(shield_id)
messages = [
UserMessage(role="user", content="Hello, how are you?"),
]
# Verify the Guardrails API was not called
self.mock_guardrails_post.assert_not_called()
with pytest.raises(ValueError):
await adapter.run_shield(shield_id, messages)
def test_run_shield_http_error(self):
shield_id = "test-shield"
shield = Shield(
provider_id="nvidia",
type="shield",
identifier=shield_id,
provider_resource_id="test-model",
)
self.shield_store.get_shield.return_value = shield
# Verify the shield store was called
adapter.shield_store.get_shield.assert_called_once_with(shield_id)
# Mock Guardrails API to raise an exception
error_msg = "API Error: 500 Internal Server Error"
self.mock_guardrails_post.side_effect = Exception(error_msg)
# Verify the Guardrails API was not called
mock_guardrails_post.assert_not_called()
# Running the shield should raise an exception
messages = [
UserMessage(role="user", content="Hello, how are you?"),
CompletionMessage(
role="assistant",
content="I'm doing well, thank you for asking!",
stop_reason="end_of_message",
tool_calls=[],
),
]
with self.assertRaises(Exception) as context:
self.run_async(self.adapter.run_shield(shield_id, messages))
# Verify the shield store was called
self.shield_store.get_shield.assert_called_once_with(shield_id)
async def test_run_shield_http_error(nvidia_adapter, mock_guardrails_post):
adapter = nvidia_adapter
# Verify the Guardrails API was called correctly
self.mock_guardrails_post.assert_called_once_with(
path="/v1/guardrail/checks",
data={
"model": shield_id,
"messages": [
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
],
"temperature": 1.0,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"max_tokens": 160,
"stream": False,
"guardrails": {
"config_id": "self-check",
},
shield_id = "test-shield"
shield = Shield(
provider_id="nvidia",
type=ResourceType.shield,
identifier=shield_id,
provider_resource_id="test-model",
)
adapter.shield_store.get_shield.return_value = shield
# Mock Guardrails API to raise an exception
error_msg = "API Error: 500 Internal Server Error"
mock_guardrails_post.side_effect = Exception(error_msg)
# Running the shield should raise an exception
messages = [
UserMessage(role="user", content="Hello, how are you?"),
CompletionMessage(
role="assistant",
content="I'm doing well, thank you for asking!",
stop_reason=StopReason.end_of_message,
tool_calls=[],
),
]
with pytest.raises(Exception) as exc_info:
await adapter.run_shield(shield_id, messages)
# Verify the shield store was called
adapter.shield_store.get_shield.assert_called_once_with(shield_id)
# Verify the Guardrails API was called correctly
mock_guardrails_post.assert_called_once_with(
path="/v1/guardrail/checks",
data={
"model": shield_id,
"messages": [
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
],
"temperature": 1.0,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"max_tokens": 160,
"stream": False,
"guardrails": {
"config_id": "self-check",
},
)
# Verify the exception message
assert error_msg in str(context.exception)
},
)
# Verify the exception message
assert error_msg in str(exc_info.value)
def test_init_nemo_guardrails(self):
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
test_config_id = "test-custom-config-id"
config = NVIDIASafetyConfig(
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
config_id=test_config_id,
)
# Initialize with default parameters
test_model = "test-model"
guardrails = NeMoGuardrails(config, test_model)
def test_init_nemo_guardrails():
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
# Verify the attributes are set correctly
assert guardrails.config_id == test_config_id
assert guardrails.model == test_model
assert guardrails.threshold == 0.9 # Default value
assert guardrails.temperature == 1.0 # Default value
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
# Initialize with custom parameters
guardrails = NeMoGuardrails(config, test_model, threshold=0.8, temperature=0.7)
test_config_id = "test-custom-config-id"
config = NVIDIASafetyConfig(
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
config_id=test_config_id,
)
# Initialize with default parameters
test_model = "test-model"
guardrails = NeMoGuardrails(config, test_model)
# Verify the attributes are set correctly
assert guardrails.config_id == test_config_id
assert guardrails.model == test_model
assert guardrails.threshold == 0.8
assert guardrails.temperature == 0.7
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
# Verify the attributes are set correctly
assert guardrails.config_id == test_config_id
assert guardrails.model == test_model
assert guardrails.threshold == 0.9 # Default value
assert guardrails.temperature == 1.0 # Default value
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
def test_init_nemo_guardrails_invalid_temperature(self):
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
# Initialize with custom parameters
guardrails = NeMoGuardrails(config, test_model, threshold=0.8, temperature=0.7)
config = NVIDIASafetyConfig(
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
config_id="test-custom-config-id",
)
with self.assertRaises(ValueError):
NeMoGuardrails(config, "test-model", temperature=0)
# Verify the attributes are set correctly
assert guardrails.config_id == test_config_id
assert guardrails.model == test_model
assert guardrails.threshold == 0.8
assert guardrails.temperature == 0.7
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
def test_init_nemo_guardrails_invalid_temperature():
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
config = NVIDIASafetyConfig(
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
config_id="test-custom-config-id",
)
with pytest.raises(ValueError):
NeMoGuardrails(config, "test-model", temperature=0)

View file

@ -19,7 +19,8 @@ from llama_stack.distribution.datatypes import (
OAuth2JWKSConfig,
OAuth2TokenAuthConfig,
)
from llama_stack.distribution.server.auth import AuthenticationMiddleware
from llama_stack.distribution.request_headers import User
from llama_stack.distribution.server.auth import AuthenticationMiddleware, _has_required_scope
from llama_stack.distribution.server.auth_providers import (
get_attributes_from_claims,
)
@ -73,7 +74,7 @@ def http_app(mock_auth_endpoint):
),
access_policy=[],
)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
@app.get("/test")
def test_endpoint():
@ -111,7 +112,50 @@ def mock_http_middleware(mock_auth_endpoint):
),
access_policy=[],
)
return AuthenticationMiddleware(mock_app, auth_config), mock_app
return AuthenticationMiddleware(mock_app, auth_config, {}), mock_app
@pytest.fixture
def mock_impls():
"""Mock implementations for scope testing"""
return {}
@pytest.fixture
def scope_middleware_with_mocks(mock_auth_endpoint):
"""Create AuthenticationMiddleware with mocked route implementations"""
mock_app = AsyncMock()
auth_config = AuthenticationConfig(
provider_config=CustomAuthConfig(
type=AuthProviderType.CUSTOM,
endpoint=mock_auth_endpoint,
),
access_policy=[],
)
middleware = AuthenticationMiddleware(mock_app, auth_config, {})
# Mock the route_impls to simulate finding routes with required scopes
from llama_stack.schema_utils import WebMethod
scoped_webmethod = WebMethod(route="/test/scoped", method="POST", required_scope="test.read")
public_webmethod = WebMethod(route="/test/public", method="GET")
# Mock the route finding logic
def mock_find_matching_route(method, path, route_impls):
if method == "POST" and path == "/test/scoped":
return None, {}, "/test/scoped", scoped_webmethod
elif method == "GET" and path == "/test/public":
return None, {}, "/test/public", public_webmethod
else:
raise ValueError("No matching route")
import llama_stack.distribution.server.auth
llama_stack.distribution.server.auth.find_matching_route = mock_find_matching_route
llama_stack.distribution.server.auth.initialize_route_impls = lambda impls: {}
return middleware, mock_app
async def mock_post_success(*args, **kwargs):
@ -138,6 +182,36 @@ async def mock_post_exception(*args, **kwargs):
raise Exception("Connection error")
async def mock_post_success_with_scope(*args, **kwargs):
"""Mock auth response for user with test.read scope"""
return MockResponse(
200,
{
"message": "Authentication successful",
"principal": "test-user",
"attributes": {
"scopes": ["test.read", "other.scope"],
"roles": ["user"],
},
},
)
async def mock_post_success_no_scope(*args, **kwargs):
"""Mock auth response for user without required scope"""
return MockResponse(
200,
{
"message": "Authentication successful",
"principal": "test-user",
"attributes": {
"scopes": ["other.scope"],
"roles": ["user"],
},
},
)
# HTTP Endpoint Tests
def test_missing_auth_header(http_client):
response = http_client.get("/test")
@ -252,7 +326,7 @@ def oauth2_app():
),
access_policy=[],
)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
@app.get("/test")
def test_endpoint():
@ -351,7 +425,7 @@ def oauth2_app_with_jwks_token():
),
access_policy=[],
)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
@app.get("/test")
def test_endpoint():
@ -442,7 +516,7 @@ def introspection_app(mock_introspection_endpoint):
),
access_policy=[],
)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
@app.get("/test")
def test_endpoint():
@ -472,7 +546,7 @@ def introspection_app_with_custom_mapping(mock_introspection_endpoint):
),
access_policy=[],
)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
@app.get("/test")
def test_endpoint():
@ -581,3 +655,122 @@ def test_valid_introspection_with_custom_mapping_authentication(
)
assert response.status_code == 200
assert response.json() == {"message": "Authentication successful"}
# Scope-based authorization tests
@patch("httpx.AsyncClient.post", new=mock_post_success_with_scope)
async def test_scope_authorization_success(scope_middleware_with_mocks, valid_api_key):
"""Test that user with required scope can access protected endpoint"""
middleware, mock_app = scope_middleware_with_mocks
mock_receive = AsyncMock()
mock_send = AsyncMock()
scope = {
"type": "http",
"path": "/test/scoped",
"method": "POST",
"headers": [(b"authorization", f"Bearer {valid_api_key}".encode())],
}
await middleware(scope, mock_receive, mock_send)
# Should call the downstream app (no 403 error sent)
mock_app.assert_called_once_with(scope, mock_receive, mock_send)
mock_send.assert_not_called()
@patch("httpx.AsyncClient.post", new=mock_post_success_no_scope)
async def test_scope_authorization_denied(scope_middleware_with_mocks, valid_api_key):
"""Test that user without required scope gets 403 access denied"""
middleware, mock_app = scope_middleware_with_mocks
mock_receive = AsyncMock()
mock_send = AsyncMock()
scope = {
"type": "http",
"path": "/test/scoped",
"method": "POST",
"headers": [(b"authorization", f"Bearer {valid_api_key}".encode())],
}
await middleware(scope, mock_receive, mock_send)
# Should send 403 error, not call downstream app
mock_app.assert_not_called()
assert mock_send.call_count == 2 # start + body
# Check the response
start_call = mock_send.call_args_list[0][0][0]
assert start_call["status"] == 403
body_call = mock_send.call_args_list[1][0][0]
body_text = body_call["body"].decode()
assert "Access denied" in body_text
assert "test.read" in body_text
@patch("httpx.AsyncClient.post", new=mock_post_success_no_scope)
async def test_public_endpoint_no_scope_required(scope_middleware_with_mocks, valid_api_key):
"""Test that public endpoints work without specific scopes"""
middleware, mock_app = scope_middleware_with_mocks
mock_receive = AsyncMock()
mock_send = AsyncMock()
scope = {
"type": "http",
"path": "/test/public",
"method": "GET",
"headers": [(b"authorization", f"Bearer {valid_api_key}".encode())],
}
await middleware(scope, mock_receive, mock_send)
# Should call the downstream app (no error)
mock_app.assert_called_once_with(scope, mock_receive, mock_send)
mock_send.assert_not_called()
async def test_scope_authorization_no_auth_disabled(scope_middleware_with_mocks):
"""Test that when auth is disabled (no user), scope checks are bypassed"""
middleware, mock_app = scope_middleware_with_mocks
mock_receive = AsyncMock()
mock_send = AsyncMock()
scope = {
"type": "http",
"path": "/test/scoped",
"method": "POST",
"headers": [], # No authorization header
}
await middleware(scope, mock_receive, mock_send)
# Should send 401 auth error, not call downstream app
mock_app.assert_not_called()
assert mock_send.call_count == 2 # start + body
# Check the response
start_call = mock_send.call_args_list[0][0][0]
assert start_call["status"] == 401
body_call = mock_send.call_args_list[1][0][0]
body_text = body_call["body"].decode()
assert "Authentication required" in body_text
def test_has_required_scope_function():
"""Test the _has_required_scope function directly"""
# Test user with required scope
user_with_scope = User(principal="test-user", attributes={"scopes": ["test.read", "other.scope"]})
assert _has_required_scope("test.read", user_with_scope)
# Test user without required scope
user_without_scope = User(principal="test-user", attributes={"scopes": ["other.scope"]})
assert not _has_required_scope("test.read", user_without_scope)
# Test user with no scopes attribute
user_no_scopes = User(principal="test-user", attributes={})
assert not _has_required_scope("test.read", user_no_scopes)
# Test no user (auth disabled)
assert _has_required_scope("test.read", None)