mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-21 20:18:52 +00:00
Merge branch 'main' into nvidia-e2e-notebook
This commit is contained in:
commit
7cdd2a0410
264 changed files with 229042 additions and 8445 deletions
155
tests/unit/distribution/test_context.py
Normal file
155
tests/unit/distribution/test_context.py
Normal file
|
@ -0,0 +1,155 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextvars import ContextVar
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserve_contexts_with_exception():
|
||||
# Create context variable
|
||||
context_var = ContextVar("exception_var", default="initial")
|
||||
token = context_var.set("start_value")
|
||||
|
||||
# Create an async generator that raises an exception
|
||||
async def exception_generator():
|
||||
yield context_var.get()
|
||||
context_var.set("modified")
|
||||
raise ValueError("Test exception")
|
||||
yield None # This will never be reached
|
||||
|
||||
# Wrap the generator
|
||||
wrapped_gen = preserve_contexts_async_generator(exception_generator(), [context_var])
|
||||
|
||||
# First iteration should work
|
||||
value = await wrapped_gen.__anext__()
|
||||
assert value == "start_value"
|
||||
|
||||
# Second iteration should raise the exception
|
||||
with pytest.raises(ValueError, match="Test exception"):
|
||||
await wrapped_gen.__anext__()
|
||||
|
||||
# Clean up
|
||||
context_var.reset(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserve_contexts_empty_generator():
|
||||
# Create context variable
|
||||
context_var = ContextVar("empty_var", default="initial")
|
||||
token = context_var.set("value")
|
||||
|
||||
# Create an empty async generator
|
||||
async def empty_generator():
|
||||
if False: # This condition ensures the generator yields nothing
|
||||
yield None
|
||||
|
||||
# Wrap the generator
|
||||
wrapped_gen = preserve_contexts_async_generator(empty_generator(), [context_var])
|
||||
|
||||
# The generator should raise StopAsyncIteration immediately
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await wrapped_gen.__anext__()
|
||||
|
||||
# Context variable should remain unchanged
|
||||
assert context_var.get() == "value"
|
||||
|
||||
# Clean up
|
||||
context_var.reset(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserve_contexts_across_event_loops():
|
||||
"""
|
||||
Test that context variables are preserved across event loop boundaries with nested generators.
|
||||
This simulates the real-world scenario where:
|
||||
1. A new event loop is created for each streaming request
|
||||
2. The async generator runs inside that loop
|
||||
3. There are multiple levels of nested generators
|
||||
4. Context needs to be preserved across these boundaries
|
||||
"""
|
||||
# Create context variables
|
||||
request_id = ContextVar("request_id", default=None)
|
||||
user_id = ContextVar("user_id", default=None)
|
||||
|
||||
# Set initial values
|
||||
|
||||
# Results container to verify values across thread boundaries
|
||||
results = []
|
||||
|
||||
# Inner-most generator (level 2)
|
||||
async def inner_generator():
|
||||
# Should have the context from the outer scope
|
||||
yield (1, request_id.get(), user_id.get())
|
||||
|
||||
# Modify one context variable
|
||||
user_id.set("user-modified")
|
||||
|
||||
# Should reflect the modification
|
||||
yield (2, request_id.get(), user_id.get())
|
||||
|
||||
# Middle generator (level 1)
|
||||
async def middle_generator():
|
||||
inner_gen = inner_generator()
|
||||
|
||||
# Forward the first yield from inner
|
||||
item = await inner_gen.__anext__()
|
||||
yield item
|
||||
|
||||
# Forward the second yield from inner
|
||||
item = await inner_gen.__anext__()
|
||||
yield item
|
||||
|
||||
request_id.set("req-modified")
|
||||
|
||||
# Add our own yield with both modified variables
|
||||
yield (3, request_id.get(), user_id.get())
|
||||
|
||||
# Function to run in a separate thread with a new event loop
|
||||
def run_in_new_loop():
|
||||
# Create a new event loop for this thread
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# Outer generator (runs in the new loop)
|
||||
async def outer_generator():
|
||||
request_id.set("req-12345")
|
||||
user_id.set("user-6789")
|
||||
# Wrap the middle generator
|
||||
wrapped_gen = preserve_contexts_async_generator(middle_generator(), [request_id, user_id])
|
||||
|
||||
# Process all items from the middle generator
|
||||
async for item in wrapped_gen:
|
||||
# Store results for verification
|
||||
results.append(item)
|
||||
|
||||
# Run the outer generator in the new loop
|
||||
loop.run_until_complete(outer_generator())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Run the generator chain in a separate thread with a new event loop
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(run_in_new_loop)
|
||||
future.result() # Wait for completion
|
||||
|
||||
# Verify the results
|
||||
assert len(results) == 3
|
||||
|
||||
# First yield should have original values
|
||||
assert results[0] == (1, "req-12345", "user-6789")
|
||||
|
||||
# Second yield should have modified user_id
|
||||
assert results[1] == (2, "req-12345", "user-modified")
|
||||
|
||||
# Third yield should have both modified values
|
||||
assert results[2] == (3, "req-modified", "user-modified")
|
223
tests/unit/distribution/test_distribution.py
Normal file
223
tests/unit/distribution/test_distribution.py
Normal file
|
@ -0,0 +1,223 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, Provider, StackRunConfig
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
|
||||
class SampleConfig(BaseModel):
|
||||
foo: str = Field(
|
||||
default="bar",
|
||||
description="foo",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"foo": "baz",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_providers():
|
||||
"""Mock the available_providers function to return test providers."""
|
||||
with patch("llama_stack.providers.registry.inference.available_providers") as mock:
|
||||
mock.return_value = [
|
||||
ProviderSpec(
|
||||
provider_type="test_provider",
|
||||
api=Api.inference,
|
||||
adapter_type="test_adapter",
|
||||
config_class="test_provider.config.TestProviderConfig",
|
||||
)
|
||||
]
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_config(tmp_path):
|
||||
"""Create a base StackRunConfig with common settings."""
|
||||
return StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="sample_provider",
|
||||
provider_type="sample",
|
||||
config=SampleConfig.sample_run_config(),
|
||||
)
|
||||
]
|
||||
},
|
||||
external_providers_dir=str(tmp_path),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider_spec_yaml():
|
||||
"""Common provider spec YAML for testing."""
|
||||
return """
|
||||
adapter:
|
||||
adapter_type: test_provider
|
||||
config_class: test_provider.config.TestProviderConfig
|
||||
module: test_provider
|
||||
api_dependencies:
|
||||
- safety
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def inline_provider_spec_yaml():
|
||||
"""Common inline provider spec YAML for testing."""
|
||||
return """
|
||||
module: test_provider
|
||||
config_class: test_provider.config.TestProviderConfig
|
||||
pip_packages:
|
||||
- test-package
|
||||
api_dependencies:
|
||||
- safety
|
||||
optional_api_dependencies:
|
||||
- vector_io
|
||||
provider_data_validator: test_provider.validator.TestValidator
|
||||
container_image: test-image:latest
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_directories(tmp_path):
|
||||
"""Create the API directory structure for testing."""
|
||||
# Create remote provider directory
|
||||
remote_inference_dir = tmp_path / "remote" / "inference"
|
||||
remote_inference_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create inline provider directory
|
||||
inline_inference_dir = tmp_path / "inline" / "inference"
|
||||
inline_inference_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return remote_inference_dir, inline_inference_dir
|
||||
|
||||
|
||||
class TestProviderRegistry:
|
||||
"""Test suite for provider registry functionality."""
|
||||
|
||||
def test_builtin_providers(self, mock_providers):
|
||||
"""Test loading built-in providers."""
|
||||
registry = get_provider_registry(None)
|
||||
|
||||
assert Api.inference in registry
|
||||
assert "test_provider" in registry[Api.inference]
|
||||
assert registry[Api.inference]["test_provider"].provider_type == "test_provider"
|
||||
assert registry[Api.inference]["test_provider"].api == Api.inference
|
||||
|
||||
def test_external_remote_providers(self, api_directories, mock_providers, base_config, provider_spec_yaml):
|
||||
"""Test loading external remote providers from YAML files."""
|
||||
remote_dir, _ = api_directories
|
||||
with open(remote_dir / "test_provider.yaml", "w") as f:
|
||||
f.write(provider_spec_yaml)
|
||||
|
||||
registry = get_provider_registry(base_config)
|
||||
assert len(registry[Api.inference]) == 2
|
||||
|
||||
assert Api.inference in registry
|
||||
assert "remote::test_provider" in registry[Api.inference]
|
||||
provider = registry[Api.inference]["remote::test_provider"]
|
||||
assert provider.adapter.adapter_type == "test_provider"
|
||||
assert provider.adapter.module == "test_provider"
|
||||
assert provider.adapter.config_class == "test_provider.config.TestProviderConfig"
|
||||
assert Api.safety in provider.api_dependencies
|
||||
|
||||
def test_external_inline_providers(self, api_directories, mock_providers, base_config, inline_provider_spec_yaml):
|
||||
"""Test loading external inline providers from YAML files."""
|
||||
_, inline_dir = api_directories
|
||||
with open(inline_dir / "test_provider.yaml", "w") as f:
|
||||
f.write(inline_provider_spec_yaml)
|
||||
|
||||
registry = get_provider_registry(base_config)
|
||||
assert len(registry[Api.inference]) == 2
|
||||
|
||||
assert Api.inference in registry
|
||||
assert "inline::test_provider" in registry[Api.inference]
|
||||
provider = registry[Api.inference]["inline::test_provider"]
|
||||
assert provider.provider_type == "inline::test_provider"
|
||||
assert provider.module == "test_provider"
|
||||
assert provider.config_class == "test_provider.config.TestProviderConfig"
|
||||
assert provider.pip_packages == ["test-package"]
|
||||
assert Api.safety in provider.api_dependencies
|
||||
assert Api.vector_io in provider.optional_api_dependencies
|
||||
assert provider.provider_data_validator == "test_provider.validator.TestValidator"
|
||||
assert provider.container_image == "test-image:latest"
|
||||
|
||||
def test_invalid_yaml(self, api_directories, mock_providers, base_config):
|
||||
"""Test handling of invalid YAML files."""
|
||||
remote_dir, inline_dir = api_directories
|
||||
with open(remote_dir / "invalid.yaml", "w") as f:
|
||||
f.write("invalid: yaml: content: -")
|
||||
with open(inline_dir / "invalid.yaml", "w") as f:
|
||||
f.write("invalid: yaml: content: -")
|
||||
|
||||
with pytest.raises(yaml.YAMLError):
|
||||
get_provider_registry(base_config)
|
||||
|
||||
def test_missing_directory(self, mock_providers):
|
||||
"""Test handling of missing external providers directory."""
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="sample_provider",
|
||||
provider_type="sample",
|
||||
config=SampleConfig.sample_run_config(),
|
||||
)
|
||||
]
|
||||
},
|
||||
external_providers_dir="/nonexistent/dir",
|
||||
)
|
||||
with pytest.raises(FileNotFoundError):
|
||||
get_provider_registry(config)
|
||||
|
||||
def test_empty_api_directory(self, api_directories, mock_providers, base_config):
|
||||
"""Test handling of empty API directory."""
|
||||
registry = get_provider_registry(base_config)
|
||||
assert len(registry[Api.inference]) == 1 # Only built-in provider
|
||||
|
||||
def test_malformed_remote_provider_spec(self, api_directories, mock_providers, base_config):
|
||||
"""Test handling of malformed remote provider spec (missing required fields)."""
|
||||
remote_dir, _ = api_directories
|
||||
malformed_spec = """
|
||||
adapter:
|
||||
adapter_type: test_provider
|
||||
# Missing required fields
|
||||
api_dependencies:
|
||||
- safety
|
||||
"""
|
||||
with open(remote_dir / "malformed.yaml", "w") as f:
|
||||
f.write(malformed_spec)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
get_provider_registry(base_config)
|
||||
|
||||
def test_malformed_inline_provider_spec(self, api_directories, mock_providers, base_config):
|
||||
"""Test handling of malformed inline provider spec (missing required fields)."""
|
||||
_, inline_dir = api_directories
|
||||
malformed_spec = """
|
||||
module: test_provider
|
||||
# Missing required config_class
|
||||
pip_packages:
|
||||
- test-package
|
||||
"""
|
||||
with open(inline_dir / "malformed.yaml", "w") as f:
|
||||
f.write(malformed_spec)
|
||||
|
||||
with pytest.raises(KeyError) as exc_info:
|
||||
get_provider_registry(base_config)
|
||||
assert "config_class" in str(exc_info.value)
|
145
tests/unit/models/llama/llama3/test_tool_utils.py
Normal file
145
tests/unit/models/llama/llama3/test_tool_utils.py
Normal file
|
@ -0,0 +1,145 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from llama_stack.models.llama.llama3.tool_utils import ToolUtils
|
||||
|
||||
|
||||
class TestMaybeExtractCustomToolCall:
|
||||
def test_valid_single_tool_call(self):
|
||||
input_string = '[get_weather(location="San Francisco", units="celsius")]'
|
||||
result = ToolUtils.maybe_extract_custom_tool_call(input_string)
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 2
|
||||
assert result[0] == "get_weather"
|
||||
assert result[1] == {"location": "San Francisco", "units": "celsius"}
|
||||
|
||||
def test_valid_multiple_tool_calls(self):
|
||||
input_string = '[search(query="python programming"), get_time(timezone="UTC")]'
|
||||
result = ToolUtils.maybe_extract_custom_tool_call(input_string)
|
||||
|
||||
# Note: maybe_extract_custom_tool_call currently only returns the first tool call
|
||||
assert result is not None
|
||||
assert len(result) == 2
|
||||
assert result[0] == "search"
|
||||
assert result[1] == {"query": "python programming"}
|
||||
|
||||
def test_different_value_types(self):
|
||||
input_string = '[analyze_data(count=42, enabled=True, ratio=3.14, name="test", options=None)]'
|
||||
result = ToolUtils.maybe_extract_custom_tool_call(input_string)
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 2
|
||||
assert result[0] == "analyze_data"
|
||||
assert result[1] == {"count": 42, "enabled": True, "ratio": 3.14, "name": "test", "options": None}
|
||||
|
||||
def test_nested_structures(self):
|
||||
input_string = '[complex_function(filters={"min": 10, "max": 100}, tags=["important", "urgent"])]'
|
||||
result = ToolUtils.maybe_extract_custom_tool_call(input_string)
|
||||
|
||||
# This test checks that nested structures are handled
|
||||
assert result is not None
|
||||
assert len(result) == 2
|
||||
assert result[0] == "complex_function"
|
||||
assert "filters" in result[1]
|
||||
assert sorted(result[1]["filters"].items()) == sorted({"min": 10, "max": 100}.items())
|
||||
|
||||
assert "tags" in result[1]
|
||||
assert result[1]["tags"] == ["important", "urgent"]
|
||||
|
||||
def test_hyphenated_function_name(self):
|
||||
input_string = '[weather-forecast(city="London")]'
|
||||
result = ToolUtils.maybe_extract_custom_tool_call(input_string)
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 2
|
||||
assert result[0] == "weather-forecast" # Function name remains hyphenated
|
||||
assert result[1] == {"city": "London"}
|
||||
|
||||
def test_empty_input(self):
|
||||
input_string = "[]"
|
||||
result = ToolUtils.maybe_extract_custom_tool_call(input_string)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_invalid_format(self):
|
||||
invalid_inputs = [
|
||||
'get_weather(location="San Francisco")', # Missing outer brackets
|
||||
'{get_weather(location="San Francisco")}', # Wrong outer brackets
|
||||
'[get_weather(location="San Francisco"]', # Unmatched brackets
|
||||
'[get_weather{location="San Francisco"}]', # Wrong inner brackets
|
||||
"just some text", # Not a tool call format at all
|
||||
]
|
||||
|
||||
for input_string in invalid_inputs:
|
||||
result = ToolUtils.maybe_extract_custom_tool_call(input_string)
|
||||
assert result is None
|
||||
|
||||
def test_quotes_handling(self):
|
||||
input_string = '[search(query="Text with \\"quotes\\" inside")]'
|
||||
result = ToolUtils.maybe_extract_custom_tool_call(input_string)
|
||||
|
||||
# This test checks that escaped quotes are handled correctly
|
||||
assert result is not None
|
||||
|
||||
def test_single_quotes_in_arguments(self):
|
||||
input_string = "[add-note(name='demonote', content='demonstrating Llama Stack and MCP integration')]"
|
||||
result = ToolUtils.maybe_extract_custom_tool_call(input_string)
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 2
|
||||
assert result[0] == "add-note" # Function name remains hyphenated
|
||||
assert result[1] == {"name": "demonote", "content": "demonstrating Llama Stack and MCP integration"}
|
||||
|
||||
def test_json_format(self):
|
||||
input_string = '{"type": "function", "name": "search_web", "parameters": {"query": "AI research"}}'
|
||||
result = ToolUtils.maybe_extract_custom_tool_call(input_string)
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 2
|
||||
assert result[0] == "search_web"
|
||||
assert result[1] == {"query": "AI research"}
|
||||
|
||||
def test_python_list_format(self):
|
||||
input_string = "[calculate(x=10, y=20)]"
|
||||
result = ToolUtils.maybe_extract_custom_tool_call(input_string)
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 2
|
||||
assert result[0] == "calculate"
|
||||
assert result[1] == {"x": 10, "y": 20}
|
||||
|
||||
def test_complex_nested_structures(self):
|
||||
input_string = '[advanced_query(config={"filters": {"categories": ["books", "electronics"], "price_range": {"min": 10, "max": 500}}, "sort": {"field": "relevance", "order": "desc"}})]'
|
||||
result = ToolUtils.maybe_extract_custom_tool_call(input_string)
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 2
|
||||
assert result[0] == "advanced_query"
|
||||
|
||||
# Verify the overall structure
|
||||
assert "config" in result[1]
|
||||
assert isinstance(result[1]["config"], dict)
|
||||
|
||||
# Verify the first level of nesting
|
||||
config = result[1]["config"]
|
||||
assert "filters" in config
|
||||
assert "sort" in config
|
||||
|
||||
# Verify the second level of nesting (filters)
|
||||
filters = config["filters"]
|
||||
assert "categories" in filters
|
||||
assert "price_range" in filters
|
||||
|
||||
# Verify the list within the dict
|
||||
assert filters["categories"] == ["books", "electronics"]
|
||||
|
||||
# Verify the nested dict within another dict
|
||||
assert filters["price_range"]["min"] == 10
|
||||
assert filters["price_range"]["max"] == 500
|
||||
|
||||
# Verify the sort dictionary
|
||||
assert config["sort"]["field"] == "relevance"
|
||||
assert config["sort"]["order"] == "desc"
|
326
tests/unit/providers/nvidia/test_safety.py
Normal file
326
tests/unit/providers/nvidia/test_safety.py
Normal file
|
@ -0,0 +1,326 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import os
|
||||
import unittest
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.inference.inference import CompletionMessage, UserMessage
|
||||
from llama_stack.apis.safety import RunShieldResponse, ViolationLevel
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig
|
||||
from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter
|
||||
|
||||
|
||||
class TestNVIDIASafetyAdapter(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
|
||||
|
||||
# Initialize the adapter
|
||||
self.config = NVIDIASafetyConfig(
|
||||
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
||||
)
|
||||
self.adapter = NVIDIASafetyAdapter(config=self.config)
|
||||
self.shield_store = AsyncMock()
|
||||
self.adapter.shield_store = self.shield_store
|
||||
|
||||
# Mock the HTTP request methods
|
||||
self.guardrails_post_patcher = patch(
|
||||
"llama_stack.providers.remote.safety.nvidia.nvidia.NeMoGuardrails._guardrails_post"
|
||||
)
|
||||
self.mock_guardrails_post = self.guardrails_post_patcher.start()
|
||||
self.mock_guardrails_post.return_value = {"status": "allowed"}
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up after each test."""
|
||||
self.guardrails_post_patcher.stop()
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, run_async):
|
||||
self.run_async = run_async
|
||||
|
||||
def _assert_request(
|
||||
self,
|
||||
mock_call: MagicMock,
|
||||
expected_url: str,
|
||||
expected_headers: dict[str, str] | None = None,
|
||||
expected_json: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Helper method to verify request details in mock API calls.
|
||||
|
||||
Args:
|
||||
mock_call: The MagicMock object that was called
|
||||
expected_url: The expected URL to which the request was made
|
||||
expected_headers: Optional dictionary of expected request headers
|
||||
expected_json: Optional dictionary of expected JSON payload
|
||||
"""
|
||||
call_args = mock_call.call_args
|
||||
|
||||
# Check URL
|
||||
assert call_args[0][0] == expected_url
|
||||
|
||||
# Check headers if provided
|
||||
if expected_headers:
|
||||
for key, value in expected_headers.items():
|
||||
assert call_args[1]["headers"][key] == value
|
||||
|
||||
# Check JSON if provided
|
||||
if expected_json:
|
||||
for key, value in expected_json.items():
|
||||
if isinstance(value, dict):
|
||||
for nested_key, nested_value in value.items():
|
||||
assert call_args[1]["json"][key][nested_key] == nested_value
|
||||
else:
|
||||
assert call_args[1]["json"][key] == value
|
||||
|
||||
def test_register_shield_with_valid_id(self):
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type="shield",
|
||||
identifier="test-shield",
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
|
||||
# Register the shield
|
||||
self.run_async(self.adapter.register_shield(shield))
|
||||
|
||||
def test_register_shield_without_id(self):
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type="shield",
|
||||
identifier="test-shield",
|
||||
provider_resource_id="",
|
||||
)
|
||||
|
||||
# Register the shield should raise a ValueError
|
||||
with self.assertRaises(ValueError):
|
||||
self.run_async(self.adapter.register_shield(shield))
|
||||
|
||||
def test_run_shield_allowed(self):
|
||||
# Set up the shield
|
||||
shield_id = "test-shield"
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type="shield",
|
||||
identifier=shield_id,
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
self.shield_store.get_shield.return_value = shield
|
||||
|
||||
# Mock Guardrails API response
|
||||
self.mock_guardrails_post.return_value = {"status": "allowed"}
|
||||
|
||||
# Run the shield
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason="end_of_message",
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
result = self.run_async(self.adapter.run_shield(shield_id, messages))
|
||||
|
||||
# Verify the shield store was called
|
||||
self.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
|
||||
# Verify the Guardrails API was called correctly
|
||||
self.mock_guardrails_post.assert_called_once_with(
|
||||
path="/v1/guardrail/checks",
|
||||
data={
|
||||
"model": shield_id,
|
||||
"messages": [
|
||||
json.loads(messages[0].model_dump_json()),
|
||||
json.loads(messages[1].model_dump_json()),
|
||||
],
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"max_tokens": 160,
|
||||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": "self-check",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, RunShieldResponse)
|
||||
assert result.violation is None
|
||||
|
||||
def test_run_shield_blocked(self):
|
||||
# Set up the shield
|
||||
shield_id = "test-shield"
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type="shield",
|
||||
identifier=shield_id,
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
self.shield_store.get_shield.return_value = shield
|
||||
|
||||
# Mock Guardrails API response
|
||||
self.mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}}
|
||||
|
||||
# Run the shield
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason="end_of_message",
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
result = self.run_async(self.adapter.run_shield(shield_id, messages))
|
||||
|
||||
# Verify the shield store was called
|
||||
self.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
|
||||
# Verify the Guardrails API was called correctly
|
||||
self.mock_guardrails_post.assert_called_once_with(
|
||||
path="/v1/guardrail/checks",
|
||||
data={
|
||||
"model": shield_id,
|
||||
"messages": [
|
||||
json.loads(messages[0].model_dump_json()),
|
||||
json.loads(messages[1].model_dump_json()),
|
||||
],
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"max_tokens": 160,
|
||||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": "self-check",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result.violation is not None
|
||||
assert isinstance(result, RunShieldResponse)
|
||||
assert result.violation.user_message == "Sorry I cannot do this."
|
||||
assert result.violation.violation_level == ViolationLevel.ERROR
|
||||
assert result.violation.metadata == {"reason": "harmful_content"}
|
||||
|
||||
def test_run_shield_not_found(self):
|
||||
# Set up shield store to return None
|
||||
shield_id = "non-existent-shield"
|
||||
self.shield_store.get_shield.return_value = None
|
||||
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
]
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
self.run_async(self.adapter.run_shield(shield_id, messages))
|
||||
|
||||
# Verify the shield store was called
|
||||
self.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
|
||||
# Verify the Guardrails API was not called
|
||||
self.mock_guardrails_post.assert_not_called()
|
||||
|
||||
def test_run_shield_http_error(self):
|
||||
shield_id = "test-shield"
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type="shield",
|
||||
identifier=shield_id,
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
self.shield_store.get_shield.return_value = shield
|
||||
|
||||
# Mock Guardrails API to raise an exception
|
||||
error_msg = "API Error: 500 Internal Server Error"
|
||||
self.mock_guardrails_post.side_effect = Exception(error_msg)
|
||||
|
||||
# Running the shield should raise an exception
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason="end_of_message",
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
with self.assertRaises(Exception) as context:
|
||||
self.run_async(self.adapter.run_shield(shield_id, messages))
|
||||
|
||||
# Verify the shield store was called
|
||||
self.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
|
||||
# Verify the Guardrails API was called correctly
|
||||
self.mock_guardrails_post.assert_called_once_with(
|
||||
path="/v1/guardrail/checks",
|
||||
data={
|
||||
"model": shield_id,
|
||||
"messages": [
|
||||
json.loads(messages[0].model_dump_json()),
|
||||
json.loads(messages[1].model_dump_json()),
|
||||
],
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"max_tokens": 160,
|
||||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": "self-check",
|
||||
},
|
||||
},
|
||||
)
|
||||
# Verify the exception message
|
||||
assert error_msg in str(context.exception)
|
||||
|
||||
def test_init_nemo_guardrails(self):
|
||||
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
|
||||
|
||||
test_config_id = "test-custom-config-id"
|
||||
config = NVIDIASafetyConfig(
|
||||
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
||||
config_id=test_config_id,
|
||||
)
|
||||
# Initialize with default parameters
|
||||
test_model = "test-model"
|
||||
guardrails = NeMoGuardrails(config, test_model)
|
||||
|
||||
# Verify the attributes are set correctly
|
||||
assert guardrails.config_id == test_config_id
|
||||
assert guardrails.model == test_model
|
||||
assert guardrails.threshold == 0.9 # Default value
|
||||
assert guardrails.temperature == 1.0 # Default value
|
||||
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
|
||||
|
||||
# Initialize with custom parameters
|
||||
guardrails = NeMoGuardrails(config, test_model, threshold=0.8, temperature=0.7)
|
||||
|
||||
# Verify the attributes are set correctly
|
||||
assert guardrails.config_id == test_config_id
|
||||
assert guardrails.model == test_model
|
||||
assert guardrails.threshold == 0.8
|
||||
assert guardrails.temperature == 0.7
|
||||
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
|
||||
|
||||
def test_init_nemo_guardrails_invalid_temperature(self):
|
||||
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
|
||||
|
||||
config = NVIDIASafetyConfig(
|
||||
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
||||
config_id="test-custom-config-id",
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
NeMoGuardrails(config, "test-model", temperature=0)
|
120
tests/unit/providers/utils/test_scheduler.py
Normal file
120
tests/unit/providers/utils/test_scheduler.py
Normal file
|
@ -0,0 +1,120 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.providers.utils.scheduler import JobStatus, Scheduler
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scheduler_unknown_backend():
|
||||
with pytest.raises(ValueError):
|
||||
Scheduler(backend="unknown")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scheduler_naive():
|
||||
sched = Scheduler()
|
||||
|
||||
# make sure the scheduler starts empty
|
||||
with pytest.raises(ValueError):
|
||||
sched.get_job("unknown")
|
||||
assert sched.get_jobs() == []
|
||||
|
||||
called = False
|
||||
|
||||
# schedule a job that will exercise the handlers
|
||||
async def job_handler(on_log, on_status, on_artifact):
|
||||
nonlocal called
|
||||
called = True
|
||||
# exercise the handlers
|
||||
on_log("test log1")
|
||||
on_log("test log2")
|
||||
on_artifact({"type": "type1", "path": "path1"})
|
||||
on_artifact({"type": "type2", "path": "path2"})
|
||||
on_status(JobStatus.completed)
|
||||
|
||||
job_id = "test_job_id"
|
||||
job_type = "test_job_type"
|
||||
sched.schedule(job_type, job_id, job_handler)
|
||||
|
||||
# make sure the job was properly registered
|
||||
with pytest.raises(ValueError):
|
||||
sched.get_job("unknown")
|
||||
assert sched.get_job(job_id) is not None
|
||||
assert sched.get_jobs() == [sched.get_job(job_id)]
|
||||
|
||||
assert sched.get_jobs("unknown") == []
|
||||
assert sched.get_jobs(job_type) == [sched.get_job(job_id)]
|
||||
|
||||
# now shut the scheduler down and make sure the job ran
|
||||
await sched.shutdown()
|
||||
|
||||
assert called
|
||||
|
||||
job = sched.get_job(job_id)
|
||||
assert job is not None
|
||||
|
||||
assert job.status == JobStatus.completed
|
||||
|
||||
assert job.scheduled_at is not None
|
||||
assert job.started_at is not None
|
||||
assert job.completed_at is not None
|
||||
assert job.scheduled_at < job.started_at < job.completed_at
|
||||
|
||||
assert job.artifacts == [
|
||||
{"type": "type1", "path": "path1"},
|
||||
{"type": "type2", "path": "path2"},
|
||||
]
|
||||
assert [msg[1] for msg in job.logs] == ["test log1", "test log2"]
|
||||
assert job.logs[0][0] < job.logs[1][0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scheduler_naive_handler_raises():
|
||||
sched = Scheduler()
|
||||
|
||||
async def failing_job_handler(on_log, on_status, on_artifact):
|
||||
on_status(JobStatus.running)
|
||||
raise ValueError("test error")
|
||||
|
||||
job_id = "test_job_id1"
|
||||
job_type = "test_job_type"
|
||||
sched.schedule(job_type, job_id, failing_job_handler)
|
||||
|
||||
job = sched.get_job(job_id)
|
||||
assert job is not None
|
||||
|
||||
# confirm the exception made the job transition to failed state, even
|
||||
# though it was set to `running` before the error
|
||||
for _ in range(10):
|
||||
if job.status == JobStatus.failed:
|
||||
break
|
||||
await asyncio.sleep(0.1)
|
||||
assert job.status == JobStatus.failed
|
||||
|
||||
# confirm that the raised error got registered in log
|
||||
assert job.logs[0][1] == "test error"
|
||||
|
||||
# even after failed job, we can schedule another one
|
||||
called = False
|
||||
|
||||
async def successful_job_handler(on_log, on_status, on_artifact):
|
||||
nonlocal called
|
||||
called = True
|
||||
on_status(JobStatus.completed)
|
||||
|
||||
job_id = "test_job_id2"
|
||||
sched.schedule(job_type, job_id, successful_job_handler)
|
||||
|
||||
await sched.shutdown()
|
||||
|
||||
assert called
|
||||
job = sched.get_job(job_id)
|
||||
assert job is not None
|
||||
assert job.status == JobStatus.completed
|
Loading…
Add table
Add a link
Reference in a new issue