Merge branch 'main' into add-llama-guard-4-model

This commit is contained in:
raghotham 2025-07-03 10:52:01 -07:00 committed by GitHub
commit bae3c766bc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
72 changed files with 990 additions and 337 deletions

View file

@ -9,7 +9,9 @@ pytest --help
```
Here are the most important options:
- `--stack-config`: specify the stack config to use. You have three ways to point to a stack:
- `--stack-config`: specify the stack config to use. You have four ways to point to a stack:
- **`server:<config>`** - automatically start a server with the given config (e.g., `server:fireworks`). This provides one-step testing by auto-starting the server if the port is available, or reusing an existing server if already running.
- **`server:<config>:<port>`** - same as above but with a custom port (e.g., `server:together:8322`)
- a URL which points to a Llama Stack distribution server
- a template (e.g., `fireworks`, `together`) or a path to a `run.yaml` file
- a comma-separated list of api=provider pairs, e.g. `inference=fireworks,safety=llama-guard,agents=meta-reference`. This is most useful for testing a single API surface.
@ -26,12 +28,39 @@ Model parameters can be influenced by the following options:
Each of these are comma-separated lists and can be used to generate multiple parameter combinations. Note that tests will be skipped
if no model is specified.
Experimental, under development, options:
- `--record-responses`: record new API responses instead of using cached ones
## Examples
### Testing against a Server
Run all text inference tests by auto-starting a server with the `fireworks` config:
```bash
pytest -s -v tests/integration/inference/test_text_inference.py \
--stack-config=server:fireworks \
--text-model=meta-llama/Llama-3.1-8B-Instruct
```
Run tests with auto-server startup on a custom port:
```bash
pytest -s -v tests/integration/inference/ \
--stack-config=server:together:8322 \
--text-model=meta-llama/Llama-3.1-8B-Instruct
```
Run multiple test suites with auto-server (eliminates manual server management):
```bash
# Auto-start server and run all integration tests
export FIREWORKS_API_KEY=<your_key>
pytest -s -v tests/integration/inference/ tests/integration/safety/ tests/integration/agents/ \
--stack-config=server:fireworks \
--text-model=meta-llama/Llama-3.1-8B-Instruct
```
### Testing with Library Client
Run all text inference tests with the `together` distribution:
```bash

View file

@ -6,9 +6,13 @@
import inspect
import os
import socket
import subprocess
import tempfile
import time
import pytest
import requests
import yaml
from llama_stack_client import LlamaStackClient
from openai import OpenAI
@ -17,6 +21,60 @@ from llama_stack import LlamaStackAsLibraryClient
from llama_stack.distribution.stack import run_config_from_adhoc_config_spec
from llama_stack.env import get_env_or_fail
DEFAULT_PORT = 8321
def is_port_available(port: int, host: str = "localhost") -> bool:
"""Check if a port is available for binding."""
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind((host, port))
return True
except OSError:
return False
def start_llama_stack_server(config_name: str) -> subprocess.Popen:
"""Start a llama stack server with the given config."""
cmd = ["llama", "stack", "run", config_name]
devnull = open(os.devnull, "w")
process = subprocess.Popen(
cmd,
stdout=devnull, # redirect stdout to devnull to prevent deadlock
stderr=devnull, # redirect stderr to devnull to prevent deadlock
text=True,
env={**os.environ, "LLAMA_STACK_LOG_FILE": "server.log"},
)
return process
def wait_for_server_ready(base_url: str, timeout: int = 30, process: subprocess.Popen | None = None) -> bool:
"""Wait for the server to be ready by polling the health endpoint."""
health_url = f"{base_url}/v1/health"
start_time = time.time()
while time.time() - start_time < timeout:
if process and process.poll() is not None:
print(f"Server process terminated with return code: {process.returncode}")
return False
try:
response = requests.get(health_url, timeout=5)
if response.status_code == 200:
return True
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
pass
# Print progress every 5 seconds
elapsed = time.time() - start_time
if int(elapsed) % 5 == 0 and elapsed > 0:
print(f"Waiting for server at {base_url}... ({elapsed:.1f}s elapsed)")
time.sleep(0.5)
print(f"Server failed to respond within {timeout} seconds")
return False
@pytest.fixture(scope="session")
def provider_data():
@ -122,6 +180,41 @@ def llama_stack_client(request, provider_data):
if not config:
raise ValueError("You must specify either --stack-config or LLAMA_STACK_CONFIG")
# Handle server:<config_name> format or server:<config_name>:<port>
if config.startswith("server:"):
parts = config.split(":")
config_name = parts[1]
port = int(parts[2]) if len(parts) > 2 else int(os.environ.get("LLAMA_STACK_PORT", DEFAULT_PORT))
base_url = f"http://localhost:{port}"
# Check if port is available
if is_port_available(port):
print(f"Starting llama stack server with config '{config_name}' on port {port}...")
# Start server
server_process = start_llama_stack_server(config_name)
# Wait for server to be ready
if not wait_for_server_ready(base_url, timeout=30, process=server_process):
print("Server failed to start within timeout")
server_process.terminate()
raise RuntimeError(
f"Server failed to start within timeout. Check that config '{config_name}' exists and is valid. "
f"See server.log for details."
)
print(f"Server is ready at {base_url}")
# Store process for potential cleanup (pytest will handle termination at session end)
request.session._llama_stack_server_process = server_process
else:
print(f"Port {port} is already in use, assuming server is already running...")
return LlamaStackClient(
base_url=base_url,
provider_data=provider_data,
)
# check if this looks like a URL
if config.startswith("http") or "//" in config:
return LlamaStackClient(
@ -151,3 +244,31 @@ def llama_stack_client(request, provider_data):
def openai_client(client_with_models):
base_url = f"{client_with_models.base_url}/v1/openai/v1"
return OpenAI(base_url=base_url, api_key="fake")
@pytest.fixture(scope="session", autouse=True)
def cleanup_server_process(request):
"""Cleanup server process at the end of the test session."""
yield # Run tests
if hasattr(request.session, "_llama_stack_server_process"):
server_process = request.session._llama_stack_server_process
if server_process:
if server_process.poll() is None:
print("Terminating llama stack server process...")
else:
print(f"Server process already terminated with return code: {server_process.returncode}")
return
try:
server_process.terminate()
server_process.wait(timeout=10)
print("Server process terminated gracefully")
except subprocess.TimeoutExpired:
print("Server process did not terminate gracefully, killing it")
server_process.kill()
server_process.wait()
print("Server process killed")
except Exception as e:
print(f"Error during server cleanup: {e}")
else:
print("Server process not found - won't be able to cleanup")

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
import pytest_asyncio
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry, DiskDistributionRegistry
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
@pytest.fixture(scope="function")
@pytest_asyncio.fixture(scope="function")
async def sqlite_kvstore(tmp_path):
db_path = tmp_path / "test_kv.db"
kvstore_config = SqliteKVStoreConfig(db_path=db_path.as_posix())
@ -20,14 +20,14 @@ async def sqlite_kvstore(tmp_path):
yield kvstore
@pytest.fixture(scope="function")
@pytest_asyncio.fixture(scope="function")
async def disk_dist_registry(sqlite_kvstore):
registry = DiskDistributionRegistry(sqlite_kvstore)
await registry.initialize()
yield registry
@pytest.fixture(scope="function")
@pytest_asyncio.fixture(scope="function")
async def cached_disk_dist_registry(sqlite_kvstore):
registry = CachedDiskDistributionRegistry(sqlite_kvstore)
await registry.initialize()

View file

@ -9,6 +9,7 @@ from datetime import datetime
from unittest.mock import patch
import pytest
import pytest_asyncio
from llama_stack.apis.agents import Turn
from llama_stack.apis.inference import CompletionMessage, StopReason
@ -16,7 +17,7 @@ from llama_stack.distribution.datatypes import User
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
@pytest.fixture
@pytest_asyncio.fixture
async def test_setup(sqlite_kvstore):
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore, policy={})
yield agent_persistence

View file

@ -148,7 +148,7 @@ async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dime
assert len(chunk_ids) == len(set(chunk_ids)), "Duplicate chunk IDs detected across batches!"
@pytest.fixture(scope="session")
@pytest_asyncio.fixture(scope="session")
async def sqlite_vec_adapter(sqlite_connection):
config = type("Config", (object,), {"db_path": ":memory:"}) # Mock config with in-memory database
adapter = SQLiteVecVectorIOAdapter(config=config, inference_api=None)

View file

@ -7,6 +7,7 @@
from unittest.mock import MagicMock, Mock, patch
import pytest
import pytest_asyncio
import yaml
from pydantic import TypeAdapter, ValidationError
@ -26,7 +27,7 @@ def _return_model(model):
return model
@pytest.fixture
@pytest_asyncio.fixture
async def test_setup(cached_disk_dist_registry):
mock_inference = Mock()
mock_inference.__provider_spec__ = MagicMock()
@ -245,7 +246,7 @@ async def test_automatic_access_attributes(mock_get_authenticated_user, test_set
assert model.identifier == "auto-access-model"
@pytest.fixture
@pytest_asyncio.fixture
async def test_setup_with_access_policy(cached_disk_dist_registry):
mock_inference = Mock()
mock_inference.__provider_spec__ = MagicMock()

View file

@ -0,0 +1,187 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import Mock
from fastapi import HTTPException
from openai import BadRequestError
from pydantic import ValidationError
from llama_stack.distribution.access_control.access_control import AccessDeniedError
from llama_stack.distribution.datatypes import AuthenticationRequiredError
from llama_stack.distribution.server.server import translate_exception
class TestTranslateException:
"""Test cases for the translate_exception function."""
def test_translate_access_denied_error(self):
"""Test that AccessDeniedError is translated to 403 HTTP status."""
exc = AccessDeniedError()
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 403
assert result.detail == "Permission denied: Insufficient permissions"
def test_translate_access_denied_error_with_context(self):
"""Test that AccessDeniedError with context includes detailed information."""
from llama_stack.distribution.datatypes import User
# Create mock user and resource
user = User("test-user", {"roles": ["user"], "teams": ["dev"]})
# Create a simple mock object that implements the ProtectedResource protocol
class MockResource:
def __init__(self, type: str, identifier: str, owner=None):
self.type = type
self.identifier = identifier
self.owner = owner
resource = MockResource("vector_db", "test-db")
exc = AccessDeniedError("create", resource, user)
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 403
assert "test-user" in result.detail
assert "vector_db::test-db" in result.detail
assert "create" in result.detail
assert "roles=['user']" in result.detail
assert "teams=['dev']" in result.detail
def test_translate_permission_error(self):
"""Test that PermissionError is translated to 403 HTTP status."""
exc = PermissionError("Permission denied")
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 403
assert result.detail == "Permission denied: Permission denied"
def test_translate_value_error(self):
"""Test that ValueError is translated to 400 HTTP status."""
exc = ValueError("Invalid input")
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 400
assert result.detail == "Invalid value: Invalid input"
def test_translate_bad_request_error(self):
"""Test that BadRequestError is translated to 400 HTTP status."""
# Create a mock response for BadRequestError
mock_response = Mock()
mock_response.status_code = 400
mock_response.headers = {}
exc = BadRequestError("Bad request", response=mock_response, body="Bad request")
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 400
assert result.detail == "Bad request"
def test_translate_authentication_required_error(self):
"""Test that AuthenticationRequiredError is translated to 401 HTTP status."""
exc = AuthenticationRequiredError("Authentication required")
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 401
assert result.detail == "Authentication required: Authentication required"
def test_translate_timeout_error(self):
"""Test that TimeoutError is translated to 504 HTTP status."""
exc = TimeoutError("Operation timed out")
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 504
assert result.detail == "Operation timed out: Operation timed out"
def test_translate_asyncio_timeout_error(self):
"""Test that asyncio.TimeoutError is translated to 504 HTTP status."""
exc = TimeoutError()
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 504
assert result.detail == "Operation timed out: "
def test_translate_not_implemented_error(self):
"""Test that NotImplementedError is translated to 501 HTTP status."""
exc = NotImplementedError("Not implemented")
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 501
assert result.detail == "Not implemented: Not implemented"
def test_translate_validation_error(self):
"""Test that ValidationError is translated to 400 HTTP status with proper format."""
# Create a mock validation error using proper Pydantic error format
exc = ValidationError.from_exception_data(
"TestModel",
[
{
"loc": ("field", "nested"),
"msg": "field required",
"type": "missing",
}
],
)
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 400
assert "errors" in result.detail
assert len(result.detail["errors"]) == 1
assert result.detail["errors"][0]["loc"] == ["field", "nested"]
assert result.detail["errors"][0]["msg"] == "Field required"
assert result.detail["errors"][0]["type"] == "missing"
def test_translate_generic_exception(self):
"""Test that generic exceptions are translated to 500 HTTP status."""
exc = Exception("Unexpected error")
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 500
assert result.detail == "Internal server error: An unexpected error occurred."
def test_translate_runtime_error(self):
"""Test that RuntimeError is translated to 500 HTTP status."""
exc = RuntimeError("Runtime error")
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 500
assert result.detail == "Internal server error: An unexpected error occurred."
def test_multiple_access_denied_scenarios(self):
"""Test various scenarios that should result in 403 status codes."""
# Test AccessDeniedError (uses enhanced message)
exc1 = AccessDeniedError()
result1 = translate_exception(exc1)
assert isinstance(result1, HTTPException)
assert result1.status_code == 403
assert result1.detail == "Permission denied: Insufficient permissions"
# Test PermissionError (uses generic message)
exc2 = PermissionError("No permission")
result2 = translate_exception(exc2)
assert isinstance(result2, HTTPException)
assert result2.status_code == 403
assert result2.detail == "Permission denied: No permission"
exc3 = PermissionError("Access denied")
result3 = translate_exception(exc3)
assert isinstance(result3, HTTPException)
assert result3.status_code == 403
assert result3.detail == "Permission denied: Access denied"