mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-20 22:32:27 +00:00
Merge branch 'main' into vllm_health_check
This commit is contained in:
commit
c18b585d32
143 changed files with 9210 additions and 5347 deletions
|
|
@ -121,7 +121,7 @@ class ToolGroupsImpl(Impl):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_models_routing_table(cached_disk_dist_registry):
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry)
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple models and verify listing
|
||||
|
|
@ -163,7 +163,7 @@ async def test_models_routing_table(cached_disk_dist_registry):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shields_routing_table(cached_disk_dist_registry):
|
||||
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry)
|
||||
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple shields and verify listing
|
||||
|
|
@ -179,14 +179,14 @@ async def test_shields_routing_table(cached_disk_dist_registry):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
||||
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry)
|
||||
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
m_table = ModelsRoutingTable({"test_providere": InferenceImpl()}, cached_disk_dist_registry)
|
||||
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await m_table.initialize()
|
||||
await m_table.register_model(
|
||||
model_id="test-model",
|
||||
provider_id="test_providere",
|
||||
provider_id="test_provider",
|
||||
metadata={"embedding_dimension": 128},
|
||||
model_type=ModelType.embedding,
|
||||
)
|
||||
|
|
@ -209,7 +209,7 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
|||
|
||||
|
||||
async def test_datasets_routing_table(cached_disk_dist_registry):
|
||||
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry)
|
||||
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple datasets and verify listing
|
||||
|
|
@ -235,7 +235,7 @@ async def test_datasets_routing_table(cached_disk_dist_registry):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoring_functions_routing_table(cached_disk_dist_registry):
|
||||
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry)
|
||||
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple scoring functions and verify listing
|
||||
|
|
@ -261,7 +261,7 @@ async def test_scoring_functions_routing_table(cached_disk_dist_registry):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_benchmarks_routing_table(cached_disk_dist_registry):
|
||||
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry)
|
||||
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple benchmarks and verify listing
|
||||
|
|
@ -279,7 +279,7 @@ async def test_benchmarks_routing_table(cached_disk_dist_registry):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_groups_routing_table(cached_disk_dist_registry):
|
||||
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry)
|
||||
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register multiple tool groups and verify listing
|
||||
|
|
|
|||
5
tests/unit/files/__init__.py
Normal file
5
tests/unit/files/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
334
tests/unit/files/test_files.py
Normal file
334
tests/unit/files/test_files.py
Normal file
|
|
@ -0,0 +1,334 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.apis.files import OpenAIFilePurpose
|
||||
from llama_stack.providers.inline.files.localfs import (
|
||||
LocalfsFilesImpl,
|
||||
LocalfsFilesImplConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
|
||||
|
||||
class MockUploadFile:
|
||||
"""Mock UploadFile for testing file uploads."""
|
||||
|
||||
def __init__(self, content: bytes, filename: str, content_type: str = "text/plain"):
|
||||
self.content = content
|
||||
self.filename = filename
|
||||
self.content_type = content_type
|
||||
|
||||
async def read(self):
|
||||
return self.content
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def files_provider(tmp_path):
|
||||
"""Create a files provider with temporary storage for testing."""
|
||||
storage_dir = tmp_path / "files"
|
||||
db_path = tmp_path / "files_metadata.db"
|
||||
|
||||
config = LocalfsFilesImplConfig(
|
||||
storage_dir=storage_dir.as_posix(), metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix())
|
||||
)
|
||||
|
||||
provider = LocalfsFilesImpl(config)
|
||||
await provider.initialize()
|
||||
yield provider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_text_file():
|
||||
"""Sample text file for testing."""
|
||||
content = b"Hello, this is a test file for the OpenAI Files API!"
|
||||
return MockUploadFile(content, "test.txt", "text/plain")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_json_file():
|
||||
"""Sample JSON file for testing."""
|
||||
content = b'{"message": "Hello, World!", "type": "test"}'
|
||||
return MockUploadFile(content, "data.json", "application/json")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def large_file():
|
||||
"""Large file for testing file size handling."""
|
||||
content = b"x" * 1024 * 1024 # 1MB file
|
||||
return MockUploadFile(content, "large_file.bin", "application/octet-stream")
|
||||
|
||||
|
||||
class TestOpenAIFilesAPI:
|
||||
"""Test suite for OpenAI Files API endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_file_success(self, files_provider, sample_text_file):
|
||||
"""Test successful file upload."""
|
||||
# Upload file
|
||||
result = await files_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
|
||||
|
||||
# Verify response
|
||||
assert result.id.startswith("file-")
|
||||
assert result.filename == "test.txt"
|
||||
assert result.purpose == OpenAIFilePurpose.ASSISTANTS
|
||||
assert result.bytes == len(sample_text_file.content)
|
||||
assert result.created_at > 0
|
||||
assert result.expires_at > result.created_at
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_different_purposes(self, files_provider, sample_text_file):
|
||||
"""Test uploading files with different purposes."""
|
||||
purposes = list(OpenAIFilePurpose)
|
||||
|
||||
uploaded_files = []
|
||||
for purpose in purposes:
|
||||
result = await files_provider.openai_upload_file(file=sample_text_file, purpose=purpose)
|
||||
uploaded_files.append(result)
|
||||
assert result.purpose == purpose
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_different_file_types(self, files_provider, sample_text_file, sample_json_file, large_file):
|
||||
"""Test uploading different types and sizes of files."""
|
||||
files_to_test = [
|
||||
(sample_text_file, "test.txt"),
|
||||
(sample_json_file, "data.json"),
|
||||
(large_file, "large_file.bin"),
|
||||
]
|
||||
|
||||
for file_obj, expected_filename in files_to_test:
|
||||
result = await files_provider.openai_upload_file(file=file_obj, purpose=OpenAIFilePurpose.ASSISTANTS)
|
||||
assert result.filename == expected_filename
|
||||
assert result.bytes == len(file_obj.content)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_files_empty(self, files_provider):
|
||||
"""Test listing files when no files exist."""
|
||||
result = await files_provider.openai_list_files()
|
||||
|
||||
assert result.data == []
|
||||
assert result.has_more is False
|
||||
assert result.first_id == ""
|
||||
assert result.last_id == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_files_with_content(self, files_provider, sample_text_file, sample_json_file):
|
||||
"""Test listing files when files exist."""
|
||||
# Upload multiple files
|
||||
file1 = await files_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
|
||||
file2 = await files_provider.openai_upload_file(file=sample_json_file, purpose=OpenAIFilePurpose.ASSISTANTS)
|
||||
|
||||
# List files
|
||||
result = await files_provider.openai_list_files()
|
||||
|
||||
assert len(result.data) == 2
|
||||
file_ids = [f.id for f in result.data]
|
||||
assert file1.id in file_ids
|
||||
assert file2.id in file_ids
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_files_with_purpose_filter(self, files_provider, sample_text_file):
|
||||
"""Test listing files with purpose filtering."""
|
||||
# Upload file with specific purpose
|
||||
uploaded_file = await files_provider.openai_upload_file(
|
||||
file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS
|
||||
)
|
||||
|
||||
# List files with matching purpose
|
||||
result = await files_provider.openai_list_files(purpose=OpenAIFilePurpose.ASSISTANTS)
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].id == uploaded_file.id
|
||||
assert result.data[0].purpose == OpenAIFilePurpose.ASSISTANTS
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_files_with_limit(self, files_provider, sample_text_file):
|
||||
"""Test listing files with limit parameter."""
|
||||
# Upload multiple files
|
||||
for _ in range(5):
|
||||
await files_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
|
||||
|
||||
# List with limit
|
||||
result = await files_provider.openai_list_files(limit=3)
|
||||
assert len(result.data) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_files_with_order(self, files_provider, sample_text_file):
|
||||
"""Test listing files with different order."""
|
||||
# Upload multiple files
|
||||
files = []
|
||||
for _ in range(3):
|
||||
file = await files_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
|
||||
files.append(file)
|
||||
|
||||
# Test descending order (default)
|
||||
result_desc = await files_provider.openai_list_files(order=Order.desc)
|
||||
assert len(result_desc.data) == 3
|
||||
# Most recent should be first
|
||||
assert result_desc.data[0].created_at >= result_desc.data[1].created_at >= result_desc.data[2].created_at
|
||||
|
||||
# Test ascending order
|
||||
result_asc = await files_provider.openai_list_files(order=Order.asc)
|
||||
assert len(result_asc.data) == 3
|
||||
# Oldest should be first
|
||||
assert result_asc.data[0].created_at <= result_asc.data[1].created_at <= result_asc.data[2].created_at
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_file_success(self, files_provider, sample_text_file):
|
||||
"""Test successful file retrieval."""
|
||||
# Upload file
|
||||
uploaded_file = await files_provider.openai_upload_file(
|
||||
file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS
|
||||
)
|
||||
|
||||
# Retrieve file
|
||||
retrieved_file = await files_provider.openai_retrieve_file(uploaded_file.id)
|
||||
|
||||
# Verify response
|
||||
assert retrieved_file.id == uploaded_file.id
|
||||
assert retrieved_file.filename == uploaded_file.filename
|
||||
assert retrieved_file.purpose == uploaded_file.purpose
|
||||
assert retrieved_file.bytes == uploaded_file.bytes
|
||||
assert retrieved_file.created_at == uploaded_file.created_at
|
||||
assert retrieved_file.expires_at == uploaded_file.expires_at
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_file_not_found(self, files_provider):
|
||||
"""Test retrieving a non-existent file."""
|
||||
with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
|
||||
await files_provider.openai_retrieve_file("file-nonexistent")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_file_content_success(self, files_provider, sample_text_file):
|
||||
"""Test successful file content retrieval."""
|
||||
# Upload file
|
||||
uploaded_file = await files_provider.openai_upload_file(
|
||||
file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS
|
||||
)
|
||||
|
||||
# Retrieve file content
|
||||
content = await files_provider.openai_retrieve_file_content(uploaded_file.id)
|
||||
|
||||
# Verify content
|
||||
assert content.body == sample_text_file.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_file_content_not_found(self, files_provider):
|
||||
"""Test retrieving content of a non-existent file."""
|
||||
with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
|
||||
await files_provider.openai_retrieve_file_content("file-nonexistent")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_file_success(self, files_provider, sample_text_file):
|
||||
"""Test successful file deletion."""
|
||||
# Upload file
|
||||
uploaded_file = await files_provider.openai_upload_file(
|
||||
file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS
|
||||
)
|
||||
|
||||
# Verify file exists
|
||||
await files_provider.openai_retrieve_file(uploaded_file.id)
|
||||
|
||||
# Delete file
|
||||
delete_response = await files_provider.openai_delete_file(uploaded_file.id)
|
||||
|
||||
# Verify delete response
|
||||
assert delete_response.id == uploaded_file.id
|
||||
assert delete_response.deleted is True
|
||||
|
||||
# Verify file no longer exists
|
||||
with pytest.raises(ValueError, match=f"File with id {uploaded_file.id} not found"):
|
||||
await files_provider.openai_retrieve_file(uploaded_file.id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_file_not_found(self, files_provider):
|
||||
"""Test deleting a non-existent file."""
|
||||
with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
|
||||
await files_provider.openai_delete_file("file-nonexistent")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_persistence_across_operations(self, files_provider, sample_text_file):
|
||||
"""Test that files persist correctly across multiple operations."""
|
||||
# Upload file
|
||||
uploaded_file = await files_provider.openai_upload_file(
|
||||
file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS
|
||||
)
|
||||
|
||||
# Verify it appears in listing
|
||||
files_list = await files_provider.openai_list_files()
|
||||
assert len(files_list.data) == 1
|
||||
assert files_list.data[0].id == uploaded_file.id
|
||||
|
||||
# Retrieve file info
|
||||
retrieved_file = await files_provider.openai_retrieve_file(uploaded_file.id)
|
||||
assert retrieved_file.id == uploaded_file.id
|
||||
|
||||
# Retrieve file content
|
||||
content = await files_provider.openai_retrieve_file_content(uploaded_file.id)
|
||||
assert content.body == sample_text_file.content
|
||||
|
||||
# Delete file
|
||||
await files_provider.openai_delete_file(uploaded_file.id)
|
||||
|
||||
# Verify it's gone from listing
|
||||
files_list = await files_provider.openai_list_files()
|
||||
assert len(files_list.data) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_files_operations(self, files_provider, sample_text_file, sample_json_file):
|
||||
"""Test operations with multiple files."""
|
||||
# Upload multiple files
|
||||
file1 = await files_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
|
||||
file2 = await files_provider.openai_upload_file(file=sample_json_file, purpose=OpenAIFilePurpose.ASSISTANTS)
|
||||
|
||||
# Verify both exist
|
||||
files_list = await files_provider.openai_list_files()
|
||||
assert len(files_list.data) == 2
|
||||
|
||||
# Delete one file
|
||||
await files_provider.openai_delete_file(file1.id)
|
||||
|
||||
# Verify only one remains
|
||||
files_list = await files_provider.openai_list_files()
|
||||
assert len(files_list.data) == 1
|
||||
assert files_list.data[0].id == file2.id
|
||||
|
||||
# Verify the remaining file is still accessible
|
||||
content = await files_provider.openai_retrieve_file_content(file2.id)
|
||||
assert content.body == sample_json_file.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_id_uniqueness(self, files_provider, sample_text_file):
|
||||
"""Test that each uploaded file gets a unique ID."""
|
||||
file_ids = set()
|
||||
|
||||
# Upload same file multiple times
|
||||
for _ in range(10):
|
||||
uploaded_file = await files_provider.openai_upload_file(
|
||||
file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS
|
||||
)
|
||||
assert uploaded_file.id not in file_ids, f"Duplicate file ID: {uploaded_file.id}"
|
||||
file_ids.add(uploaded_file.id)
|
||||
assert uploaded_file.id.startswith("file-")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_no_filename_handling(self, files_provider):
|
||||
"""Test handling files with no filename."""
|
||||
file_without_name = MockUploadFile(b"content", None) # No filename
|
||||
|
||||
uploaded_file = await files_provider.openai_upload_file(
|
||||
file=file_without_name, purpose=OpenAIFilePurpose.ASSISTANTS
|
||||
)
|
||||
|
||||
assert uploaded_file.filename == "uploaded_file" # Default filename
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_after_pagination_not_implemented(self, files_provider):
|
||||
"""Test that 'after' pagination raises NotImplementedError."""
|
||||
with pytest.raises(NotImplementedError, match="After pagination not yet implemented"):
|
||||
await files_provider.openai_list_files(after="file-some-id")
|
||||
177
tests/unit/models/llama/test_tokenizer_utils.py
Normal file
177
tests/unit/models/llama/test_tokenizer_utils.py
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from tiktoken.load import load_tiktoken_bpe
|
||||
|
||||
from llama_stack.models.llama.tokenizer_utils import load_bpe_file
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_bpe_content():
|
||||
"""Sample BPE file content for testing."""
|
||||
return """wA== 0
|
||||
wQ== 1
|
||||
9Q== 2
|
||||
9g== 3
|
||||
9w== 4
|
||||
+A== 5
|
||||
+Q== 6
|
||||
+g== 7
|
||||
+w== 8
|
||||
/A== 9
|
||||
/Q== 10
|
||||
/g== 11
|
||||
/w== 12
|
||||
AA== 13
|
||||
AQ== 14"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_bpe_file(tmp_path, test_bpe_content):
|
||||
"""Create a temporary BPE file for testing."""
|
||||
bpe_file = tmp_path / "test_tokenizer.model"
|
||||
bpe_file.write_text(test_bpe_content, encoding="utf-8")
|
||||
return bpe_file
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llama3_model_path():
|
||||
"""Path to Llama3 tokenizer model."""
|
||||
return Path(__file__).parent / "../../../../llama_stack/models/llama/llama3/tokenizer.model"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llama4_model_path():
|
||||
"""Path to Llama4 tokenizer model."""
|
||||
return Path(__file__).parent / "../../../../llama_stack/models/llama/llama4/tokenizer.model"
|
||||
|
||||
|
||||
def test_load_bpe_file_basic_functionality(test_bpe_file):
|
||||
"""Test that load_bpe_file correctly parses BPE files."""
|
||||
result = load_bpe_file(test_bpe_file)
|
||||
|
||||
for key, value in result.items():
|
||||
assert isinstance(key, bytes)
|
||||
assert isinstance(value, int)
|
||||
|
||||
assert len(result) == 15
|
||||
|
||||
expected_first_token = base64.b64decode("wA==")
|
||||
assert expected_first_token in result
|
||||
assert result[expected_first_token] == 0
|
||||
|
||||
|
||||
def test_load_bpe_file_vs_tiktoken_with_real_model(llama3_model_path):
|
||||
"""Test that our implementation produces identical results to tiktoken on real model files."""
|
||||
if not llama3_model_path.exists():
|
||||
pytest.skip("Llama3 tokenizer model not found")
|
||||
|
||||
our_result = load_bpe_file(llama3_model_path)
|
||||
tiktoken_result = load_tiktoken_bpe(llama3_model_path.as_posix())
|
||||
|
||||
# Compare results from our implementation and tiktoken
|
||||
assert len(our_result) == len(tiktoken_result)
|
||||
assert our_result == tiktoken_result
|
||||
|
||||
assert len(our_result) > 100000
|
||||
ranks = list(our_result.values())
|
||||
assert len(ranks) == len(set(ranks))
|
||||
|
||||
|
||||
def test_load_bpe_file_vs_tiktoken_with_llama4_model(llama4_model_path):
|
||||
"""Test that our implementation produces identical results to tiktoken on Llama4 model."""
|
||||
if not llama4_model_path.exists():
|
||||
pytest.skip("Llama4 tokenizer model not found")
|
||||
|
||||
our_result = load_bpe_file(llama4_model_path)
|
||||
tiktoken_result = load_tiktoken_bpe(llama4_model_path.as_posix())
|
||||
|
||||
# Compare results from our implementation and tiktoken
|
||||
assert len(our_result) == len(tiktoken_result)
|
||||
assert our_result == tiktoken_result
|
||||
|
||||
assert len(our_result) > 100000
|
||||
ranks = list(our_result.values())
|
||||
assert len(ranks) == len(set(ranks))
|
||||
|
||||
|
||||
def test_load_bpe_file_malformed_lines(tmp_path):
|
||||
"""Test that load_bpe_file handles malformed lines gracefully."""
|
||||
malformed_content = """wA== 0
|
||||
invalid_line_without_rank
|
||||
wQ== 1
|
||||
invalid_base64!!! 2
|
||||
9Q== 2"""
|
||||
|
||||
test_file = tmp_path / "malformed.model"
|
||||
test_file.write_text(malformed_content, encoding="utf-8")
|
||||
|
||||
with patch("llama_stack.models.llama.tokenizer_utils.logger") as mock_logger:
|
||||
result = load_bpe_file(test_file)
|
||||
|
||||
# Should have 3 valid entries (skipping malformed ones)
|
||||
assert len(result) == 3
|
||||
|
||||
# Should have logged warnings for malformed lines
|
||||
assert mock_logger.warning.called
|
||||
assert mock_logger.warning.call_count > 0
|
||||
|
||||
|
||||
def test_load_bpe_file_nonexistent_file():
|
||||
"""Test that load_bpe_file raises appropriate error for nonexistent files."""
|
||||
with pytest.raises(FileNotFoundError):
|
||||
load_bpe_file("/nonexistent/path/to/file.model")
|
||||
|
||||
|
||||
def test_tokenizer_integration():
|
||||
"""Test that our load_bpe_file works correctly when used in actual tokenizers."""
|
||||
try:
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
||||
|
||||
tokenizer = Llama3Tokenizer.get_instance()
|
||||
|
||||
# Test basic functionality
|
||||
test_text = "Hello, world! This is a test."
|
||||
tokens = tokenizer.encode(test_text, bos=False, eos=False)
|
||||
decoded = tokenizer.decode(tokens)
|
||||
|
||||
assert test_text == decoded
|
||||
assert isinstance(tokens, list)
|
||||
assert all(isinstance(token, int) for token in tokens)
|
||||
|
||||
except Exception as e:
|
||||
pytest.skip(f"Llama3 tokenizer not available: {e}")
|
||||
|
||||
|
||||
def test_performance_comparison(llama3_model_path):
|
||||
"""Test that our implementation has reasonable performance compared to tiktoken."""
|
||||
if not llama3_model_path.exists():
|
||||
pytest.skip("Llama3 tokenizer model not found")
|
||||
|
||||
# Time our implementation
|
||||
start_time = time.time()
|
||||
our_result = load_bpe_file(llama3_model_path)
|
||||
our_time = time.time() - start_time
|
||||
|
||||
# Time tiktoken implementation
|
||||
start_time = time.time()
|
||||
tiktoken_result = load_tiktoken_bpe(llama3_model_path.as_posix())
|
||||
tiktoken_time = time.time() - start_time
|
||||
|
||||
# Verify results are identical
|
||||
assert our_result == tiktoken_result
|
||||
|
||||
# Our implementation should be reasonably fast (within 10x of tiktoken)
|
||||
# This is a loose bound since we're optimizing for correctness, not speed
|
||||
assert our_time < tiktoken_time * 10, f"Our implementation took {our_time:.3f}s vs tiktoken's {tiktoken_time:.3f}s"
|
||||
|
||||
print(f"Performance comparison - Our: {our_time:.3f}s, Tiktoken: {tiktoken_time:.3f}s")
|
||||
|
|
@ -59,6 +59,7 @@ async def agents_impl(config, mock_apis):
|
|||
mock_apis["safety_api"],
|
||||
mock_apis["tool_runtime_api"],
|
||||
mock_apis["tool_groups_api"],
|
||||
{},
|
||||
)
|
||||
await impl.initialize()
|
||||
yield impl
|
||||
|
|
|
|||
|
|
@ -25,11 +25,17 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseObjectWithInput,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
OpenAIResponseText,
|
||||
OpenAIResponseTextFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIJSONSchema,
|
||||
OpenAIResponseFormatJSONObject,
|
||||
OpenAIResponseFormatJSONSchema,
|
||||
OpenAIResponseFormatText,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime
|
||||
|
|
@ -96,6 +102,7 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
|
|||
mock_inference_api.openai_chat_completion.assert_called_once_with(
|
||||
model=model,
|
||||
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
|
||||
response_format=OpenAIResponseFormatText(),
|
||||
tools=None,
|
||||
stream=False,
|
||||
temperature=0.1,
|
||||
|
|
@ -224,16 +231,16 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
|
|||
],
|
||||
)
|
||||
|
||||
# Verify
|
||||
# Check that we got the content from our mocked tool execution result
|
||||
chunks = [chunk async for chunk in result]
|
||||
assert len(chunks) == 2 # Should have response.created and response.completed
|
||||
|
||||
# Verify inference API was called correctly (after iterating over result)
|
||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||
assert first_call.kwargs["messages"][0].content == input_text
|
||||
assert first_call.kwargs["tools"] is not None
|
||||
assert first_call.kwargs["temperature"] == 0.1
|
||||
|
||||
# Check that we got the content from our mocked tool execution result
|
||||
chunks = [chunk async for chunk in result]
|
||||
assert len(chunks) == 2 # Should have response.created and response.completed
|
||||
|
||||
# Check response.created event (should have empty output)
|
||||
assert chunks[0].type == "response.created"
|
||||
assert len(chunks[0].response.output) == 0
|
||||
|
|
@ -320,6 +327,7 @@ async def test_prepend_previous_response_basic(openai_responses_impl, mock_respo
|
|||
model="fake_model",
|
||||
output=[response_output_message],
|
||||
status="completed",
|
||||
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
|
||||
input=[input_item_message],
|
||||
)
|
||||
mock_responses_store.get_response_object.return_value = previous_response
|
||||
|
|
@ -362,6 +370,7 @@ async def test_prepend_previous_response_web_search(openai_responses_impl, mock_
|
|||
model="fake_model",
|
||||
output=[output_web_search, output_message],
|
||||
status="completed",
|
||||
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
|
||||
input=[input_item_message],
|
||||
)
|
||||
mock_responses_store.get_response_object.return_value = response
|
||||
|
|
@ -483,6 +492,7 @@ async def test_create_openai_response_with_instructions_and_previous_response(
|
|||
model="fake_model",
|
||||
output=[response_output_message],
|
||||
status="completed",
|
||||
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
|
||||
input=[input_item_message],
|
||||
)
|
||||
mock_responses_store.get_response_object.return_value = response
|
||||
|
|
@ -576,6 +586,7 @@ async def test_responses_store_list_input_items_logic():
|
|||
object="response",
|
||||
status="completed",
|
||||
output=[],
|
||||
text=OpenAIResponseText(format=(OpenAIResponseTextFormat(type="text"))),
|
||||
input=input_items,
|
||||
)
|
||||
|
||||
|
|
@ -644,6 +655,7 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
|
|||
created_at=1234567890,
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
status="completed",
|
||||
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
|
||||
input=[
|
||||
OpenAIResponseMessage(
|
||||
id="msg-prev-user", role="user", content=[OpenAIResponseInputMessageContentText(text="What is 2+2?")]
|
||||
|
|
@ -694,3 +706,61 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
|
|||
# Verify the response itself is correct
|
||||
assert result.model == model
|
||||
assert result.status == "completed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"text_format, response_format",
|
||||
[
|
||||
(OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")), OpenAIResponseFormatText()),
|
||||
(
|
||||
OpenAIResponseText(format=OpenAIResponseTextFormat(name="Test", schema={"foo": "bar"}, type="json_schema")),
|
||||
OpenAIResponseFormatJSONSchema(json_schema=OpenAIJSONSchema(name="Test", schema={"foo": "bar"})),
|
||||
),
|
||||
(OpenAIResponseText(format=OpenAIResponseTextFormat(type="json_object")), OpenAIResponseFormatJSONObject()),
|
||||
# ensure text param with no format specified defaults to text
|
||||
(OpenAIResponseText(format=None), OpenAIResponseFormatText()),
|
||||
# ensure text param of None defaults to text
|
||||
(None, OpenAIResponseFormatText()),
|
||||
],
|
||||
)
|
||||
async def test_create_openai_response_with_text_format(
|
||||
openai_responses_impl, mock_inference_api, text_format, response_format
|
||||
):
|
||||
"""Test creating Responses with text formats."""
|
||||
# Setup
|
||||
input_text = "How hot it is in San Francisco today?"
|
||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
# Load the chat completion fixture
|
||||
mock_chat_completion = load_chat_completion_fixture("simple_chat_completion.yaml")
|
||||
mock_inference_api.openai_chat_completion.return_value = mock_chat_completion
|
||||
|
||||
# Execute
|
||||
_result = await openai_responses_impl.create_openai_response(
|
||||
input=input_text,
|
||||
model=model,
|
||||
text=text_format,
|
||||
)
|
||||
|
||||
# Verify
|
||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||
assert first_call.kwargs["messages"][0].content == input_text
|
||||
assert first_call.kwargs["response_format"] is not None
|
||||
assert first_call.kwargs["response_format"] == response_format
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_openai_response_with_invalid_text_format(openai_responses_impl, mock_inference_api):
|
||||
"""Test creating an OpenAI response with an invalid text format."""
|
||||
# Setup
|
||||
input_text = "How hot it is in San Francisco today?"
|
||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
# Execute
|
||||
with pytest.raises(ValueError):
|
||||
_result = await openai_responses_impl.create_openai_response(
|
||||
input=input_text,
|
||||
model=model,
|
||||
text=OpenAIResponseText(format={"type": "invalid"}),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -12,24 +12,24 @@ import pytest
|
|||
|
||||
from llama_stack.apis.agents import Turn
|
||||
from llama_stack.apis.inference import CompletionMessage, StopReason
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.distribution.datatypes import User
|
||||
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_setup(sqlite_kvstore):
|
||||
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore)
|
||||
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore, policy={})
|
||||
yield agent_persistence
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
|
||||
async def test_session_creation_with_access_attributes(mock_get_auth_attributes, test_setup):
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
||||
async def test_session_creation_with_access_attributes(mock_get_authenticated_user, test_setup):
|
||||
agent_persistence = test_setup
|
||||
|
||||
# Set creator's attributes for the session
|
||||
creator_attributes = {"roles": ["researcher"], "teams": ["ai-team"]}
|
||||
mock_get_auth_attributes.return_value = creator_attributes
|
||||
mock_get_authenticated_user.return_value = User("test_user", creator_attributes)
|
||||
|
||||
# Create a session
|
||||
session_id = await agent_persistence.create_session("Test Session")
|
||||
|
|
@ -37,14 +37,15 @@ async def test_session_creation_with_access_attributes(mock_get_auth_attributes,
|
|||
# Get the session and verify access attributes were set
|
||||
session_info = await agent_persistence.get_session_info(session_id)
|
||||
assert session_info is not None
|
||||
assert session_info.access_attributes is not None
|
||||
assert session_info.access_attributes.roles == ["researcher"]
|
||||
assert session_info.access_attributes.teams == ["ai-team"]
|
||||
assert session_info.owner is not None
|
||||
assert session_info.owner.attributes is not None
|
||||
assert session_info.owner.attributes["roles"] == ["researcher"]
|
||||
assert session_info.owner.attributes["teams"] == ["ai-team"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
|
||||
async def test_session_access_control(mock_get_auth_attributes, test_setup):
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
||||
async def test_session_access_control(mock_get_authenticated_user, test_setup):
|
||||
agent_persistence = test_setup
|
||||
|
||||
# Create a session with specific access attributes
|
||||
|
|
@ -53,8 +54,9 @@ async def test_session_access_control(mock_get_auth_attributes, test_setup):
|
|||
session_id=session_id,
|
||||
session_name="Restricted Session",
|
||||
started_at=datetime.now(),
|
||||
access_attributes=AccessAttributes(roles=["admin"], teams=["security-team"]),
|
||||
owner=User("someone", {"roles": ["admin"], "teams": ["security-team"]}),
|
||||
turns=[],
|
||||
identifier="Restricted Session",
|
||||
)
|
||||
|
||||
await agent_persistence.kvstore.set(
|
||||
|
|
@ -63,20 +65,22 @@ async def test_session_access_control(mock_get_auth_attributes, test_setup):
|
|||
)
|
||||
|
||||
# User with matching attributes can access
|
||||
mock_get_auth_attributes.return_value = {"roles": ["admin", "user"], "teams": ["security-team", "other-team"]}
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"testuser", {"roles": ["admin", "user"], "teams": ["security-team", "other-team"]}
|
||||
)
|
||||
retrieved_session = await agent_persistence.get_session_info(session_id)
|
||||
assert retrieved_session is not None
|
||||
assert retrieved_session.session_id == session_id
|
||||
|
||||
# User without matching attributes cannot access
|
||||
mock_get_auth_attributes.return_value = {"roles": ["user"], "teams": ["other-team"]}
|
||||
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"], "teams": ["other-team"]})
|
||||
retrieved_session = await agent_persistence.get_session_info(session_id)
|
||||
assert retrieved_session is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
|
||||
async def test_turn_access_control(mock_get_auth_attributes, test_setup):
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
||||
async def test_turn_access_control(mock_get_authenticated_user, test_setup):
|
||||
agent_persistence = test_setup
|
||||
|
||||
# Create a session with restricted access
|
||||
|
|
@ -85,8 +89,9 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
|
|||
session_id=session_id,
|
||||
session_name="Restricted Session",
|
||||
started_at=datetime.now(),
|
||||
access_attributes=AccessAttributes(roles=["admin"]),
|
||||
owner=User("someone", {"roles": ["admin"]}),
|
||||
turns=[],
|
||||
identifier="Restricted Session",
|
||||
)
|
||||
|
||||
await agent_persistence.kvstore.set(
|
||||
|
|
@ -109,7 +114,7 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
|
|||
)
|
||||
|
||||
# Admin can add turn
|
||||
mock_get_auth_attributes.return_value = {"roles": ["admin"]}
|
||||
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["admin"]})
|
||||
await agent_persistence.add_turn_to_session(session_id, turn)
|
||||
|
||||
# Admin can get turn
|
||||
|
|
@ -118,7 +123,7 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
|
|||
assert retrieved_turn.turn_id == turn_id
|
||||
|
||||
# Regular user cannot get turn
|
||||
mock_get_auth_attributes.return_value = {"roles": ["user"]}
|
||||
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"]})
|
||||
with pytest.raises(ValueError):
|
||||
await agent_persistence.get_session_turn(session_id, turn_id)
|
||||
|
||||
|
|
@ -128,8 +133,8 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
|
||||
async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes, test_setup):
|
||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
||||
async def test_tool_call_and_infer_iters_access_control(mock_get_authenticated_user, test_setup):
|
||||
agent_persistence = test_setup
|
||||
|
||||
# Create a session with restricted access
|
||||
|
|
@ -138,8 +143,9 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes
|
|||
session_id=session_id,
|
||||
session_name="Restricted Session",
|
||||
started_at=datetime.now(),
|
||||
access_attributes=AccessAttributes(roles=["admin"]),
|
||||
owner=User("someone", {"roles": ["admin"]}),
|
||||
turns=[],
|
||||
identifier="Restricted Session",
|
||||
)
|
||||
|
||||
await agent_persistence.kvstore.set(
|
||||
|
|
@ -150,7 +156,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes
|
|||
turn_id = str(uuid.uuid4())
|
||||
|
||||
# Admin user can set inference iterations
|
||||
mock_get_auth_attributes.return_value = {"roles": ["admin"]}
|
||||
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["admin"]})
|
||||
await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 5)
|
||||
|
||||
# Admin user can get inference iterations
|
||||
|
|
@ -158,7 +164,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes
|
|||
assert infer_iters == 5
|
||||
|
||||
# Regular user cannot get inference iterations
|
||||
mock_get_auth_attributes.return_value = {"roles": ["user"]}
|
||||
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"]})
|
||||
infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id)
|
||||
assert infer_iters is None
|
||||
|
||||
|
|
|
|||
|
|
@ -70,9 +70,12 @@ class MockInferenceAdapterWithSleep:
|
|||
# ruff: noqa: N802
|
||||
def do_POST(self):
|
||||
time.sleep(sleep_time)
|
||||
response_body = json.dumps(response).encode("utf-8")
|
||||
self.send_response(code=200)
|
||||
self.send_header("Content-Type", "application/json")
|
||||
self.send_header("Content-Length", len(response_body))
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps(response).encode("utf-8"))
|
||||
self.wfile.write(response_body)
|
||||
|
||||
self.request_handler = DelayedRequestHandler
|
||||
|
||||
|
|
|
|||
|
|
@ -8,19 +8,18 @@
|
|||
import pytest
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.distribution.datatypes import ModelWithACL
|
||||
from llama_stack.distribution.server.auth_providers import AccessAttributes
|
||||
from llama_stack.distribution.datatypes import ModelWithOwner, User
|
||||
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registry_cache_with_acl(cached_disk_dist_registry):
|
||||
model = ModelWithACL(
|
||||
model = ModelWithOwner(
|
||||
identifier="model-acl",
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="model-acl-resource",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(roles=["admin"], teams=["ai-team"]),
|
||||
owner=User("testuser", {"roles": ["admin"], "teams": ["ai-team"]}),
|
||||
)
|
||||
|
||||
success = await cached_disk_dist_registry.register(model)
|
||||
|
|
@ -29,22 +28,14 @@ async def test_registry_cache_with_acl(cached_disk_dist_registry):
|
|||
cached_model = cached_disk_dist_registry.get_cached("model", "model-acl")
|
||||
assert cached_model is not None
|
||||
assert cached_model.identifier == "model-acl"
|
||||
assert cached_model.access_attributes.roles == ["admin"]
|
||||
assert cached_model.access_attributes.teams == ["ai-team"]
|
||||
assert cached_model.owner.principal == "testuser"
|
||||
assert cached_model.owner.attributes["roles"] == ["admin"]
|
||||
assert cached_model.owner.attributes["teams"] == ["ai-team"]
|
||||
|
||||
fetched_model = await cached_disk_dist_registry.get("model", "model-acl")
|
||||
assert fetched_model is not None
|
||||
assert fetched_model.identifier == "model-acl"
|
||||
assert fetched_model.access_attributes.roles == ["admin"]
|
||||
|
||||
model.access_attributes = AccessAttributes(roles=["admin", "user"], projects=["project-x"])
|
||||
await cached_disk_dist_registry.update(model)
|
||||
|
||||
updated_cached = cached_disk_dist_registry.get_cached("model", "model-acl")
|
||||
assert updated_cached is not None
|
||||
assert updated_cached.access_attributes.roles == ["admin", "user"]
|
||||
assert updated_cached.access_attributes.projects == ["project-x"]
|
||||
assert updated_cached.access_attributes.teams is None
|
||||
assert fetched_model.owner.attributes["roles"] == ["admin"]
|
||||
|
||||
new_registry = CachedDiskDistributionRegistry(cached_disk_dist_registry.kvstore)
|
||||
await new_registry.initialize()
|
||||
|
|
@ -52,35 +43,32 @@ async def test_registry_cache_with_acl(cached_disk_dist_registry):
|
|||
new_model = await new_registry.get("model", "model-acl")
|
||||
assert new_model is not None
|
||||
assert new_model.identifier == "model-acl"
|
||||
assert new_model.access_attributes.roles == ["admin", "user"]
|
||||
assert new_model.access_attributes.projects == ["project-x"]
|
||||
assert new_model.access_attributes.teams is None
|
||||
assert new_model.owner.principal == "testuser"
|
||||
assert new_model.owner.attributes["roles"] == ["admin"]
|
||||
assert new_model.owner.attributes["teams"] == ["ai-team"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registry_empty_acl(cached_disk_dist_registry):
|
||||
model = ModelWithACL(
|
||||
model = ModelWithOwner(
|
||||
identifier="model-empty-acl",
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="model-resource",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(),
|
||||
owner=User("testuser", None),
|
||||
)
|
||||
|
||||
await cached_disk_dist_registry.register(model)
|
||||
|
||||
cached_model = cached_disk_dist_registry.get_cached("model", "model-empty-acl")
|
||||
assert cached_model is not None
|
||||
assert cached_model.access_attributes is not None
|
||||
assert cached_model.access_attributes.roles is None
|
||||
assert cached_model.access_attributes.teams is None
|
||||
assert cached_model.access_attributes.projects is None
|
||||
assert cached_model.access_attributes.namespaces is None
|
||||
assert cached_model.owner is not None
|
||||
assert cached_model.owner.attributes is None
|
||||
|
||||
all_models = await cached_disk_dist_registry.get_all()
|
||||
assert len(all_models) == 1
|
||||
|
||||
model = ModelWithACL(
|
||||
model = ModelWithOwner(
|
||||
identifier="model-no-acl",
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="model-resource-2",
|
||||
|
|
@ -91,7 +79,7 @@ async def test_registry_empty_acl(cached_disk_dist_registry):
|
|||
|
||||
cached_model = cached_disk_dist_registry.get_cached("model", "model-no-acl")
|
||||
assert cached_model is not None
|
||||
assert cached_model.access_attributes is None
|
||||
assert cached_model.owner is None
|
||||
|
||||
all_models = await cached_disk_dist_registry.get_all()
|
||||
assert len(all_models) == 2
|
||||
|
|
@ -99,19 +87,19 @@ async def test_registry_empty_acl(cached_disk_dist_registry):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registry_serialization(cached_disk_dist_registry):
|
||||
attributes = AccessAttributes(
|
||||
roles=["admin", "researcher"],
|
||||
teams=["ai-team", "ml-team"],
|
||||
projects=["project-a", "project-b"],
|
||||
namespaces=["prod", "staging"],
|
||||
)
|
||||
attributes = {
|
||||
"roles": ["admin", "researcher"],
|
||||
"teams": ["ai-team", "ml-team"],
|
||||
"projects": ["project-a", "project-b"],
|
||||
"namespaces": ["prod", "staging"],
|
||||
}
|
||||
|
||||
model = ModelWithACL(
|
||||
model = ModelWithOwner(
|
||||
identifier="model-serialize",
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="model-resource",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=attributes,
|
||||
owner=User("bob", attributes),
|
||||
)
|
||||
|
||||
await cached_disk_dist_registry.register(model)
|
||||
|
|
@ -122,7 +110,7 @@ async def test_registry_serialization(cached_disk_dist_registry):
|
|||
loaded_model = await new_registry.get("model", "model-serialize")
|
||||
assert loaded_model is not None
|
||||
|
||||
assert loaded_model.access_attributes.roles == ["admin", "researcher"]
|
||||
assert loaded_model.access_attributes.teams == ["ai-team", "ml-team"]
|
||||
assert loaded_model.access_attributes.projects == ["project-a", "project-b"]
|
||||
assert loaded_model.access_attributes.namespaces == ["prod", "staging"]
|
||||
assert loaded_model.owner.attributes["roles"] == ["admin", "researcher"]
|
||||
assert loaded_model.owner.attributes["teams"] == ["ai-team", "ml-team"]
|
||||
assert loaded_model.owner.attributes["projects"] == ["project-a", "project-b"]
|
||||
assert loaded_model.owner.attributes["namespaces"] == ["prod", "staging"]
|
||||
|
|
|
|||
|
|
@ -7,10 +7,13 @@
|
|||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
|
||||
from llama_stack.apis.datatypes import Api
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL
|
||||
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||
from llama_stack.distribution.datatypes import AccessRule, ModelWithOwner, User
|
||||
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
|
||||
|
||||
|
||||
|
|
@ -32,39 +35,40 @@ async def test_setup(cached_disk_dist_registry):
|
|||
routing_table = ModelsRoutingTable(
|
||||
impls_by_provider_id={"test_provider": mock_inference},
|
||||
dist_registry=cached_disk_dist_registry,
|
||||
policy={},
|
||||
)
|
||||
yield cached_disk_dist_registry, routing_table
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
|
||||
async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
|
||||
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||
async def test_access_control_with_cache(mock_get_authenticated_user, test_setup):
|
||||
registry, routing_table = test_setup
|
||||
model_public = ModelWithACL(
|
||||
model_public = ModelWithOwner(
|
||||
identifier="model-public",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="model-public",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
model_admin_only = ModelWithACL(
|
||||
model_admin_only = ModelWithOwner(
|
||||
identifier="model-admin",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="model-admin",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(roles=["admin"]),
|
||||
owner=User("testuser", {"roles": ["admin"]}),
|
||||
)
|
||||
model_data_scientist = ModelWithACL(
|
||||
model_data_scientist = ModelWithOwner(
|
||||
identifier="model-data-scientist",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="model-data-scientist",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(roles=["data-scientist", "researcher"], teams=["ml-team"]),
|
||||
owner=User("testuser", {"roles": ["data-scientist", "researcher"], "teams": ["ml-team"]}),
|
||||
)
|
||||
await registry.register(model_public)
|
||||
await registry.register(model_admin_only)
|
||||
await registry.register(model_data_scientist)
|
||||
|
||||
mock_get_auth_attributes.return_value = {"roles": ["admin"], "teams": ["management"]}
|
||||
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["admin"], "teams": ["management"]})
|
||||
all_models = await routing_table.list_models()
|
||||
assert len(all_models.data) == 2
|
||||
|
||||
|
|
@ -75,7 +79,7 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
|
|||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-data-scientist")
|
||||
|
||||
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["other-team"]}
|
||||
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["data-scientist"], "teams": ["other-team"]})
|
||||
all_models = await routing_table.list_models()
|
||||
assert len(all_models.data) == 1
|
||||
assert all_models.data[0].identifier == "model-public"
|
||||
|
|
@ -86,7 +90,7 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
|
|||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-data-scientist")
|
||||
|
||||
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["ml-team"]}
|
||||
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["data-scientist"], "teams": ["ml-team"]})
|
||||
all_models = await routing_table.list_models()
|
||||
assert len(all_models.data) == 2
|
||||
model_ids = [m.identifier for m in all_models.data]
|
||||
|
|
@ -102,50 +106,62 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
|
||||
async def test_access_control_and_updates(mock_get_auth_attributes, test_setup):
|
||||
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||
async def test_access_control_and_updates(mock_get_authenticated_user, test_setup):
|
||||
registry, routing_table = test_setup
|
||||
model_public = ModelWithACL(
|
||||
model_public = ModelWithOwner(
|
||||
identifier="model-updates",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="model-updates",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
await registry.register(model_public)
|
||||
mock_get_auth_attributes.return_value = {
|
||||
"roles": ["user"],
|
||||
}
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"test-user",
|
||||
{
|
||||
"roles": ["user"],
|
||||
},
|
||||
)
|
||||
model = await routing_table.get_model("model-updates")
|
||||
assert model.identifier == "model-updates"
|
||||
model_public.access_attributes = AccessAttributes(roles=["admin"])
|
||||
model_public.owner = User("testuser", {"roles": ["admin"]})
|
||||
await registry.update(model_public)
|
||||
mock_get_auth_attributes.return_value = {
|
||||
"roles": ["user"],
|
||||
}
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"test-user",
|
||||
{
|
||||
"roles": ["user"],
|
||||
},
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-updates")
|
||||
mock_get_auth_attributes.return_value = {
|
||||
"roles": ["admin"],
|
||||
}
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"test-user",
|
||||
{
|
||||
"roles": ["admin"],
|
||||
},
|
||||
)
|
||||
model = await routing_table.get_model("model-updates")
|
||||
assert model.identifier == "model-updates"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
|
||||
async def test_access_control_empty_attributes(mock_get_auth_attributes, test_setup):
|
||||
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||
async def test_access_control_empty_attributes(mock_get_authenticated_user, test_setup):
|
||||
registry, routing_table = test_setup
|
||||
model = ModelWithACL(
|
||||
model = ModelWithOwner(
|
||||
identifier="model-empty-attrs",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="model-empty-attrs",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(),
|
||||
owner=User("testuser", {}),
|
||||
)
|
||||
await registry.register(model)
|
||||
mock_get_auth_attributes.return_value = {
|
||||
"roles": [],
|
||||
}
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"test-user",
|
||||
{
|
||||
"roles": [],
|
||||
},
|
||||
)
|
||||
result = await routing_table.get_model("model-empty-attrs")
|
||||
assert result.identifier == "model-empty-attrs"
|
||||
all_models = await routing_table.list_models()
|
||||
|
|
@ -154,25 +170,25 @@ async def test_access_control_empty_attributes(mock_get_auth_attributes, test_se
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
|
||||
async def test_no_user_attributes(mock_get_auth_attributes, test_setup):
|
||||
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||
async def test_no_user_attributes(mock_get_authenticated_user, test_setup):
|
||||
registry, routing_table = test_setup
|
||||
model_public = ModelWithACL(
|
||||
model_public = ModelWithOwner(
|
||||
identifier="model-public-2",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="model-public-2",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
model_restricted = ModelWithACL(
|
||||
model_restricted = ModelWithOwner(
|
||||
identifier="model-restricted",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="model-restricted",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(roles=["admin"]),
|
||||
owner=User("testuser", {"roles": ["admin"]}),
|
||||
)
|
||||
await registry.register(model_public)
|
||||
await registry.register(model_restricted)
|
||||
mock_get_auth_attributes.return_value = None
|
||||
mock_get_authenticated_user.return_value = User("test-user", None)
|
||||
model = await routing_table.get_model("model-public-2")
|
||||
assert model.identifier == "model-public-2"
|
||||
|
||||
|
|
@ -185,17 +201,17 @@ async def test_no_user_attributes(mock_get_auth_attributes, test_setup):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
|
||||
async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup):
|
||||
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||
async def test_automatic_access_attributes(mock_get_authenticated_user, test_setup):
|
||||
"""Test that newly created resources inherit access attributes from their creator."""
|
||||
registry, routing_table = test_setup
|
||||
|
||||
# Set creator's attributes
|
||||
creator_attributes = {"roles": ["data-scientist"], "teams": ["ml-team"], "projects": ["llama-3"]}
|
||||
mock_get_auth_attributes.return_value = creator_attributes
|
||||
mock_get_authenticated_user.return_value = User("test-user", creator_attributes)
|
||||
|
||||
# Create model without explicit access attributes
|
||||
model = ModelWithACL(
|
||||
model = ModelWithOwner(
|
||||
identifier="auto-access-model",
|
||||
provider_id="test_provider",
|
||||
provider_resource_id="auto-access-model",
|
||||
|
|
@ -205,21 +221,346 @@ async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup)
|
|||
|
||||
# Verify the model got creator's attributes
|
||||
registered_model = await routing_table.get_model("auto-access-model")
|
||||
assert registered_model.access_attributes is not None
|
||||
assert registered_model.access_attributes.roles == ["data-scientist"]
|
||||
assert registered_model.access_attributes.teams == ["ml-team"]
|
||||
assert registered_model.access_attributes.projects == ["llama-3"]
|
||||
assert registered_model.owner is not None
|
||||
assert registered_model.owner.attributes is not None
|
||||
assert registered_model.owner.attributes["roles"] == ["data-scientist"]
|
||||
assert registered_model.owner.attributes["teams"] == ["ml-team"]
|
||||
assert registered_model.owner.attributes["projects"] == ["llama-3"]
|
||||
|
||||
# Verify another user without matching attributes can't access it
|
||||
mock_get_auth_attributes.return_value = {"roles": ["engineer"], "teams": ["infra-team"]}
|
||||
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["engineer"], "teams": ["infra-team"]})
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("auto-access-model")
|
||||
|
||||
# But a user with matching attributes can
|
||||
mock_get_auth_attributes.return_value = {
|
||||
"roles": ["data-scientist", "engineer"],
|
||||
"teams": ["ml-team", "platform-team"],
|
||||
"projects": ["llama-3"],
|
||||
}
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"test-user",
|
||||
{
|
||||
"roles": ["data-scientist", "engineer"],
|
||||
"teams": ["ml-team", "platform-team"],
|
||||
"projects": ["llama-3"],
|
||||
},
|
||||
)
|
||||
model = await routing_table.get_model("auto-access-model")
|
||||
assert model.identifier == "auto-access-model"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_setup_with_access_policy(cached_disk_dist_registry):
|
||||
mock_inference = Mock()
|
||||
mock_inference.__provider_spec__ = MagicMock()
|
||||
mock_inference.__provider_spec__.api = Api.inference
|
||||
mock_inference.register_model = AsyncMock(side_effect=_return_model)
|
||||
mock_inference.unregister_model = AsyncMock(side_effect=_return_model)
|
||||
|
||||
config = """
|
||||
- permit:
|
||||
principal: user-1
|
||||
actions: [create, read, delete]
|
||||
description: user-1 has full access to all models
|
||||
- permit:
|
||||
principal: user-2
|
||||
actions: [read]
|
||||
resource: model::model-1
|
||||
description: user-2 has read access to model-1 only
|
||||
- permit:
|
||||
principal: user-3
|
||||
actions: [read]
|
||||
resource: model::model-2
|
||||
description: user-3 has read access to model-2 only
|
||||
- forbid:
|
||||
actions: [create, read, delete]
|
||||
"""
|
||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
routing_table = ModelsRoutingTable(
|
||||
impls_by_provider_id={"test_provider": mock_inference},
|
||||
dist_registry=cached_disk_dist_registry,
|
||||
policy=policy,
|
||||
)
|
||||
yield routing_table
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||
async def test_access_policy(mock_get_authenticated_user, test_setup_with_access_policy):
|
||||
routing_table = test_setup_with_access_policy
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"user-1",
|
||||
{
|
||||
"roles": ["admin"],
|
||||
"projects": ["foo", "bar"],
|
||||
},
|
||||
)
|
||||
await routing_table.register_model("model-1", provider_id="test_provider")
|
||||
await routing_table.register_model("model-2", provider_id="test_provider")
|
||||
await routing_table.register_model("model-3", provider_id="test_provider")
|
||||
model = await routing_table.get_model("model-1")
|
||||
assert model.identifier == "model-1"
|
||||
model = await routing_table.get_model("model-2")
|
||||
assert model.identifier == "model-2"
|
||||
model = await routing_table.get_model("model-3")
|
||||
assert model.identifier == "model-3"
|
||||
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"user-2",
|
||||
{
|
||||
"roles": ["user"],
|
||||
"projects": ["foo"],
|
||||
},
|
||||
)
|
||||
model = await routing_table.get_model("model-1")
|
||||
assert model.identifier == "model-1"
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-2")
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-3")
|
||||
with pytest.raises(AccessDeniedError):
|
||||
await routing_table.register_model("model-4", provider_id="test_provider")
|
||||
with pytest.raises(AccessDeniedError):
|
||||
await routing_table.unregister_model("model-1")
|
||||
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"user-3",
|
||||
{
|
||||
"roles": ["user"],
|
||||
"projects": ["bar"],
|
||||
},
|
||||
)
|
||||
model = await routing_table.get_model("model-2")
|
||||
assert model.identifier == "model-2"
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-1")
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-3")
|
||||
with pytest.raises(AccessDeniedError):
|
||||
await routing_table.register_model("model-5", provider_id="test_provider")
|
||||
with pytest.raises(AccessDeniedError):
|
||||
await routing_table.unregister_model("model-2")
|
||||
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"user-1",
|
||||
{
|
||||
"roles": ["admin"],
|
||||
"projects": ["foo", "bar"],
|
||||
},
|
||||
)
|
||||
await routing_table.unregister_model("model-3")
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-3")
|
||||
|
||||
|
||||
def test_permit_when():
|
||||
config = """
|
||||
- permit:
|
||||
principal: user-1
|
||||
actions: [read]
|
||||
when: user in owners namespaces
|
||||
"""
|
||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
model = ModelWithOwner(
|
||||
identifier="mymodel",
|
||||
provider_id="myprovider",
|
||||
model_type=ModelType.llm,
|
||||
owner=User("testuser", {"namespaces": ["foo"]}),
|
||||
)
|
||||
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
|
||||
|
||||
|
||||
def test_permit_unless():
|
||||
config = """
|
||||
- permit:
|
||||
principal: user-1
|
||||
actions: [read]
|
||||
resource: model::*
|
||||
unless:
|
||||
- user not in owners namespaces
|
||||
- user in owners teams
|
||||
"""
|
||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
model = ModelWithOwner(
|
||||
identifier="mymodel",
|
||||
provider_id="myprovider",
|
||||
model_type=ModelType.llm,
|
||||
owner=User("testuser", {"namespaces": ["foo"]}),
|
||||
)
|
||||
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
|
||||
|
||||
|
||||
def test_forbid_when():
|
||||
config = """
|
||||
- forbid:
|
||||
principal: user-1
|
||||
actions: [read]
|
||||
when:
|
||||
user in owners namespaces
|
||||
- permit:
|
||||
actions: [read]
|
||||
"""
|
||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
model = ModelWithOwner(
|
||||
identifier="mymodel",
|
||||
provider_id="myprovider",
|
||||
model_type=ModelType.llm,
|
||||
owner=User("testuser", {"namespaces": ["foo"]}),
|
||||
)
|
||||
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
|
||||
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
|
||||
assert is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
|
||||
|
||||
|
||||
def test_forbid_unless():
|
||||
config = """
|
||||
- forbid:
|
||||
principal: user-1
|
||||
actions: [read]
|
||||
unless:
|
||||
user in owners namespaces
|
||||
- permit:
|
||||
actions: [read]
|
||||
"""
|
||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
model = ModelWithOwner(
|
||||
identifier="mymodel",
|
||||
provider_id="myprovider",
|
||||
model_type=ModelType.llm,
|
||||
owner=User("testuser", {"namespaces": ["foo"]}),
|
||||
)
|
||||
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
|
||||
assert is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
|
||||
|
||||
|
||||
def test_user_has_attribute():
|
||||
config = """
|
||||
- permit:
|
||||
actions: [read]
|
||||
when: user with admin in roles
|
||||
"""
|
||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
model = ModelWithOwner(
|
||||
identifier="mymodel",
|
||||
provider_id="myprovider",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
|
||||
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-4", None))
|
||||
|
||||
|
||||
def test_user_does_not_have_attribute():
|
||||
config = """
|
||||
- permit:
|
||||
actions: [read]
|
||||
unless: user with admin not in roles
|
||||
"""
|
||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
model = ModelWithOwner(
|
||||
identifier="mymodel",
|
||||
provider_id="myprovider",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
|
||||
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-4", None))
|
||||
|
||||
|
||||
def test_is_owner():
|
||||
config = """
|
||||
- permit:
|
||||
actions: [read]
|
||||
when: user is owner
|
||||
"""
|
||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
model = ModelWithOwner(
|
||||
identifier="mymodel",
|
||||
provider_id="myprovider",
|
||||
model_type=ModelType.llm,
|
||||
owner=User("user-2", {"namespaces": ["foo"]}),
|
||||
)
|
||||
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
|
||||
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-4", None))
|
||||
|
||||
|
||||
def test_is_not_owner():
|
||||
config = """
|
||||
- permit:
|
||||
actions: [read]
|
||||
unless: user is not owner
|
||||
"""
|
||||
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
model = ModelWithOwner(
|
||||
identifier="mymodel",
|
||||
provider_id="myprovider",
|
||||
model_type=ModelType.llm,
|
||||
owner=User("user-2", {"namespaces": ["foo"]}),
|
||||
)
|
||||
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
|
||||
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
|
||||
assert not is_action_allowed(policy, "read", model, User("user-4", None))
|
||||
|
||||
|
||||
def test_invalid_rule_permit_and_forbid_both_specified():
|
||||
config = """
|
||||
- permit:
|
||||
actions: [read]
|
||||
forbid:
|
||||
actions: [create]
|
||||
"""
|
||||
with pytest.raises(ValidationError):
|
||||
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
|
||||
|
||||
def test_invalid_rule_neither_permit_or_forbid_specified():
|
||||
config = """
|
||||
- when: user is owner
|
||||
unless: user with admin in roles
|
||||
"""
|
||||
with pytest.raises(ValidationError):
|
||||
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
|
||||
|
||||
def test_invalid_rule_when_and_unless_both_specified():
|
||||
config = """
|
||||
- permit:
|
||||
actions: [read]
|
||||
when: user is owner
|
||||
unless: user with admin in roles
|
||||
"""
|
||||
with pytest.raises(ValidationError):
|
||||
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
|
||||
|
||||
def test_invalid_condition():
|
||||
config = """
|
||||
- permit:
|
||||
actions: [read]
|
||||
when: random words that are not valid
|
||||
"""
|
||||
with pytest.raises(ValidationError):
|
||||
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"condition",
|
||||
[
|
||||
"user is owner",
|
||||
"user is not owner",
|
||||
"user with dev in teams",
|
||||
"user with default not in namespaces",
|
||||
"user in owners roles",
|
||||
"user not in owners projects",
|
||||
],
|
||||
)
|
||||
def test_condition_reprs(condition):
|
||||
from llama_stack.distribution.access_control.conditions import parse_condition
|
||||
|
||||
assert condition == str(parse_condition(condition))
|
||||
|
|
|
|||
|
|
@ -139,7 +139,7 @@ async def mock_post_success(*args, **kwargs):
|
|||
{
|
||||
"message": "Authentication successful",
|
||||
"principal": "test-principal",
|
||||
"access_attributes": {
|
||||
"attributes": {
|
||||
"roles": ["admin", "user"],
|
||||
"teams": ["ml-team", "nlp-team"],
|
||||
"projects": ["llama-3", "project-x"],
|
||||
|
|
@ -233,7 +233,7 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock
|
|||
{
|
||||
"message": "Authentication successful",
|
||||
"principal": "test-principal",
|
||||
"access_attributes": {
|
||||
"attributes": {
|
||||
"roles": ["admin", "user"],
|
||||
"teams": ["ml-team", "nlp-team"],
|
||||
"projects": ["llama-3", "project-x"],
|
||||
|
|
@ -255,33 +255,6 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock
|
|||
mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_middleware_no_attributes(mock_http_middleware, mock_scope):
|
||||
"""Test middleware behavior with no access attributes"""
|
||||
middleware, mock_app = mock_http_middleware
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
mock_client_instance.post.return_value = MockResponse(
|
||||
200,
|
||||
{
|
||||
"message": "Authentication successful"
|
||||
# No access_attributes
|
||||
},
|
||||
)
|
||||
|
||||
await middleware(mock_scope, mock_receive, mock_send)
|
||||
|
||||
assert "user_attributes" in mock_scope
|
||||
attributes = mock_scope["user_attributes"]
|
||||
assert "roles" in attributes
|
||||
assert attributes["roles"] == ["test.jwt.token"]
|
||||
|
||||
|
||||
# oauth2 token provider tests
|
||||
|
||||
|
||||
|
|
@ -380,16 +353,16 @@ def test_get_attributes_from_claims():
|
|||
"aud": "llama-stack",
|
||||
}
|
||||
attributes = get_attributes_from_claims(claims, {"sub": "roles", "groups": "teams"})
|
||||
assert attributes.roles == ["my-user"]
|
||||
assert attributes.teams == ["group1", "group2"]
|
||||
assert attributes["roles"] == ["my-user"]
|
||||
assert attributes["teams"] == ["group1", "group2"]
|
||||
|
||||
claims = {
|
||||
"sub": "my-user",
|
||||
"tenant": "my-tenant",
|
||||
}
|
||||
attributes = get_attributes_from_claims(claims, {"sub": "roles", "tenant": "namespaces"})
|
||||
assert attributes.roles == ["my-user"]
|
||||
assert attributes.namespaces == ["my-tenant"]
|
||||
assert attributes["roles"] == ["my-user"]
|
||||
assert attributes["namespaces"] == ["my-tenant"]
|
||||
|
||||
claims = {
|
||||
"sub": "my-user",
|
||||
|
|
@ -408,9 +381,9 @@ def test_get_attributes_from_claims():
|
|||
"groups": "teams",
|
||||
},
|
||||
)
|
||||
assert set(attributes.roles) == {"my-user", "my-username"}
|
||||
assert set(attributes.teams) == {"my-team", "group1", "group2"}
|
||||
assert attributes.namespaces == ["my-tenant"]
|
||||
assert set(attributes["roles"]) == {"my-user", "my-username"}
|
||||
assert set(attributes["teams"]) == {"my-team", "group1", "group2"}
|
||||
assert attributes["namespaces"] == ["my-tenant"]
|
||||
|
||||
|
||||
# TODO: add more tests for oauth2 token provider
|
||||
|
|
|
|||
|
|
@ -100,9 +100,10 @@ async def test_resolve_impls_basic():
|
|||
add_protocol_methods(SampleImpl, Inference)
|
||||
|
||||
mock_module.get_provider_impl = AsyncMock(return_value=impl)
|
||||
mock_module.get_provider_impl.__text_signature__ = "()"
|
||||
sys.modules["test_module"] = mock_module
|
||||
|
||||
impls = await resolve_impls(run_config, provider_registry, dist_registry)
|
||||
impls = await resolve_impls(run_config, provider_registry, dist_registry, policy={})
|
||||
|
||||
assert Api.inference in impls
|
||||
assert isinstance(impls[Api.inference], InferenceRouter)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue