chore: default to pytest asyncio-mode=auto (#2730)

# What does this PR do?

previously, developers who ran `./scripts/unit-tests.sh` would get
`asyncio-mode=auto`, which meant `@pytest.mark.asyncio` and
`@pytest_asyncio.fixture` were redundent. developers who ran `pytest`
directly would get pytest's default (strict mode), would run into errors
leading them to add `@pytest.mark.asyncio` / `@pytest_asyncio.fixture`
to their code.

with this change -
- `asyncio_mode=auto` is included in `pyproject.toml` making behavior
consistent for all invocations of pytest
- removes all redundant `@pytest_asyncio.fixture` and
`@pytest.mark.asyncio`
 - for good measure, requires `pytest>=8.4` and `pytest-asyncio>=1.0`

## Test Plan

- `./scripts/unit-tests.sh`
- `uv run pytest tests/unit`
This commit is contained in:
Matthew Farrellee 2025-07-11 16:00:24 -04:00 committed by GitHub
parent 2ebc172f33
commit 30b2e6a495
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
35 changed files with 29 additions and 239 deletions

View file

@ -58,9 +58,9 @@ ui = [
[dependency-groups] [dependency-groups]
dev = [ dev = [
"pytest", "pytest>=8.4",
"pytest-timeout", "pytest-timeout",
"pytest-asyncio", "pytest-asyncio>=1.0",
"pytest-cov", "pytest-cov",
"pytest-html", "pytest-html",
"pytest-json-report", "pytest-json-report",
@ -339,3 +339,6 @@ warn_required_dynamic_aliases = true
[tool.ruff.lint.pep8-naming] [tool.ruff.lint.pep8-naming]
classmethod-decorators = ["classmethod", "pydantic.field_validator"] classmethod-decorators = ["classmethod", "pydantic.field_validator"]
[tool.pytest.ini_options]
asyncio_mode = "auto"

View file

@ -16,4 +16,4 @@ if [ $FOUND_PYTHON -ne 0 ]; then
uv python install "$PYTHON_VERSION" uv python install "$PYTHON_VERSION"
fi fi
uv run --python "$PYTHON_VERSION" --with-editable . --group unit pytest --asyncio-mode=auto -s -v tests/unit/ $@ uv run --python "$PYTHON_VERSION" --with-editable . --group unit pytest -s -v tests/unit/ $@

View file

@ -44,7 +44,6 @@ def common_params(inference_model):
) )
@pytest.mark.asyncio
@pytest.mark.skip(reason="This test needs to be migrated to api / client-sdk world") @pytest.mark.skip(reason="This test needs to be migrated to api / client-sdk world")
async def test_delete_agents_and_sessions(self, agents_stack, common_params): async def test_delete_agents_and_sessions(self, agents_stack, common_params):
agents_impl = agents_stack.impls[Api.agents] agents_impl = agents_stack.impls[Api.agents]
@ -73,7 +72,6 @@ async def test_delete_agents_and_sessions(self, agents_stack, common_params):
assert agent_response is None assert agent_response is None
@pytest.mark.asyncio
@pytest.mark.skip(reason="This test needs to be migrated to api / client-sdk world") @pytest.mark.skip(reason="This test needs to be migrated to api / client-sdk world")
async def test_get_agent_turns_and_steps(self, agents_stack, sample_messages, common_params): async def test_get_agent_turns_and_steps(self, agents_stack, sample_messages, common_params):
agents_impl = agents_stack.impls[Api.agents] agents_impl = agents_stack.impls[Api.agents]

View file

@ -4,20 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import pytest
from llama_stack_client import LlamaStackClient from llama_stack_client import LlamaStackClient
from llama_stack import LlamaStackAsLibraryClient from llama_stack import LlamaStackAsLibraryClient
class TestInspect: class TestInspect:
@pytest.mark.asyncio
def test_health(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient): def test_health(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
health = llama_stack_client.inspect.health() health = llama_stack_client.inspect.health()
assert health is not None assert health is not None
assert health.status == "OK" assert health.status == "OK"
@pytest.mark.asyncio
def test_version(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient): def test_version(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
version = llama_stack_client.inspect.version() version = llama_stack_client.inspect.version()
assert version is not None assert version is not None

View file

@ -4,14 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import pytest
from llama_stack_client import LlamaStackClient from llama_stack_client import LlamaStackClient
from llama_stack import LlamaStackAsLibraryClient from llama_stack import LlamaStackAsLibraryClient
class TestProviders: class TestProviders:
@pytest.mark.asyncio
def test_providers(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient): def test_providers(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
provider_list = llama_stack_client.providers.list() provider_list = llama_stack_client.providers.list()
assert provider_list is not None assert provider_list is not None

View file

@ -88,7 +88,6 @@ async def cleanup_records(sql_store, table_name, record_ids):
pass pass
@pytest.mark.asyncio
@pytest.mark.parametrize("backend_config", BACKEND_CONFIGS) @pytest.mark.parametrize("backend_config", BACKEND_CONFIGS)
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") @patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
async def test_authorized_store_attributes(mock_get_authenticated_user, authorized_store, request): async def test_authorized_store_attributes(mock_get_authenticated_user, authorized_store, request):
@ -183,7 +182,6 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
await cleanup_records(authorized_store.sql_store, table_name, ["1", "2", "3", "4", "5", "6"]) await cleanup_records(authorized_store.sql_store, table_name, ["1", "2", "3", "4", "5", "6"])
@pytest.mark.asyncio
@pytest.mark.parametrize("backend_config", BACKEND_CONFIGS) @pytest.mark.parametrize("backend_config", BACKEND_CONFIGS)
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") @patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
async def test_user_ownership_policy(mock_get_authenticated_user, authorized_store, request): async def test_user_ownership_policy(mock_get_authenticated_user, authorized_store, request):

View file

@ -8,8 +8,6 @@
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.common.type_system import NumberType from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
from llama_stack.apis.datatypes import Api from llama_stack.apis.datatypes import Api
@ -119,7 +117,6 @@ class ToolGroupsImpl(Impl):
) )
@pytest.mark.asyncio
async def test_models_routing_table(cached_disk_dist_registry): async def test_models_routing_table(cached_disk_dist_registry):
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize() await table.initialize()
@ -161,7 +158,6 @@ async def test_models_routing_table(cached_disk_dist_registry):
assert len(openai_models.data) == 0 assert len(openai_models.data) == 0
@pytest.mark.asyncio
async def test_shields_routing_table(cached_disk_dist_registry): async def test_shields_routing_table(cached_disk_dist_registry):
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry, {}) table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry, {})
await table.initialize() await table.initialize()
@ -177,7 +173,6 @@ async def test_shields_routing_table(cached_disk_dist_registry):
assert "test-shield-2" in shield_ids assert "test-shield-2" in shield_ids
@pytest.mark.asyncio
async def test_vectordbs_routing_table(cached_disk_dist_registry): async def test_vectordbs_routing_table(cached_disk_dist_registry):
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {}) table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
await table.initialize() await table.initialize()
@ -233,7 +228,6 @@ async def test_datasets_routing_table(cached_disk_dist_registry):
assert len(datasets.data) == 0 assert len(datasets.data) == 0
@pytest.mark.asyncio
async def test_scoring_functions_routing_table(cached_disk_dist_registry): async def test_scoring_functions_routing_table(cached_disk_dist_registry):
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry, {}) table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry, {})
await table.initialize() await table.initialize()
@ -259,7 +253,6 @@ async def test_scoring_functions_routing_table(cached_disk_dist_registry):
assert "test-scoring-fn-2" in scoring_fn_ids assert "test-scoring-fn-2" in scoring_fn_ids
@pytest.mark.asyncio
async def test_benchmarks_routing_table(cached_disk_dist_registry): async def test_benchmarks_routing_table(cached_disk_dist_registry):
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry, {}) table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry, {})
await table.initialize() await table.initialize()
@ -277,7 +270,6 @@ async def test_benchmarks_routing_table(cached_disk_dist_registry):
assert "test-benchmark" in benchmark_ids assert "test-benchmark" in benchmark_ids
@pytest.mark.asyncio
async def test_tool_groups_routing_table(cached_disk_dist_registry): async def test_tool_groups_routing_table(cached_disk_dist_registry):
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry, {}) table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry, {})
await table.initialize() await table.initialize()

View file

@ -13,7 +13,6 @@ import pytest
from llama_stack.distribution.utils.context import preserve_contexts_async_generator from llama_stack.distribution.utils.context import preserve_contexts_async_generator
@pytest.mark.asyncio
async def test_preserve_contexts_with_exception(): async def test_preserve_contexts_with_exception():
# Create context variable # Create context variable
context_var = ContextVar("exception_var", default="initial") context_var = ContextVar("exception_var", default="initial")
@ -41,7 +40,6 @@ async def test_preserve_contexts_with_exception():
context_var.reset(token) context_var.reset(token)
@pytest.mark.asyncio
async def test_preserve_contexts_empty_generator(): async def test_preserve_contexts_empty_generator():
# Create context variable # Create context variable
context_var = ContextVar("empty_var", default="initial") context_var = ContextVar("empty_var", default="initial")
@ -66,7 +64,6 @@ async def test_preserve_contexts_empty_generator():
context_var.reset(token) context_var.reset(token)
@pytest.mark.asyncio
async def test_preserve_contexts_across_event_loops(): async def test_preserve_contexts_across_event_loops():
""" """
Test that context variables are preserved across event loop boundaries with nested generators. Test that context variables are preserved across event loop boundaries with nested generators.

View file

@ -6,7 +6,6 @@
import pytest import pytest
import pytest_asyncio
from llama_stack.apis.common.responses import Order from llama_stack.apis.common.responses import Order
from llama_stack.apis.files import OpenAIFilePurpose from llama_stack.apis.files import OpenAIFilePurpose
@ -29,7 +28,7 @@ class MockUploadFile:
return self.content return self.content
@pytest_asyncio.fixture @pytest.fixture
async def files_provider(tmp_path): async def files_provider(tmp_path):
"""Create a files provider with temporary storage for testing.""" """Create a files provider with temporary storage for testing."""
storage_dir = tmp_path / "files" storage_dir = tmp_path / "files"
@ -68,7 +67,6 @@ def large_file():
class TestOpenAIFilesAPI: class TestOpenAIFilesAPI:
"""Test suite for OpenAI Files API endpoints.""" """Test suite for OpenAI Files API endpoints."""
@pytest.mark.asyncio
async def test_upload_file_success(self, files_provider, sample_text_file): async def test_upload_file_success(self, files_provider, sample_text_file):
"""Test successful file upload.""" """Test successful file upload."""
# Upload file # Upload file
@ -82,7 +80,6 @@ class TestOpenAIFilesAPI:
assert result.created_at > 0 assert result.created_at > 0
assert result.expires_at > result.created_at assert result.expires_at > result.created_at
@pytest.mark.asyncio
async def test_upload_different_purposes(self, files_provider, sample_text_file): async def test_upload_different_purposes(self, files_provider, sample_text_file):
"""Test uploading files with different purposes.""" """Test uploading files with different purposes."""
purposes = list(OpenAIFilePurpose) purposes = list(OpenAIFilePurpose)
@ -93,7 +90,6 @@ class TestOpenAIFilesAPI:
uploaded_files.append(result) uploaded_files.append(result)
assert result.purpose == purpose assert result.purpose == purpose
@pytest.mark.asyncio
async def test_upload_different_file_types(self, files_provider, sample_text_file, sample_json_file, large_file): async def test_upload_different_file_types(self, files_provider, sample_text_file, sample_json_file, large_file):
"""Test uploading different types and sizes of files.""" """Test uploading different types and sizes of files."""
files_to_test = [ files_to_test = [
@ -107,7 +103,6 @@ class TestOpenAIFilesAPI:
assert result.filename == expected_filename assert result.filename == expected_filename
assert result.bytes == len(file_obj.content) assert result.bytes == len(file_obj.content)
@pytest.mark.asyncio
async def test_list_files_empty(self, files_provider): async def test_list_files_empty(self, files_provider):
"""Test listing files when no files exist.""" """Test listing files when no files exist."""
result = await files_provider.openai_list_files() result = await files_provider.openai_list_files()
@ -117,7 +112,6 @@ class TestOpenAIFilesAPI:
assert result.first_id == "" assert result.first_id == ""
assert result.last_id == "" assert result.last_id == ""
@pytest.mark.asyncio
async def test_list_files_with_content(self, files_provider, sample_text_file, sample_json_file): async def test_list_files_with_content(self, files_provider, sample_text_file, sample_json_file):
"""Test listing files when files exist.""" """Test listing files when files exist."""
# Upload multiple files # Upload multiple files
@ -132,7 +126,6 @@ class TestOpenAIFilesAPI:
assert file1.id in file_ids assert file1.id in file_ids
assert file2.id in file_ids assert file2.id in file_ids
@pytest.mark.asyncio
async def test_list_files_with_purpose_filter(self, files_provider, sample_text_file): async def test_list_files_with_purpose_filter(self, files_provider, sample_text_file):
"""Test listing files with purpose filtering.""" """Test listing files with purpose filtering."""
# Upload file with specific purpose # Upload file with specific purpose
@ -146,7 +139,6 @@ class TestOpenAIFilesAPI:
assert result.data[0].id == uploaded_file.id assert result.data[0].id == uploaded_file.id
assert result.data[0].purpose == OpenAIFilePurpose.ASSISTANTS assert result.data[0].purpose == OpenAIFilePurpose.ASSISTANTS
@pytest.mark.asyncio
async def test_list_files_with_limit(self, files_provider, sample_text_file): async def test_list_files_with_limit(self, files_provider, sample_text_file):
"""Test listing files with limit parameter.""" """Test listing files with limit parameter."""
# Upload multiple files # Upload multiple files
@ -157,7 +149,6 @@ class TestOpenAIFilesAPI:
result = await files_provider.openai_list_files(limit=3) result = await files_provider.openai_list_files(limit=3)
assert len(result.data) == 3 assert len(result.data) == 3
@pytest.mark.asyncio
async def test_list_files_with_order(self, files_provider, sample_text_file): async def test_list_files_with_order(self, files_provider, sample_text_file):
"""Test listing files with different order.""" """Test listing files with different order."""
# Upload multiple files # Upload multiple files
@ -178,7 +169,6 @@ class TestOpenAIFilesAPI:
# Oldest should be first # Oldest should be first
assert result_asc.data[0].created_at <= result_asc.data[1].created_at <= result_asc.data[2].created_at assert result_asc.data[0].created_at <= result_asc.data[1].created_at <= result_asc.data[2].created_at
@pytest.mark.asyncio
async def test_retrieve_file_success(self, files_provider, sample_text_file): async def test_retrieve_file_success(self, files_provider, sample_text_file):
"""Test successful file retrieval.""" """Test successful file retrieval."""
# Upload file # Upload file
@ -197,13 +187,11 @@ class TestOpenAIFilesAPI:
assert retrieved_file.created_at == uploaded_file.created_at assert retrieved_file.created_at == uploaded_file.created_at
assert retrieved_file.expires_at == uploaded_file.expires_at assert retrieved_file.expires_at == uploaded_file.expires_at
@pytest.mark.asyncio
async def test_retrieve_file_not_found(self, files_provider): async def test_retrieve_file_not_found(self, files_provider):
"""Test retrieving a non-existent file.""" """Test retrieving a non-existent file."""
with pytest.raises(ValueError, match="File with id file-nonexistent not found"): with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
await files_provider.openai_retrieve_file("file-nonexistent") await files_provider.openai_retrieve_file("file-nonexistent")
@pytest.mark.asyncio
async def test_retrieve_file_content_success(self, files_provider, sample_text_file): async def test_retrieve_file_content_success(self, files_provider, sample_text_file):
"""Test successful file content retrieval.""" """Test successful file content retrieval."""
# Upload file # Upload file
@ -217,13 +205,11 @@ class TestOpenAIFilesAPI:
# Verify content # Verify content
assert content.body == sample_text_file.content assert content.body == sample_text_file.content
@pytest.mark.asyncio
async def test_retrieve_file_content_not_found(self, files_provider): async def test_retrieve_file_content_not_found(self, files_provider):
"""Test retrieving content of a non-existent file.""" """Test retrieving content of a non-existent file."""
with pytest.raises(ValueError, match="File with id file-nonexistent not found"): with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
await files_provider.openai_retrieve_file_content("file-nonexistent") await files_provider.openai_retrieve_file_content("file-nonexistent")
@pytest.mark.asyncio
async def test_delete_file_success(self, files_provider, sample_text_file): async def test_delete_file_success(self, files_provider, sample_text_file):
"""Test successful file deletion.""" """Test successful file deletion."""
# Upload file # Upload file
@ -245,13 +231,11 @@ class TestOpenAIFilesAPI:
with pytest.raises(ValueError, match=f"File with id {uploaded_file.id} not found"): with pytest.raises(ValueError, match=f"File with id {uploaded_file.id} not found"):
await files_provider.openai_retrieve_file(uploaded_file.id) await files_provider.openai_retrieve_file(uploaded_file.id)
@pytest.mark.asyncio
async def test_delete_file_not_found(self, files_provider): async def test_delete_file_not_found(self, files_provider):
"""Test deleting a non-existent file.""" """Test deleting a non-existent file."""
with pytest.raises(ValueError, match="File with id file-nonexistent not found"): with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
await files_provider.openai_delete_file("file-nonexistent") await files_provider.openai_delete_file("file-nonexistent")
@pytest.mark.asyncio
async def test_file_persistence_across_operations(self, files_provider, sample_text_file): async def test_file_persistence_across_operations(self, files_provider, sample_text_file):
"""Test that files persist correctly across multiple operations.""" """Test that files persist correctly across multiple operations."""
# Upload file # Upload file
@ -279,7 +263,6 @@ class TestOpenAIFilesAPI:
files_list = await files_provider.openai_list_files() files_list = await files_provider.openai_list_files()
assert len(files_list.data) == 0 assert len(files_list.data) == 0
@pytest.mark.asyncio
async def test_multiple_files_operations(self, files_provider, sample_text_file, sample_json_file): async def test_multiple_files_operations(self, files_provider, sample_text_file, sample_json_file):
"""Test operations with multiple files.""" """Test operations with multiple files."""
# Upload multiple files # Upload multiple files
@ -302,7 +285,6 @@ class TestOpenAIFilesAPI:
content = await files_provider.openai_retrieve_file_content(file2.id) content = await files_provider.openai_retrieve_file_content(file2.id)
assert content.body == sample_json_file.content assert content.body == sample_json_file.content
@pytest.mark.asyncio
async def test_file_id_uniqueness(self, files_provider, sample_text_file): async def test_file_id_uniqueness(self, files_provider, sample_text_file):
"""Test that each uploaded file gets a unique ID.""" """Test that each uploaded file gets a unique ID."""
file_ids = set() file_ids = set()
@ -316,7 +298,6 @@ class TestOpenAIFilesAPI:
file_ids.add(uploaded_file.id) file_ids.add(uploaded_file.id)
assert uploaded_file.id.startswith("file-") assert uploaded_file.id.startswith("file-")
@pytest.mark.asyncio
async def test_file_no_filename_handling(self, files_provider): async def test_file_no_filename_handling(self, files_provider):
"""Test handling files with no filename.""" """Test handling files with no filename."""
file_without_name = MockUploadFile(b"content", None) # No filename file_without_name = MockUploadFile(b"content", None) # No filename
@ -327,7 +308,6 @@ class TestOpenAIFilesAPI:
assert uploaded_file.filename == "uploaded_file" # Default filename assert uploaded_file.filename == "uploaded_file" # Default filename
@pytest.mark.asyncio
async def test_after_pagination_works(self, files_provider, sample_text_file): async def test_after_pagination_works(self, files_provider, sample_text_file):
"""Test that 'after' pagination works correctly.""" """Test that 'after' pagination works correctly."""
# Upload multiple files to test pagination # Upload multiple files to test pagination

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import pytest_asyncio import pytest
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry, DiskDistributionRegistry from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry, DiskDistributionRegistry
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
@pytest_asyncio.fixture(scope="function") @pytest.fixture(scope="function")
async def sqlite_kvstore(tmp_path): async def sqlite_kvstore(tmp_path):
db_path = tmp_path / "test_kv.db" db_path = tmp_path / "test_kv.db"
kvstore_config = SqliteKVStoreConfig(db_path=db_path.as_posix()) kvstore_config = SqliteKVStoreConfig(db_path=db_path.as_posix())
@ -20,14 +20,14 @@ async def sqlite_kvstore(tmp_path):
yield kvstore yield kvstore
@pytest_asyncio.fixture(scope="function") @pytest.fixture(scope="function")
async def disk_dist_registry(sqlite_kvstore): async def disk_dist_registry(sqlite_kvstore):
registry = DiskDistributionRegistry(sqlite_kvstore) registry = DiskDistributionRegistry(sqlite_kvstore)
await registry.initialize() await registry.initialize()
yield registry yield registry
@pytest_asyncio.fixture(scope="function") @pytest.fixture(scope="function")
async def cached_disk_dist_registry(sqlite_kvstore): async def cached_disk_dist_registry(sqlite_kvstore):
registry = CachedDiskDistributionRegistry(sqlite_kvstore) registry = CachedDiskDistributionRegistry(sqlite_kvstore)
await registry.initialize() await registry.initialize()

View file

@ -8,7 +8,6 @@ from datetime import datetime
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
import pytest import pytest
import pytest_asyncio
from llama_stack.apis.agents import ( from llama_stack.apis.agents import (
Agent, Agent,
@ -50,7 +49,7 @@ def config(tmp_path):
) )
@pytest_asyncio.fixture @pytest.fixture
async def agents_impl(config, mock_apis): async def agents_impl(config, mock_apis):
impl = MetaReferenceAgentsImpl( impl = MetaReferenceAgentsImpl(
config, config,
@ -117,7 +116,6 @@ def sample_agent_config():
) )
@pytest.mark.asyncio
async def test_create_agent(agents_impl, sample_agent_config): async def test_create_agent(agents_impl, sample_agent_config):
response = await agents_impl.create_agent(sample_agent_config) response = await agents_impl.create_agent(sample_agent_config)
@ -132,7 +130,6 @@ async def test_create_agent(agents_impl, sample_agent_config):
assert isinstance(agent_info.created_at, datetime) assert isinstance(agent_info.created_at, datetime)
@pytest.mark.asyncio
async def test_get_agent(agents_impl, sample_agent_config): async def test_get_agent(agents_impl, sample_agent_config):
create_response = await agents_impl.create_agent(sample_agent_config) create_response = await agents_impl.create_agent(sample_agent_config)
agent_id = create_response.agent_id agent_id = create_response.agent_id
@ -146,7 +143,6 @@ async def test_get_agent(agents_impl, sample_agent_config):
assert isinstance(agent.created_at, datetime) assert isinstance(agent.created_at, datetime)
@pytest.mark.asyncio
async def test_list_agents(agents_impl, sample_agent_config): async def test_list_agents(agents_impl, sample_agent_config):
agent1_response = await agents_impl.create_agent(sample_agent_config) agent1_response = await agents_impl.create_agent(sample_agent_config)
agent2_response = await agents_impl.create_agent(sample_agent_config) agent2_response = await agents_impl.create_agent(sample_agent_config)
@ -160,7 +156,6 @@ async def test_list_agents(agents_impl, sample_agent_config):
assert agent2_response.agent_id in agent_ids assert agent2_response.agent_id in agent_ids
@pytest.mark.asyncio
@pytest.mark.parametrize("enable_session_persistence", [True, False]) @pytest.mark.parametrize("enable_session_persistence", [True, False])
async def test_create_agent_session_persistence(agents_impl, sample_agent_config, enable_session_persistence): async def test_create_agent_session_persistence(agents_impl, sample_agent_config, enable_session_persistence):
# Create an agent with specified persistence setting # Create an agent with specified persistence setting
@ -188,7 +183,6 @@ async def test_create_agent_session_persistence(agents_impl, sample_agent_config
await agents_impl.get_agents_session(agent_id, session_response.session_id) await agents_impl.get_agents_session(agent_id, session_response.session_id)
@pytest.mark.asyncio
@pytest.mark.parametrize("enable_session_persistence", [True, False]) @pytest.mark.parametrize("enable_session_persistence", [True, False])
async def test_list_agent_sessions_persistence(agents_impl, sample_agent_config, enable_session_persistence): async def test_list_agent_sessions_persistence(agents_impl, sample_agent_config, enable_session_persistence):
# Create an agent with specified persistence setting # Create an agent with specified persistence setting
@ -221,7 +215,6 @@ async def test_list_agent_sessions_persistence(agents_impl, sample_agent_config,
assert session2.session_id in {s["session_id"] for s in sessions.data} assert session2.session_id in {s["session_id"] for s in sessions.data}
@pytest.mark.asyncio
async def test_delete_agent(agents_impl, sample_agent_config): async def test_delete_agent(agents_impl, sample_agent_config):
# Create an agent # Create an agent
response = await agents_impl.create_agent(sample_agent_config) response = await agents_impl.create_agent(sample_agent_config)

View file

@ -122,7 +122,6 @@ async def fake_stream(fixture: str = "simple_chat_completion.yaml"):
) )
@pytest.mark.asyncio
async def test_create_openai_response_with_string_input(openai_responses_impl, mock_inference_api): async def test_create_openai_response_with_string_input(openai_responses_impl, mock_inference_api):
"""Test creating an OpenAI response with a simple string input.""" """Test creating an OpenAI response with a simple string input."""
# Setup # Setup
@ -155,7 +154,6 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
assert result.output[0].content[0].text == "Dublin" assert result.output[0].content[0].text == "Dublin"
@pytest.mark.asyncio
async def test_create_openai_response_with_string_input_with_tools(openai_responses_impl, mock_inference_api): async def test_create_openai_response_with_string_input_with_tools(openai_responses_impl, mock_inference_api):
"""Test creating an OpenAI response with a simple string input and tools.""" """Test creating an OpenAI response with a simple string input and tools."""
# Setup # Setup
@ -224,7 +222,6 @@ async def test_create_openai_response_with_string_input_with_tools(openai_respon
assert result.output[1].content[0].annotations == [] assert result.output[1].content[0].annotations == []
@pytest.mark.asyncio
async def test_create_openai_response_with_tool_call_type_none(openai_responses_impl, mock_inference_api): async def test_create_openai_response_with_tool_call_type_none(openai_responses_impl, mock_inference_api):
"""Test creating an OpenAI response with a tool call response that has a type of None.""" """Test creating an OpenAI response with a tool call response that has a type of None."""
# Setup # Setup
@ -294,7 +291,6 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
assert chunks[1].response.output[0].name == "get_weather" assert chunks[1].response.output[0].name == "get_weather"
@pytest.mark.asyncio
async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api): async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api):
"""Test creating an OpenAI response with multiple messages.""" """Test creating an OpenAI response with multiple messages."""
# Setup # Setup
@ -340,7 +336,6 @@ async def test_create_openai_response_with_multiple_messages(openai_responses_im
assert isinstance(inference_messages[i], OpenAIDeveloperMessageParam) assert isinstance(inference_messages[i], OpenAIDeveloperMessageParam)
@pytest.mark.asyncio
async def test_prepend_previous_response_none(openai_responses_impl): async def test_prepend_previous_response_none(openai_responses_impl):
"""Test prepending no previous response to a new response.""" """Test prepending no previous response to a new response."""
@ -348,7 +343,6 @@ async def test_prepend_previous_response_none(openai_responses_impl):
assert input == "fake_input" assert input == "fake_input"
@pytest.mark.asyncio
async def test_prepend_previous_response_basic(openai_responses_impl, mock_responses_store): async def test_prepend_previous_response_basic(openai_responses_impl, mock_responses_store):
"""Test prepending a basic previous response to a new response.""" """Test prepending a basic previous response to a new response."""
@ -388,7 +382,6 @@ async def test_prepend_previous_response_basic(openai_responses_impl, mock_respo
assert input[2].content == "fake_input" assert input[2].content == "fake_input"
@pytest.mark.asyncio
async def test_prepend_previous_response_web_search(openai_responses_impl, mock_responses_store): async def test_prepend_previous_response_web_search(openai_responses_impl, mock_responses_store):
"""Test prepending a web search previous response to a new response.""" """Test prepending a web search previous response to a new response."""
input_item_message = OpenAIResponseMessage( input_item_message = OpenAIResponseMessage(
@ -434,7 +427,6 @@ async def test_prepend_previous_response_web_search(openai_responses_impl, mock_
assert input[3].content == "fake_input" assert input[3].content == "fake_input"
@pytest.mark.asyncio
async def test_create_openai_response_with_instructions(openai_responses_impl, mock_inference_api): async def test_create_openai_response_with_instructions(openai_responses_impl, mock_inference_api):
# Setup # Setup
input_text = "What is the capital of Ireland?" input_text = "What is the capital of Ireland?"
@ -463,7 +455,6 @@ async def test_create_openai_response_with_instructions(openai_responses_impl, m
assert sent_messages[1].content == input_text assert sent_messages[1].content == input_text
@pytest.mark.asyncio
async def test_create_openai_response_with_instructions_and_multiple_messages( async def test_create_openai_response_with_instructions_and_multiple_messages(
openai_responses_impl, mock_inference_api openai_responses_impl, mock_inference_api
): ):
@ -508,7 +499,6 @@ async def test_create_openai_response_with_instructions_and_multiple_messages(
assert sent_messages[3].content == "Which is the largest?" assert sent_messages[3].content == "Which is the largest?"
@pytest.mark.asyncio
async def test_create_openai_response_with_instructions_and_previous_response( async def test_create_openai_response_with_instructions_and_previous_response(
openai_responses_impl, mock_responses_store, mock_inference_api openai_responses_impl, mock_responses_store, mock_inference_api
): ):
@ -565,7 +555,6 @@ async def test_create_openai_response_with_instructions_and_previous_response(
assert sent_messages[3].content == "Which is the largest?" assert sent_messages[3].content == "Which is the largest?"
@pytest.mark.asyncio
async def test_list_openai_response_input_items_delegation(openai_responses_impl, mock_responses_store): async def test_list_openai_response_input_items_delegation(openai_responses_impl, mock_responses_store):
"""Test that list_openai_response_input_items properly delegates to responses_store with correct parameters.""" """Test that list_openai_response_input_items properly delegates to responses_store with correct parameters."""
# Setup # Setup
@ -601,7 +590,6 @@ async def test_list_openai_response_input_items_delegation(openai_responses_impl
assert result.data[0].id == "msg_123" assert result.data[0].id == "msg_123"
@pytest.mark.asyncio
async def test_responses_store_list_input_items_logic(): async def test_responses_store_list_input_items_logic():
"""Test ResponsesStore list_response_input_items logic - mocks get_response_object to test actual ordering/limiting.""" """Test ResponsesStore list_response_input_items logic - mocks get_response_object to test actual ordering/limiting."""
@ -680,7 +668,6 @@ async def test_responses_store_list_input_items_logic():
assert len(result.data) == 0 # Should return no items assert len(result.data) == 0 # Should return no items
@pytest.mark.asyncio
async def test_store_response_uses_rehydrated_input_with_previous_response( async def test_store_response_uses_rehydrated_input_with_previous_response(
openai_responses_impl, mock_responses_store, mock_inference_api openai_responses_impl, mock_responses_store, mock_inference_api
): ):
@ -747,7 +734,6 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
assert result.status == "completed" assert result.status == "completed"
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"text_format, response_format", "text_format, response_format",
[ [
@ -787,7 +773,6 @@ async def test_create_openai_response_with_text_format(
assert first_call.kwargs["response_format"] == response_format assert first_call.kwargs["response_format"] == response_format
@pytest.mark.asyncio
async def test_create_openai_response_with_invalid_text_format(openai_responses_impl, mock_inference_api): async def test_create_openai_response_with_invalid_text_format(openai_responses_impl, mock_inference_api):
"""Test creating an OpenAI response with an invalid text format.""" """Test creating an OpenAI response with an invalid text format."""
# Setup # Setup

View file

@ -9,7 +9,6 @@ from datetime import datetime
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
import pytest_asyncio
from llama_stack.apis.agents import Turn from llama_stack.apis.agents import Turn
from llama_stack.apis.inference import CompletionMessage, StopReason from llama_stack.apis.inference import CompletionMessage, StopReason
@ -17,13 +16,12 @@ from llama_stack.distribution.datatypes import User
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
@pytest_asyncio.fixture @pytest.fixture
async def test_setup(sqlite_kvstore): async def test_setup(sqlite_kvstore):
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore, policy={}) agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore, policy={})
yield agent_persistence yield agent_persistence
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user") @patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_session_creation_with_access_attributes(mock_get_authenticated_user, test_setup): async def test_session_creation_with_access_attributes(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup agent_persistence = test_setup
@ -44,7 +42,6 @@ async def test_session_creation_with_access_attributes(mock_get_authenticated_us
assert session_info.owner.attributes["teams"] == ["ai-team"] assert session_info.owner.attributes["teams"] == ["ai-team"]
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user") @patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_session_access_control(mock_get_authenticated_user, test_setup): async def test_session_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup agent_persistence = test_setup
@ -79,7 +76,6 @@ async def test_session_access_control(mock_get_authenticated_user, test_setup):
assert retrieved_session is None assert retrieved_session is None
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user") @patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_turn_access_control(mock_get_authenticated_user, test_setup): async def test_turn_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup agent_persistence = test_setup
@ -133,7 +129,6 @@ async def test_turn_access_control(mock_get_authenticated_user, test_setup):
await agent_persistence.get_session_turns(session_id) await agent_persistence.get_session_turns(session_id)
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user") @patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_tool_call_and_infer_iters_access_control(mock_get_authenticated_user, test_setup): async def test_tool_call_and_infer_iters_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup agent_persistence = test_setup

View file

@ -14,7 +14,6 @@ from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
import pytest_asyncio
from openai.types.chat.chat_completion_chunk import ( from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk as OpenAIChatCompletionChunk, ChatCompletionChunk as OpenAIChatCompletionChunk,
) )
@ -103,7 +102,7 @@ def mock_openai_models_list():
yield mock_list yield mock_list
@pytest_asyncio.fixture(scope="module") @pytest.fixture(scope="module")
async def vllm_inference_adapter(): async def vllm_inference_adapter():
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345") config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
inference_adapter = VLLMInferenceAdapter(config) inference_adapter = VLLMInferenceAdapter(config)
@ -112,7 +111,6 @@ async def vllm_inference_adapter():
return inference_adapter return inference_adapter
@pytest.mark.asyncio
async def test_register_model_checks_vllm(mock_openai_models_list, vllm_inference_adapter): async def test_register_model_checks_vllm(mock_openai_models_list, vllm_inference_adapter):
async def mock_openai_models(): async def mock_openai_models():
yield OpenAIModel(id="foo", created=1, object="model", owned_by="test") yield OpenAIModel(id="foo", created=1, object="model", owned_by="test")
@ -125,7 +123,6 @@ async def test_register_model_checks_vllm(mock_openai_models_list, vllm_inferenc
mock_openai_models_list.assert_called() mock_openai_models_list.assert_called()
@pytest.mark.asyncio
async def test_old_vllm_tool_choice(vllm_inference_adapter): async def test_old_vllm_tool_choice(vllm_inference_adapter):
""" """
Test that we set tool_choice to none when no tools are in use Test that we set tool_choice to none when no tools are in use
@ -149,7 +146,6 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter):
assert request.tool_config.tool_choice == ToolChoice.none assert request.tool_config.tool_choice == ToolChoice.none
@pytest.mark.asyncio
async def test_tool_call_response(vllm_inference_adapter): async def test_tool_call_response(vllm_inference_adapter):
"""Verify that tool call arguments from a CompletionMessage are correctly converted """Verify that tool call arguments from a CompletionMessage are correctly converted
into the expected JSON format.""" into the expected JSON format."""
@ -192,7 +188,6 @@ async def test_tool_call_response(vllm_inference_adapter):
] ]
@pytest.mark.asyncio
async def test_tool_call_delta_empty_tool_call_buf(): async def test_tool_call_delta_empty_tool_call_buf():
""" """
Test that we don't generate extra chunks when processing a Test that we don't generate extra chunks when processing a
@ -222,7 +217,6 @@ async def test_tool_call_delta_empty_tool_call_buf():
assert chunks[1].event.stop_reason == StopReason.end_of_turn assert chunks[1].event.stop_reason == StopReason.end_of_turn
@pytest.mark.asyncio
async def test_tool_call_delta_streaming_arguments_dict(): async def test_tool_call_delta_streaming_arguments_dict():
async def mock_stream(): async def mock_stream():
mock_chunk_1 = OpenAIChatCompletionChunk( mock_chunk_1 = OpenAIChatCompletionChunk(
@ -297,7 +291,6 @@ async def test_tool_call_delta_streaming_arguments_dict():
assert chunks[2].event.event_type.value == "complete" assert chunks[2].event.event_type.value == "complete"
@pytest.mark.asyncio
async def test_multiple_tool_calls(): async def test_multiple_tool_calls():
async def mock_stream(): async def mock_stream():
mock_chunk_1 = OpenAIChatCompletionChunk( mock_chunk_1 = OpenAIChatCompletionChunk(
@ -376,7 +369,6 @@ async def test_multiple_tool_calls():
assert chunks[3].event.event_type.value == "complete" assert chunks[3].event.event_type.value == "complete"
@pytest.mark.asyncio
async def test_process_vllm_chat_completion_stream_response_no_choices(): async def test_process_vllm_chat_completion_stream_response_no_choices():
""" """
Test that we don't error out when vLLM returns no choices for a Test that we don't error out when vLLM returns no choices for a
@ -453,7 +445,6 @@ def test_chat_completion_doesnt_block_event_loop(caplog):
assert not asyncio_warnings assert not asyncio_warnings
@pytest.mark.asyncio
async def test_get_params_empty_tools(vllm_inference_adapter): async def test_get_params_empty_tools(vllm_inference_adapter):
request = ChatCompletionRequest( request = ChatCompletionRequest(
tools=[], tools=[],
@ -464,7 +455,6 @@ async def test_get_params_empty_tools(vllm_inference_adapter):
assert "tools" not in params assert "tools" not in params
@pytest.mark.asyncio
async def test_process_vllm_chat_completion_stream_response_tool_call_args_last_chunk(): async def test_process_vllm_chat_completion_stream_response_tool_call_args_last_chunk():
""" """
Tests the edge case where the model returns the arguments for the tool call in the same chunk that Tests the edge case where the model returns the arguments for the tool call in the same chunk that
@ -543,7 +533,6 @@ async def test_process_vllm_chat_completion_stream_response_tool_call_args_last_
assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments
@pytest.mark.asyncio
async def test_process_vllm_chat_completion_stream_response_no_finish_reason(): async def test_process_vllm_chat_completion_stream_response_no_finish_reason():
""" """
Tests the edge case where the model requests a tool call and stays idle without explicitly providing the Tests the edge case where the model requests a tool call and stays idle without explicitly providing the
@ -596,7 +585,6 @@ async def test_process_vllm_chat_completion_stream_response_no_finish_reason():
assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments
@pytest.mark.asyncio
async def test_process_vllm_chat_completion_stream_response_tool_without_args(): async def test_process_vllm_chat_completion_stream_response_tool_without_args():
""" """
Tests the edge case where no arguments are provided for the tool call. Tests the edge case where no arguments are provided for the tool call.
@ -645,7 +633,6 @@ async def test_process_vllm_chat_completion_stream_response_tool_without_args():
assert chunks[-2].event.delta.tool_call.arguments == {} assert chunks[-2].event.delta.tool_call.arguments == {}
@pytest.mark.asyncio
async def test_health_status_success(vllm_inference_adapter): async def test_health_status_success(vllm_inference_adapter):
""" """
Test the health method of VLLM InferenceAdapter when the connection is successful. Test the health method of VLLM InferenceAdapter when the connection is successful.
@ -679,7 +666,6 @@ async def test_health_status_success(vllm_inference_adapter):
mock_models.list.assert_called_once() mock_models.list.assert_called_once()
@pytest.mark.asyncio
async def test_health_status_failure(vllm_inference_adapter): async def test_health_status_failure(vllm_inference_adapter):
""" """
Test the health method of VLLM InferenceAdapter when the connection fails. Test the health method of VLLM InferenceAdapter when the connection fails.

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import pytest
from llama_stack.apis.common.content_types import TextContentItem from llama_stack.apis.common.content_types import TextContentItem
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
@ -23,7 +22,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
) )
@pytest.mark.asyncio
async def test_convert_message_to_openai_dict(): async def test_convert_message_to_openai_dict():
message = UserMessage(content=[TextContentItem(text="Hello, world!")], role="user") message = UserMessage(content=[TextContentItem(text="Hello, world!")], role="user")
assert await convert_message_to_openai_dict(message) == { assert await convert_message_to_openai_dict(message) == {
@ -33,7 +31,6 @@ async def test_convert_message_to_openai_dict():
# Test convert_message_to_openai_dict with a tool call # Test convert_message_to_openai_dict with a tool call
@pytest.mark.asyncio
async def test_convert_message_to_openai_dict_with_tool_call(): async def test_convert_message_to_openai_dict_with_tool_call():
message = CompletionMessage( message = CompletionMessage(
content="", content="",
@ -54,7 +51,6 @@ async def test_convert_message_to_openai_dict_with_tool_call():
} }
@pytest.mark.asyncio
async def test_convert_message_to_openai_dict_with_builtin_tool_call(): async def test_convert_message_to_openai_dict_with_builtin_tool_call():
message = CompletionMessage( message = CompletionMessage(
content="", content="",
@ -80,7 +76,6 @@ async def test_convert_message_to_openai_dict_with_builtin_tool_call():
} }
@pytest.mark.asyncio
async def test_openai_messages_to_messages_with_content_str(): async def test_openai_messages_to_messages_with_content_str():
openai_messages = [ openai_messages = [
OpenAISystemMessageParam(content="system message"), OpenAISystemMessageParam(content="system message"),
@ -98,7 +93,6 @@ async def test_openai_messages_to_messages_with_content_str():
assert llama_messages[2].content == "assistant message" assert llama_messages[2].content == "assistant message"
@pytest.mark.asyncio
async def test_openai_messages_to_messages_with_content_list(): async def test_openai_messages_to_messages_with_content_list():
openai_messages = [ openai_messages = [
OpenAISystemMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="system message")]), OpenAISystemMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="system message")]),

View file

@ -13,7 +13,6 @@ from llama_stack.apis.tools import RAGDocument
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, content_from_doc from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, content_from_doc
@pytest.mark.asyncio
async def test_content_from_doc_with_url(): async def test_content_from_doc_with_url():
"""Test extracting content from RAGDocument with URL content.""" """Test extracting content from RAGDocument with URL content."""
mock_url = URL(uri="https://example.com") mock_url = URL(uri="https://example.com")
@ -33,7 +32,6 @@ async def test_content_from_doc_with_url():
mock_instance.get.assert_called_once_with(mock_url.uri) mock_instance.get.assert_called_once_with(mock_url.uri)
@pytest.mark.asyncio
async def test_content_from_doc_with_pdf_url(): async def test_content_from_doc_with_pdf_url():
"""Test extracting content from RAGDocument with URL pointing to a PDF.""" """Test extracting content from RAGDocument with URL pointing to a PDF."""
mock_url = URL(uri="https://example.com/document.pdf") mock_url = URL(uri="https://example.com/document.pdf")
@ -58,7 +56,6 @@ async def test_content_from_doc_with_pdf_url():
mock_parse_pdf.assert_called_once_with(b"PDF binary data") mock_parse_pdf.assert_called_once_with(b"PDF binary data")
@pytest.mark.asyncio
async def test_content_from_doc_with_data_url(): async def test_content_from_doc_with_data_url():
"""Test extracting content from RAGDocument with data URL content.""" """Test extracting content from RAGDocument with data URL content."""
data_url = "data:text/plain;base64,SGVsbG8gV29ybGQ=" # "Hello World" base64 encoded data_url = "data:text/plain;base64,SGVsbG8gV29ybGQ=" # "Hello World" base64 encoded
@ -74,7 +71,6 @@ async def test_content_from_doc_with_data_url():
mock_content_from_data.assert_called_once_with(data_url) mock_content_from_data.assert_called_once_with(data_url)
@pytest.mark.asyncio
async def test_content_from_doc_with_string(): async def test_content_from_doc_with_string():
"""Test extracting content from RAGDocument with string content.""" """Test extracting content from RAGDocument with string content."""
content_string = "This is plain text content" content_string = "This is plain text content"
@ -85,7 +81,6 @@ async def test_content_from_doc_with_string():
assert result == content_string assert result == content_string
@pytest.mark.asyncio
async def test_content_from_doc_with_string_url(): async def test_content_from_doc_with_string_url():
"""Test extracting content from RAGDocument with string URL content.""" """Test extracting content from RAGDocument with string URL content."""
url_string = "https://example.com" url_string = "https://example.com"
@ -105,7 +100,6 @@ async def test_content_from_doc_with_string_url():
mock_instance.get.assert_called_once_with(url_string) mock_instance.get.assert_called_once_with(url_string)
@pytest.mark.asyncio
async def test_content_from_doc_with_string_pdf_url(): async def test_content_from_doc_with_string_pdf_url():
"""Test extracting content from RAGDocument with string URL pointing to a PDF.""" """Test extracting content from RAGDocument with string URL pointing to a PDF."""
url_string = "https://example.com/document.pdf" url_string = "https://example.com/document.pdf"
@ -130,7 +124,6 @@ async def test_content_from_doc_with_string_pdf_url():
mock_parse_pdf.assert_called_once_with(b"PDF binary data") mock_parse_pdf.assert_called_once_with(b"PDF binary data")
@pytest.mark.asyncio
async def test_content_from_doc_with_interleaved_content(): async def test_content_from_doc_with_interleaved_content():
"""Test extracting content from RAGDocument with InterleavedContent (the new case added in the commit).""" """Test extracting content from RAGDocument with InterleavedContent (the new case added in the commit)."""
interleaved_content = [TextContentItem(text="First item"), TextContentItem(text="Second item")] interleaved_content = [TextContentItem(text="First item"), TextContentItem(text="Second item")]

View file

@ -87,18 +87,15 @@ def helper(known_provider_model: ProviderModelEntry, known_provider_model2: Prov
return ModelRegistryHelper([known_provider_model, known_provider_model2]) return ModelRegistryHelper([known_provider_model, known_provider_model2])
@pytest.mark.asyncio
async def test_lookup_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None: async def test_lookup_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
assert helper.get_provider_model_id(unknown_model.model_id) is None assert helper.get_provider_model_id(unknown_model.model_id) is None
@pytest.mark.asyncio
async def test_register_unknown_provider_model(helper: ModelRegistryHelper, unknown_model: Model) -> None: async def test_register_unknown_provider_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
with pytest.raises(ValueError): with pytest.raises(ValueError):
await helper.register_model(unknown_model) await helper.register_model(unknown_model)
@pytest.mark.asyncio
async def test_register_model(helper: ModelRegistryHelper, known_model: Model) -> None: async def test_register_model(helper: ModelRegistryHelper, known_model: Model) -> None:
model = Model( model = Model(
provider_id=known_model.provider_id, provider_id=known_model.provider_id,
@ -110,7 +107,6 @@ async def test_register_model(helper: ModelRegistryHelper, known_model: Model) -
assert helper.get_provider_model_id(model.model_id) == model.provider_resource_id assert helper.get_provider_model_id(model.model_id) == model.provider_resource_id
@pytest.mark.asyncio
async def test_register_model_from_alias(helper: ModelRegistryHelper, known_model: Model) -> None: async def test_register_model_from_alias(helper: ModelRegistryHelper, known_model: Model) -> None:
model = Model( model = Model(
provider_id=known_model.provider_id, provider_id=known_model.provider_id,
@ -122,13 +118,11 @@ async def test_register_model_from_alias(helper: ModelRegistryHelper, known_mode
assert helper.get_provider_model_id(model.model_id) == known_model.provider_resource_id assert helper.get_provider_model_id(model.model_id) == known_model.provider_resource_id
@pytest.mark.asyncio
async def test_register_model_existing(helper: ModelRegistryHelper, known_model: Model) -> None: async def test_register_model_existing(helper: ModelRegistryHelper, known_model: Model) -> None:
await helper.register_model(known_model) await helper.register_model(known_model)
assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_resource_id assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_resource_id
@pytest.mark.asyncio
async def test_register_model_existing_different( async def test_register_model_existing_different(
helper: ModelRegistryHelper, known_model: Model, known_model2: Model helper: ModelRegistryHelper, known_model: Model, known_model2: Model
) -> None: ) -> None:
@ -137,7 +131,6 @@ async def test_register_model_existing_different(
await helper.register_model(known_model) await helper.register_model(known_model)
@pytest.mark.asyncio
async def test_unregister_model(helper: ModelRegistryHelper, known_model: Model) -> None: async def test_unregister_model(helper: ModelRegistryHelper, known_model: Model) -> None:
await helper.register_model(known_model) # duplicate entry await helper.register_model(known_model) # duplicate entry
assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_model_id assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_model_id
@ -145,18 +138,15 @@ async def test_unregister_model(helper: ModelRegistryHelper, known_model: Model)
assert helper.get_provider_model_id(known_model.model_id) is None assert helper.get_provider_model_id(known_model.model_id) is None
@pytest.mark.asyncio
async def test_unregister_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None: async def test_unregister_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
with pytest.raises(ValueError): with pytest.raises(ValueError):
await helper.unregister_model(unknown_model.model_id) await helper.unregister_model(unknown_model.model_id)
@pytest.mark.asyncio
async def test_register_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None: async def test_register_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
@pytest.mark.asyncio
async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None: async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
await helper.unregister_model(known_model.provider_resource_id) await helper.unregister_model(known_model.provider_resource_id)

View file

@ -11,7 +11,6 @@ import pytest
from llama_stack.providers.utils.scheduler import JobStatus, Scheduler from llama_stack.providers.utils.scheduler import JobStatus, Scheduler
@pytest.mark.asyncio
async def test_scheduler_unknown_backend(): async def test_scheduler_unknown_backend():
with pytest.raises(ValueError): with pytest.raises(ValueError):
Scheduler(backend="unknown") Scheduler(backend="unknown")
@ -26,7 +25,6 @@ async def wait_for_job_completed(sched: Scheduler, job_id: str) -> None:
raise TimeoutError(f"Job {job_id} did not complete in time.") raise TimeoutError(f"Job {job_id} did not complete in time.")
@pytest.mark.asyncio
async def test_scheduler_naive(): async def test_scheduler_naive():
sched = Scheduler() sched = Scheduler()
@ -87,7 +85,6 @@ async def test_scheduler_naive():
assert job.logs[0][0] < job.logs[1][0] assert job.logs[0][0] < job.logs[1][0]
@pytest.mark.asyncio
async def test_scheduler_naive_handler_raises(): async def test_scheduler_naive_handler_raises():
sched = Scheduler() sched = Scheduler()

View file

@ -9,7 +9,6 @@ from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np import numpy as np
import pytest import pytest
import pytest_asyncio
from llama_stack.apis.files import Files from llama_stack.apis.files import Files
from llama_stack.apis.inference import EmbeddingsResponse, Inference from llama_stack.apis.inference import EmbeddingsResponse, Inference
@ -91,13 +90,13 @@ def faiss_config():
return config return config
@pytest_asyncio.fixture @pytest.fixture
async def faiss_index(embedding_dimension): async def faiss_index(embedding_dimension):
index = await FaissIndex.create(dimension=embedding_dimension) index = await FaissIndex.create(dimension=embedding_dimension)
yield index yield index
@pytest_asyncio.fixture @pytest.fixture
async def faiss_adapter(faiss_config, mock_inference_api, mock_files_api) -> FaissVectorIOAdapter: async def faiss_adapter(faiss_config, mock_inference_api, mock_files_api) -> FaissVectorIOAdapter:
# Create the adapter # Create the adapter
adapter = FaissVectorIOAdapter(config=faiss_config, inference_api=mock_inference_api, files_api=mock_files_api) adapter = FaissVectorIOAdapter(config=faiss_config, inference_api=mock_inference_api, files_api=mock_files_api)
@ -113,7 +112,6 @@ async def faiss_adapter(faiss_config, mock_inference_api, mock_files_api) -> Fai
yield adapter yield adapter
@pytest.mark.asyncio
async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_identical( async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_identical(
faiss_index, sample_chunks, sample_embeddings, embedding_dimension faiss_index, sample_chunks, sample_embeddings, embedding_dimension
): ):
@ -136,7 +134,6 @@ async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_
assert response.chunks[1] == sample_chunks[1] assert response.chunks[1] == sample_chunks[1]
@pytest.mark.asyncio
async def test_health_success(): async def test_health_success():
"""Test that the health check returns OK status when faiss is working correctly.""" """Test that the health check returns OK status when faiss is working correctly."""
# Create a fresh instance of FaissVectorIOAdapter for testing # Create a fresh instance of FaissVectorIOAdapter for testing
@ -160,7 +157,6 @@ async def test_health_success():
mock_index_flat.assert_called_once_with(128) # VECTOR_DIMENSION is 128 mock_index_flat.assert_called_once_with(128) # VECTOR_DIMENSION is 128
@pytest.mark.asyncio
async def test_health_failure(): async def test_health_failure():
"""Test that the health check returns ERROR status when faiss encounters an error.""" """Test that the health check returns ERROR status when faiss encounters an error."""
# Create a fresh instance of FaissVectorIOAdapter for testing # Create a fresh instance of FaissVectorIOAdapter for testing

View file

@ -10,7 +10,6 @@ from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
import pytest_asyncio
from llama_stack.apis.inference import EmbeddingsResponse, Inference from llama_stack.apis.inference import EmbeddingsResponse, Inference
from llama_stack.apis.vector_io import ( from llama_stack.apis.vector_io import (
@ -68,7 +67,7 @@ def mock_api_service(sample_embeddings):
return mock_api_service return mock_api_service
@pytest_asyncio.fixture @pytest.fixture
async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service, loop) -> QdrantVectorIOAdapter: async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service, loop) -> QdrantVectorIOAdapter:
adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service) adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service)
adapter.vector_db_store = mock_vector_db_store adapter.vector_db_store = mock_vector_db_store
@ -80,7 +79,6 @@ async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service,
__QUERY = "Sample query" __QUERY = "Sample query"
@pytest.mark.asyncio
@pytest.mark.parametrize("max_query_chunks, expected_chunks", [(2, 2), (100, 60)]) @pytest.mark.parametrize("max_query_chunks, expected_chunks", [(2, 2), (100, 60)])
async def test_qdrant_adapter_returns_expected_chunks( async def test_qdrant_adapter_returns_expected_chunks(
qdrant_adapter: QdrantVectorIOAdapter, qdrant_adapter: QdrantVectorIOAdapter,
@ -111,7 +109,6 @@ def _prepare_for_json(value: Any) -> str:
@patch("llama_stack.providers.utils.telemetry.trace_protocol._prepare_for_json", new=_prepare_for_json) @patch("llama_stack.providers.utils.telemetry.trace_protocol._prepare_for_json", new=_prepare_for_json)
@pytest.mark.asyncio
async def test_qdrant_register_and_unregister_vector_db( async def test_qdrant_register_and_unregister_vector_db(
qdrant_adapter: QdrantVectorIOAdapter, qdrant_adapter: QdrantVectorIOAdapter,
mock_vector_db, mock_vector_db,

View file

@ -8,7 +8,6 @@ import asyncio
import numpy as np import numpy as np
import pytest import pytest
import pytest_asyncio
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import ( from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import (
@ -34,7 +33,7 @@ def loop():
return asyncio.new_event_loop() return asyncio.new_event_loop()
@pytest_asyncio.fixture @pytest.fixture
async def sqlite_vec_index(embedding_dimension, tmp_path_factory): async def sqlite_vec_index(embedding_dimension, tmp_path_factory):
temp_dir = tmp_path_factory.getbasetemp() temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / "test_sqlite.db") db_path = str(temp_dir / "test_sqlite.db")
@ -43,14 +42,12 @@ async def sqlite_vec_index(embedding_dimension, tmp_path_factory):
await index.delete() await index.delete()
@pytest.mark.asyncio
async def test_query_chunk_metadata(sqlite_vec_index, sample_chunks_with_metadata, sample_embeddings_with_metadata): async def test_query_chunk_metadata(sqlite_vec_index, sample_chunks_with_metadata, sample_embeddings_with_metadata):
await sqlite_vec_index.add_chunks(sample_chunks_with_metadata, sample_embeddings_with_metadata) await sqlite_vec_index.add_chunks(sample_chunks_with_metadata, sample_embeddings_with_metadata)
response = await sqlite_vec_index.query_vector(sample_embeddings_with_metadata[-1], k=2, score_threshold=0.0) response = await sqlite_vec_index.query_vector(sample_embeddings_with_metadata[-1], k=2, score_threshold=0.0)
assert response.chunks[0].chunk_metadata == sample_chunks_with_metadata[-1].chunk_metadata assert response.chunks[0].chunk_metadata == sample_chunks_with_metadata[-1].chunk_metadata
@pytest.mark.asyncio
async def test_query_chunks_full_text_search(sqlite_vec_index, sample_chunks, sample_embeddings): async def test_query_chunks_full_text_search(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
query_string = "Sentence 5" query_string = "Sentence 5"
@ -68,7 +65,6 @@ async def test_query_chunks_full_text_search(sqlite_vec_index, sample_chunks, sa
assert len(response_no_results.chunks) == 0, f"Expected 0 results, but got {len(response_no_results.chunks)}" assert len(response_no_results.chunks) == 0, f"Expected 0 results, but got {len(response_no_results.chunks)}"
@pytest.mark.asyncio
async def test_query_chunks_hybrid(sqlite_vec_index, sample_chunks, sample_embeddings): async def test_query_chunks_hybrid(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -90,7 +86,6 @@ async def test_query_chunks_hybrid(sqlite_vec_index, sample_chunks, sample_embed
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1)) assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
@pytest.mark.asyncio
async def test_query_chunks_full_text_search_k_greater_than_results(sqlite_vec_index, sample_chunks, sample_embeddings): async def test_query_chunks_full_text_search_k_greater_than_results(sqlite_vec_index, sample_chunks, sample_embeddings):
# Re-initialize with a clean index # Re-initialize with a clean index
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -103,7 +98,6 @@ async def test_query_chunks_full_text_search_k_greater_than_results(sqlite_vec_i
assert any("Sentence 1 from document 0" in chunk.content for chunk in response.chunks), "Expected chunk not found" assert any("Sentence 1 from document 0" in chunk.content for chunk in response.chunks), "Expected chunk not found"
@pytest.mark.asyncio
async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dimension): async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dimension):
"""Test that chunk IDs do not conflict across batches when inserting chunks.""" """Test that chunk IDs do not conflict across batches when inserting chunks."""
# Reduce batch size to force multiple batches for same document # Reduce batch size to force multiple batches for same document
@ -134,7 +128,6 @@ async def sqlite_vec_adapter(sqlite_connection):
await adapter.shutdown() await adapter.shutdown()
@pytest.mark.asyncio
async def test_query_chunks_hybrid_no_keyword_matches(sqlite_vec_index, sample_chunks, sample_embeddings): async def test_query_chunks_hybrid_no_keyword_matches(sqlite_vec_index, sample_chunks, sample_embeddings):
"""Test hybrid search when keyword search returns no matches - should still return vector results.""" """Test hybrid search when keyword search returns no matches - should still return vector results."""
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -163,7 +156,6 @@ async def test_query_chunks_hybrid_no_keyword_matches(sqlite_vec_index, sample_c
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1)) assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
@pytest.mark.asyncio
async def test_query_chunks_hybrid_score_threshold(sqlite_vec_index, sample_chunks, sample_embeddings): async def test_query_chunks_hybrid_score_threshold(sqlite_vec_index, sample_chunks, sample_embeddings):
"""Test hybrid search with a high score threshold.""" """Test hybrid search with a high score threshold."""
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -185,7 +177,6 @@ async def test_query_chunks_hybrid_score_threshold(sqlite_vec_index, sample_chun
assert len(response.chunks) == 0 assert len(response.chunks) == 0
@pytest.mark.asyncio
async def test_query_chunks_hybrid_different_embedding( async def test_query_chunks_hybrid_different_embedding(
sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension
): ):
@ -211,7 +202,6 @@ async def test_query_chunks_hybrid_different_embedding(
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1)) assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
@pytest.mark.asyncio
async def test_query_chunks_hybrid_rrf_ranking(sqlite_vec_index, sample_chunks, sample_embeddings): async def test_query_chunks_hybrid_rrf_ranking(sqlite_vec_index, sample_chunks, sample_embeddings):
"""Test that RRF properly combines rankings when documents appear in both search methods.""" """Test that RRF properly combines rankings when documents appear in both search methods."""
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -236,7 +226,6 @@ async def test_query_chunks_hybrid_rrf_ranking(sqlite_vec_index, sample_chunks,
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1)) assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
@pytest.mark.asyncio
async def test_query_chunks_hybrid_score_selection(sqlite_vec_index, sample_chunks, sample_embeddings): async def test_query_chunks_hybrid_score_selection(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -284,7 +273,6 @@ async def test_query_chunks_hybrid_score_selection(sqlite_vec_index, sample_chun
assert response.scores[0] == pytest.approx(2.0 / 61.0, rel=1e-6) # Should behave like RRF assert response.scores[0] == pytest.approx(2.0 / 61.0, rel=1e-6) # Should behave like RRF
@pytest.mark.asyncio
async def test_query_chunks_hybrid_mixed_results(sqlite_vec_index, sample_chunks, sample_embeddings): async def test_query_chunks_hybrid_mixed_results(sqlite_vec_index, sample_chunks, sample_embeddings):
"""Test hybrid search with documents that appear in only one search method.""" """Test hybrid search with documents that appear in only one search method."""
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -313,7 +301,6 @@ async def test_query_chunks_hybrid_mixed_results(sqlite_vec_index, sample_chunks
assert "document-2" in doc_ids # From keyword search assert "document-2" in doc_ids # From keyword search
@pytest.mark.asyncio
async def test_query_chunks_hybrid_weighted_reranker_parametrization( async def test_query_chunks_hybrid_weighted_reranker_parametrization(
sqlite_vec_index, sample_chunks, sample_embeddings sqlite_vec_index, sample_chunks, sample_embeddings
): ):
@ -369,7 +356,6 @@ async def test_query_chunks_hybrid_weighted_reranker_parametrization(
) )
@pytest.mark.asyncio
async def test_query_chunks_hybrid_rrf_impact_factor(sqlite_vec_index, sample_chunks, sample_embeddings): async def test_query_chunks_hybrid_rrf_impact_factor(sqlite_vec_index, sample_chunks, sample_embeddings):
"""Test RRFReRanker with different impact factors.""" """Test RRFReRanker with different impact factors."""
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -401,7 +387,6 @@ async def test_query_chunks_hybrid_rrf_impact_factor(sqlite_vec_index, sample_ch
assert response.scores[0] == pytest.approx(2.0 / 101.0, rel=1e-6) assert response.scores[0] == pytest.approx(2.0 / 101.0, rel=1e-6)
@pytest.mark.asyncio
async def test_query_chunks_hybrid_edge_cases(sqlite_vec_index, sample_chunks, sample_embeddings): async def test_query_chunks_hybrid_edge_cases(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -445,7 +430,6 @@ async def test_query_chunks_hybrid_edge_cases(sqlite_vec_index, sample_chunks, s
assert len(response.chunks) <= 100 assert len(response.chunks) <= 100
@pytest.mark.asyncio
async def test_query_chunks_hybrid_tie_breaking( async def test_query_chunks_hybrid_tie_breaking(
sqlite_vec_index, sample_embeddings, embedding_dimension, tmp_path_factory sqlite_vec_index, sample_embeddings, embedding_dimension, tmp_path_factory
): ):

View file

@ -25,12 +25,10 @@ from llama_stack.providers.remote.vector_io.milvus.milvus import VECTOR_DBS_PREF
# -v -s --tb=short --disable-warnings --asyncio-mode=auto # -v -s --tb=short --disable-warnings --asyncio-mode=auto
@pytest.mark.asyncio
async def test_initialize_index(vector_index): async def test_initialize_index(vector_index):
await vector_index.initialize() await vector_index.initialize()
@pytest.mark.asyncio
async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embeddings): async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embeddings):
vector_index.delete() vector_index.delete()
vector_index.initialize() vector_index.initialize()
@ -40,7 +38,6 @@ async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embed
vector_index.delete() vector_index.delete()
@pytest.mark.asyncio
async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimension): async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimension):
embeddings = np.random.rand(len(sample_chunks), embedding_dimension).astype(np.float32) embeddings = np.random.rand(len(sample_chunks), embedding_dimension).astype(np.float32)
await vector_index.add_chunks(sample_chunks, embeddings) await vector_index.add_chunks(sample_chunks, embeddings)
@ -54,7 +51,6 @@ async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimensio
assert len(contents) == len(set(contents)) assert len(contents) == len(set(contents))
@pytest.mark.asyncio
async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter): async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter):
key = f"{VECTOR_DBS_PREFIX}db1" key = f"{VECTOR_DBS_PREFIX}db1"
dummy = VectorDB( dummy = VectorDB(
@ -65,7 +61,6 @@ async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter):
await vector_io_adapter.initialize() await vector_io_adapter.initialize()
@pytest.mark.asyncio
async def test_persistence_across_adapter_restarts(vector_io_adapter): async def test_persistence_across_adapter_restarts(vector_io_adapter):
await vector_io_adapter.initialize() await vector_io_adapter.initialize()
dummy = VectorDB( dummy = VectorDB(
@ -79,7 +74,6 @@ async def test_persistence_across_adapter_restarts(vector_io_adapter):
await vector_io_adapter.shutdown() await vector_io_adapter.shutdown()
@pytest.mark.asyncio
async def test_register_and_unregister_vector_db(vector_io_adapter): async def test_register_and_unregister_vector_db(vector_io_adapter):
unique_id = f"foo_db_{np.random.randint(1e6)}" unique_id = f"foo_db_{np.random.randint(1e6)}"
dummy = VectorDB( dummy = VectorDB(
@ -92,14 +86,12 @@ async def test_register_and_unregister_vector_db(vector_io_adapter):
assert dummy.identifier not in vector_io_adapter.cache assert dummy.identifier not in vector_io_adapter.cache
@pytest.mark.asyncio
async def test_query_unregistered_raises(vector_io_adapter): async def test_query_unregistered_raises(vector_io_adapter):
fake_emb = np.zeros(8, dtype=np.float32) fake_emb = np.zeros(8, dtype=np.float32)
with pytest.raises(ValueError): with pytest.raises(ValueError):
await vector_io_adapter.query_chunks("no_such_db", fake_emb) await vector_io_adapter.query_chunks("no_such_db", fake_emb)
@pytest.mark.asyncio
async def test_insert_chunks_calls_underlying_index(vector_io_adapter): async def test_insert_chunks_calls_underlying_index(vector_io_adapter):
fake_index = AsyncMock() fake_index = AsyncMock()
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index) vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index)
@ -110,7 +102,6 @@ async def test_insert_chunks_calls_underlying_index(vector_io_adapter):
fake_index.insert_chunks.assert_awaited_once_with(chunks) fake_index.insert_chunks.assert_awaited_once_with(chunks)
@pytest.mark.asyncio
async def test_insert_chunks_missing_db_raises(vector_io_adapter): async def test_insert_chunks_missing_db_raises(vector_io_adapter):
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None) vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
@ -118,7 +109,6 @@ async def test_insert_chunks_missing_db_raises(vector_io_adapter):
await vector_io_adapter.insert_chunks("db_not_exist", []) await vector_io_adapter.insert_chunks("db_not_exist", [])
@pytest.mark.asyncio
async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter): async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter):
expected = QueryChunksResponse(chunks=[Chunk(content="c1")], scores=[0.1]) expected = QueryChunksResponse(chunks=[Chunk(content="c1")], scores=[0.1])
fake_index = AsyncMock(query_chunks=AsyncMock(return_value=expected)) fake_index = AsyncMock(query_chunks=AsyncMock(return_value=expected))
@ -130,7 +120,6 @@ async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter
assert response is expected assert response is expected
@pytest.mark.asyncio
async def test_query_chunks_missing_db_raises(vector_io_adapter): async def test_query_chunks_missing_db_raises(vector_io_adapter):
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None) vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
@ -138,7 +127,6 @@ async def test_query_chunks_missing_db_raises(vector_io_adapter):
await vector_io_adapter.query_chunks("db_missing", "q", None) await vector_io_adapter.query_chunks("db_missing", "q", None)
@pytest.mark.asyncio
async def test_save_openai_vector_store(vector_io_adapter): async def test_save_openai_vector_store(vector_io_adapter):
store_id = "vs_1234" store_id = "vs_1234"
openai_vector_store = { openai_vector_store = {
@ -155,7 +143,6 @@ async def test_save_openai_vector_store(vector_io_adapter):
assert vector_io_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store assert vector_io_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store
@pytest.mark.asyncio
async def test_update_openai_vector_store(vector_io_adapter): async def test_update_openai_vector_store(vector_io_adapter):
store_id = "vs_1234" store_id = "vs_1234"
openai_vector_store = { openai_vector_store = {
@ -172,7 +159,6 @@ async def test_update_openai_vector_store(vector_io_adapter):
assert vector_io_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store assert vector_io_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store
@pytest.mark.asyncio
async def test_delete_openai_vector_store(vector_io_adapter): async def test_delete_openai_vector_store(vector_io_adapter):
store_id = "vs_1234" store_id = "vs_1234"
openai_vector_store = { openai_vector_store = {
@ -188,7 +174,6 @@ async def test_delete_openai_vector_store(vector_io_adapter):
assert openai_vector_store["id"] not in vector_io_adapter.openai_vector_stores assert openai_vector_store["id"] not in vector_io_adapter.openai_vector_stores
@pytest.mark.asyncio
async def test_load_openai_vector_stores(vector_io_adapter): async def test_load_openai_vector_stores(vector_io_adapter):
store_id = "vs_1234" store_id = "vs_1234"
openai_vector_store = { openai_vector_store = {
@ -204,7 +189,6 @@ async def test_load_openai_vector_stores(vector_io_adapter):
assert loaded_stores[store_id] == openai_vector_store assert loaded_stores[store_id] == openai_vector_store
@pytest.mark.asyncio
async def test_save_openai_vector_store_file(vector_io_adapter, tmp_path_factory): async def test_save_openai_vector_store_file(vector_io_adapter, tmp_path_factory):
store_id = "vs_1234" store_id = "vs_1234"
file_id = "file_1234" file_id = "file_1234"
@ -226,7 +210,6 @@ async def test_save_openai_vector_store_file(vector_io_adapter, tmp_path_factory
await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents) await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
@pytest.mark.asyncio
async def test_update_openai_vector_store_file(vector_io_adapter, tmp_path_factory): async def test_update_openai_vector_store_file(vector_io_adapter, tmp_path_factory):
store_id = "vs_1234" store_id = "vs_1234"
file_id = "file_1234" file_id = "file_1234"
@ -260,7 +243,6 @@ async def test_update_openai_vector_store_file(vector_io_adapter, tmp_path_facto
assert loaded_contents != file_info assert loaded_contents != file_info
@pytest.mark.asyncio
async def test_load_openai_vector_store_file_contents(vector_io_adapter, tmp_path_factory): async def test_load_openai_vector_store_file_contents(vector_io_adapter, tmp_path_factory):
store_id = "vs_1234" store_id = "vs_1234"
file_id = "file_1234" file_id = "file_1234"
@ -284,7 +266,6 @@ async def test_load_openai_vector_store_file_contents(vector_io_adapter, tmp_pat
assert loaded_contents == file_contents assert loaded_contents == file_contents
@pytest.mark.asyncio
async def test_delete_openai_vector_store_file_from_storage(vector_io_adapter, tmp_path_factory): async def test_delete_openai_vector_store_file_from_storage(vector_io_adapter, tmp_path_factory):
store_id = "vs_1234" store_id = "vs_1234"
file_id = "file_1234" file_id = "file_1234"

View file

@ -17,13 +17,11 @@ from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRunti
class TestRagQuery: class TestRagQuery:
@pytest.mark.asyncio
async def test_query_raises_on_empty_vector_db_ids(self): async def test_query_raises_on_empty_vector_db_ids(self):
rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock()) rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock())
with pytest.raises(ValueError): with pytest.raises(ValueError):
await rag_tool.query(content=MagicMock(), vector_db_ids=[]) await rag_tool.query(content=MagicMock(), vector_db_ids=[])
@pytest.mark.asyncio
async def test_query_chunk_metadata_handling(self): async def test_query_chunk_metadata_handling(self):
rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock()) rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock())
content = "test query content" content = "test query content"

View file

@ -112,7 +112,6 @@ class TestValidateEmbedding:
class TestVectorStore: class TestVectorStore:
@pytest.mark.asyncio
async def test_returns_content_from_pdf_data_uri(self): async def test_returns_content_from_pdf_data_uri(self):
data_uri = data_url_from_file(DUMMY_PDF_PATH) data_uri = data_url_from_file(DUMMY_PDF_PATH)
doc = RAGDocument( doc = RAGDocument(
@ -124,7 +123,6 @@ class TestVectorStore:
content = await content_from_doc(doc) content = await content_from_doc(doc)
assert content in DUMMY_PDF_TEXT_CHOICES assert content in DUMMY_PDF_TEXT_CHOICES
@pytest.mark.asyncio
async def test_downloads_pdf_and_returns_content(self): async def test_downloads_pdf_and_returns_content(self):
# Using GitHub to host the PDF file # Using GitHub to host the PDF file
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf" url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
@ -137,7 +135,6 @@ class TestVectorStore:
content = await content_from_doc(doc) content = await content_from_doc(doc)
assert content in DUMMY_PDF_TEXT_CHOICES assert content in DUMMY_PDF_TEXT_CHOICES
@pytest.mark.asyncio
async def test_downloads_pdf_and_returns_content_with_url_object(self): async def test_downloads_pdf_and_returns_content_with_url_object(self):
# Using GitHub to host the PDF file # Using GitHub to host the PDF file
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf" url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
@ -204,7 +201,6 @@ class TestVectorStore:
class TestVectorDBWithIndex: class TestVectorDBWithIndex:
@pytest.mark.asyncio
async def test_insert_chunks_without_embeddings(self): async def test_insert_chunks_without_embeddings(self):
mock_vector_db = MagicMock() mock_vector_db = MagicMock()
mock_vector_db.embedding_model = "test-model without embeddings" mock_vector_db.embedding_model = "test-model without embeddings"
@ -230,7 +226,6 @@ class TestVectorDBWithIndex:
assert args[0] == chunks assert args[0] == chunks
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
@pytest.mark.asyncio
async def test_insert_chunks_with_valid_embeddings(self): async def test_insert_chunks_with_valid_embeddings(self):
mock_vector_db = MagicMock() mock_vector_db = MagicMock()
mock_vector_db.embedding_model = "test-model with embeddings" mock_vector_db.embedding_model = "test-model with embeddings"
@ -255,7 +250,6 @@ class TestVectorDBWithIndex:
assert args[0] == chunks assert args[0] == chunks
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
@pytest.mark.asyncio
async def test_insert_chunks_with_invalid_embeddings(self): async def test_insert_chunks_with_invalid_embeddings(self):
mock_vector_db = MagicMock() mock_vector_db = MagicMock()
mock_vector_db.embedding_dimension = 3 mock_vector_db.embedding_dimension = 3
@ -295,7 +289,6 @@ class TestVectorDBWithIndex:
mock_inference_api.embeddings.assert_not_called() mock_inference_api.embeddings.assert_not_called()
mock_index.add_chunks.assert_not_called() mock_index.add_chunks.assert_not_called()
@pytest.mark.asyncio
async def test_insert_chunks_with_partially_precomputed_embeddings(self): async def test_insert_chunks_with_partially_precomputed_embeddings(self):
mock_vector_db = MagicMock() mock_vector_db = MagicMock()
mock_vector_db.embedding_model = "test-model with partial embeddings" mock_vector_db.embedding_model = "test-model with partial embeddings"

View file

@ -38,14 +38,12 @@ def sample_model():
) )
@pytest.mark.asyncio
async def test_registry_initialization(disk_dist_registry): async def test_registry_initialization(disk_dist_registry):
# Test empty registry # Test empty registry
result = await disk_dist_registry.get("nonexistent", "nonexistent") result = await disk_dist_registry.get("nonexistent", "nonexistent")
assert result is None assert result is None
@pytest.mark.asyncio
async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_model): async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_model):
print(f"Registering {sample_vector_db}") print(f"Registering {sample_vector_db}")
await disk_dist_registry.register(sample_vector_db) await disk_dist_registry.register(sample_vector_db)
@ -64,7 +62,6 @@ async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_m
assert result_model.provider_id == sample_model.provider_id assert result_model.provider_id == sample_model.provider_id
@pytest.mark.asyncio
async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db, sample_model): async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db, sample_model):
# First populate the disk registry # First populate the disk registry
disk_registry = DiskDistributionRegistry(sqlite_kvstore) disk_registry = DiskDistributionRegistry(sqlite_kvstore)
@ -85,7 +82,6 @@ async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db,
assert result_vector_db.provider_id == sample_vector_db.provider_id assert result_vector_db.provider_id == sample_vector_db.provider_id
@pytest.mark.asyncio
async def test_cached_registry_updates(cached_disk_dist_registry): async def test_cached_registry_updates(cached_disk_dist_registry):
new_vector_db = VectorDB( new_vector_db = VectorDB(
identifier="test_vector_db_2", identifier="test_vector_db_2",
@ -112,7 +108,6 @@ async def test_cached_registry_updates(cached_disk_dist_registry):
assert result_vector_db.provider_id == new_vector_db.provider_id assert result_vector_db.provider_id == new_vector_db.provider_id
@pytest.mark.asyncio
async def test_duplicate_provider_registration(cached_disk_dist_registry): async def test_duplicate_provider_registration(cached_disk_dist_registry):
original_vector_db = VectorDB( original_vector_db = VectorDB(
identifier="test_vector_db_2", identifier="test_vector_db_2",
@ -137,7 +132,6 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry):
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
@pytest.mark.asyncio
async def test_get_all_objects(cached_disk_dist_registry): async def test_get_all_objects(cached_disk_dist_registry):
# Create multiple test banks # Create multiple test banks
# Create multiple test banks # Create multiple test banks
@ -170,7 +164,6 @@ async def test_get_all_objects(cached_disk_dist_registry):
assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension
@pytest.mark.asyncio
async def test_parse_registry_values_error_handling(sqlite_kvstore): async def test_parse_registry_values_error_handling(sqlite_kvstore):
valid_db = VectorDB( valid_db = VectorDB(
identifier="valid_vector_db", identifier="valid_vector_db",
@ -209,7 +202,6 @@ async def test_parse_registry_values_error_handling(sqlite_kvstore):
assert invalid_obj is None assert invalid_obj is None
@pytest.mark.asyncio
async def test_cached_registry_error_handling(sqlite_kvstore): async def test_cached_registry_error_handling(sqlite_kvstore):
valid_db = VectorDB( valid_db = VectorDB(
identifier="valid_cached_db", identifier="valid_cached_db",

View file

@ -5,14 +5,11 @@
# the root directory of this source tree. # the root directory of this source tree.
import pytest
from llama_stack.apis.models import ModelType from llama_stack.apis.models import ModelType
from llama_stack.distribution.datatypes import ModelWithOwner, User from llama_stack.distribution.datatypes import ModelWithOwner, User
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
@pytest.mark.asyncio
async def test_registry_cache_with_acl(cached_disk_dist_registry): async def test_registry_cache_with_acl(cached_disk_dist_registry):
model = ModelWithOwner( model = ModelWithOwner(
identifier="model-acl", identifier="model-acl",
@ -48,7 +45,6 @@ async def test_registry_cache_with_acl(cached_disk_dist_registry):
assert new_model.owner.attributes["teams"] == ["ai-team"] assert new_model.owner.attributes["teams"] == ["ai-team"]
@pytest.mark.asyncio
async def test_registry_empty_acl(cached_disk_dist_registry): async def test_registry_empty_acl(cached_disk_dist_registry):
model = ModelWithOwner( model = ModelWithOwner(
identifier="model-empty-acl", identifier="model-empty-acl",
@ -85,7 +81,6 @@ async def test_registry_empty_acl(cached_disk_dist_registry):
assert len(all_models) == 2 assert len(all_models) == 2
@pytest.mark.asyncio
async def test_registry_serialization(cached_disk_dist_registry): async def test_registry_serialization(cached_disk_dist_registry):
attributes = { attributes = {
"roles": ["admin", "researcher"], "roles": ["admin", "researcher"],

View file

@ -7,7 +7,6 @@
from unittest.mock import MagicMock, Mock, patch from unittest.mock import MagicMock, Mock, patch
import pytest import pytest
import pytest_asyncio
import yaml import yaml
from pydantic import TypeAdapter, ValidationError from pydantic import TypeAdapter, ValidationError
@ -27,7 +26,7 @@ def _return_model(model):
return model return model
@pytest_asyncio.fixture @pytest.fixture
async def test_setup(cached_disk_dist_registry): async def test_setup(cached_disk_dist_registry):
mock_inference = Mock() mock_inference = Mock()
mock_inference.__provider_spec__ = MagicMock() mock_inference.__provider_spec__ = MagicMock()
@ -41,7 +40,6 @@ async def test_setup(cached_disk_dist_registry):
yield cached_disk_dist_registry, routing_table yield cached_disk_dist_registry, routing_table
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user") @patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_control_with_cache(mock_get_authenticated_user, test_setup): async def test_access_control_with_cache(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup registry, routing_table = test_setup
@ -106,7 +104,6 @@ async def test_access_control_with_cache(mock_get_authenticated_user, test_setup
await routing_table.get_model("model-admin") await routing_table.get_model("model-admin")
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user") @patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_control_and_updates(mock_get_authenticated_user, test_setup): async def test_access_control_and_updates(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup registry, routing_table = test_setup
@ -145,7 +142,6 @@ async def test_access_control_and_updates(mock_get_authenticated_user, test_setu
assert model.identifier == "model-updates" assert model.identifier == "model-updates"
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user") @patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_control_empty_attributes(mock_get_authenticated_user, test_setup): async def test_access_control_empty_attributes(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup registry, routing_table = test_setup
@ -170,7 +166,6 @@ async def test_access_control_empty_attributes(mock_get_authenticated_user, test
assert "model-empty-attrs" in model_ids assert "model-empty-attrs" in model_ids
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user") @patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_no_user_attributes(mock_get_authenticated_user, test_setup): async def test_no_user_attributes(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup registry, routing_table = test_setup
@ -201,7 +196,6 @@ async def test_no_user_attributes(mock_get_authenticated_user, test_setup):
assert all_models.data[0].identifier == "model-public-2" assert all_models.data[0].identifier == "model-public-2"
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user") @patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_automatic_access_attributes(mock_get_authenticated_user, test_setup): async def test_automatic_access_attributes(mock_get_authenticated_user, test_setup):
"""Test that newly created resources inherit access attributes from their creator.""" """Test that newly created resources inherit access attributes from their creator."""
@ -246,7 +240,7 @@ async def test_automatic_access_attributes(mock_get_authenticated_user, test_set
assert model.identifier == "auto-access-model" assert model.identifier == "auto-access-model"
@pytest_asyncio.fixture @pytest.fixture
async def test_setup_with_access_policy(cached_disk_dist_registry): async def test_setup_with_access_policy(cached_disk_dist_registry):
mock_inference = Mock() mock_inference = Mock()
mock_inference.__provider_spec__ = MagicMock() mock_inference.__provider_spec__ = MagicMock()
@ -281,7 +275,6 @@ async def test_setup_with_access_policy(cached_disk_dist_registry):
yield routing_table yield routing_table
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user") @patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_policy(mock_get_authenticated_user, test_setup_with_access_policy): async def test_access_policy(mock_get_authenticated_user, test_setup_with_access_policy):
routing_table = test_setup_with_access_policy routing_table = test_setup_with_access_policy

View file

@ -202,7 +202,6 @@ def test_http_auth_request_payload(http_client, valid_api_key, mock_auth_endpoin
assert "param2" in payload["request"]["params"] assert "param2" in payload["request"]["params"]
@pytest.mark.asyncio
async def test_http_middleware_with_access_attributes(mock_http_middleware, mock_scope): async def test_http_middleware_with_access_attributes(mock_http_middleware, mock_scope):
"""Test HTTP middleware behavior with access attributes""" """Test HTTP middleware behavior with access attributes"""
middleware, mock_app = mock_http_middleware middleware, mock_app = mock_http_middleware

View file

@ -9,7 +9,6 @@ import sys
from typing import Any, Protocol from typing import Any, Protocol
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
import pytest
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.apis.inference import Inference from llama_stack.apis.inference import Inference
@ -66,7 +65,6 @@ class SampleImpl:
pass pass
@pytest.mark.asyncio
async def test_resolve_impls_basic(): async def test_resolve_impls_basic():
# Create a real provider spec # Create a real provider spec
provider_spec = InlineProviderSpec( provider_spec = InlineProviderSpec(

View file

@ -7,13 +7,10 @@
import asyncio import asyncio
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
import pytest
from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.distribution.server.server import create_dynamic_typed_route, create_sse_event, sse_generator from llama_stack.distribution.server.server import create_dynamic_typed_route, create_sse_event, sse_generator
@pytest.mark.asyncio
async def test_sse_generator_basic(): async def test_sse_generator_basic():
# An AsyncIterator wrapped in an Awaitable, just like our web methods # An AsyncIterator wrapped in an Awaitable, just like our web methods
async def async_event_gen(): async def async_event_gen():
@ -35,7 +32,6 @@ async def test_sse_generator_basic():
assert seen_events[1] == create_sse_event("Test event 2") assert seen_events[1] == create_sse_event("Test event 2")
@pytest.mark.asyncio
async def test_sse_generator_client_disconnected(): async def test_sse_generator_client_disconnected():
# An AsyncIterator wrapped in an Awaitable, just like our web methods # An AsyncIterator wrapped in an Awaitable, just like our web methods
async def async_event_gen(): async def async_event_gen():
@ -58,7 +54,6 @@ async def test_sse_generator_client_disconnected():
assert seen_events[0] == create_sse_event("Test event 1") assert seen_events[0] == create_sse_event("Test event 1")
@pytest.mark.asyncio
async def test_sse_generator_client_disconnected_before_response_starts(): async def test_sse_generator_client_disconnected_before_response_starts():
# Disconnect before the response starts # Disconnect before the response starts
async def async_event_gen(): async def async_event_gen():
@ -75,7 +70,6 @@ async def test_sse_generator_client_disconnected_before_response_starts():
assert len(seen_events) == 0 assert len(seen_events) == 0
@pytest.mark.asyncio
async def test_sse_generator_error_before_response_starts(): async def test_sse_generator_error_before_response_starts():
# Raise an error before the response starts # Raise an error before the response starts
async def async_event_gen(): async def async_event_gen():
@ -93,7 +87,6 @@ async def test_sse_generator_error_before_response_starts():
assert 'data: {"error":' in seen_events[0] assert 'data: {"error":' in seen_events[0]
@pytest.mark.asyncio
async def test_paginated_response_url_setting(): async def test_paginated_response_url_setting():
"""Test that PaginatedResponse gets url set to route path.""" """Test that PaginatedResponse gets url set to route path."""

View file

@ -42,7 +42,6 @@ def create_test_chat_completion(
) )
@pytest.mark.asyncio
async def test_inference_store_pagination_basic(): async def test_inference_store_pagination_basic():
"""Test basic pagination functionality.""" """Test basic pagination functionality."""
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
@ -88,7 +87,6 @@ async def test_inference_store_pagination_basic():
assert result3.has_more is False assert result3.has_more is False
@pytest.mark.asyncio
async def test_inference_store_pagination_ascending(): async def test_inference_store_pagination_ascending():
"""Test pagination with ascending order.""" """Test pagination with ascending order."""
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
@ -123,7 +121,6 @@ async def test_inference_store_pagination_ascending():
assert result2.has_more is True assert result2.has_more is True
@pytest.mark.asyncio
async def test_inference_store_pagination_with_model_filter(): async def test_inference_store_pagination_with_model_filter():
"""Test pagination combined with model filtering.""" """Test pagination combined with model filtering."""
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
@ -161,7 +158,6 @@ async def test_inference_store_pagination_with_model_filter():
assert result2.has_more is False assert result2.has_more is False
@pytest.mark.asyncio
async def test_inference_store_pagination_invalid_after(): async def test_inference_store_pagination_invalid_after():
"""Test error handling for invalid 'after' parameter.""" """Test error handling for invalid 'after' parameter."""
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
@ -174,7 +170,6 @@ async def test_inference_store_pagination_invalid_after():
await store.list_chat_completions(after="non-existent", limit=2) await store.list_chat_completions(after="non-existent", limit=2)
@pytest.mark.asyncio
async def test_inference_store_pagination_no_limit(): async def test_inference_store_pagination_no_limit():
"""Test pagination behavior when no limit is specified.""" """Test pagination behavior when no limit is specified."""
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:

View file

@ -44,7 +44,6 @@ def create_test_response_input(content: str, input_id: str) -> OpenAIResponseInp
) )
@pytest.mark.asyncio
async def test_responses_store_pagination_basic(): async def test_responses_store_pagination_basic():
"""Test basic pagination functionality for responses store.""" """Test basic pagination functionality for responses store."""
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
@ -90,7 +89,6 @@ async def test_responses_store_pagination_basic():
assert result3.has_more is False assert result3.has_more is False
@pytest.mark.asyncio
async def test_responses_store_pagination_ascending(): async def test_responses_store_pagination_ascending():
"""Test pagination with ascending order.""" """Test pagination with ascending order."""
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
@ -125,7 +123,6 @@ async def test_responses_store_pagination_ascending():
assert result2.has_more is True assert result2.has_more is True
@pytest.mark.asyncio
async def test_responses_store_pagination_with_model_filter(): async def test_responses_store_pagination_with_model_filter():
"""Test pagination combined with model filtering.""" """Test pagination combined with model filtering."""
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
@ -163,7 +160,6 @@ async def test_responses_store_pagination_with_model_filter():
assert result2.has_more is False assert result2.has_more is False
@pytest.mark.asyncio
async def test_responses_store_pagination_invalid_after(): async def test_responses_store_pagination_invalid_after():
"""Test error handling for invalid 'after' parameter.""" """Test error handling for invalid 'after' parameter."""
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
@ -176,7 +172,6 @@ async def test_responses_store_pagination_invalid_after():
await store.list_responses(after="non-existent", limit=2) await store.list_responses(after="non-existent", limit=2)
@pytest.mark.asyncio
async def test_responses_store_pagination_no_limit(): async def test_responses_store_pagination_no_limit():
"""Test pagination behavior when no limit is specified.""" """Test pagination behavior when no limit is specified."""
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
@ -205,7 +200,6 @@ async def test_responses_store_pagination_no_limit():
assert result.has_more is False assert result.has_more is False
@pytest.mark.asyncio
async def test_responses_store_get_response_object(): async def test_responses_store_get_response_object():
"""Test retrieving a single response object.""" """Test retrieving a single response object."""
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
@ -230,7 +224,6 @@ async def test_responses_store_get_response_object():
await store.get_response_object("non-existent") await store.get_response_object("non-existent")
@pytest.mark.asyncio
async def test_responses_store_input_items_pagination(): async def test_responses_store_input_items_pagination():
"""Test pagination functionality for input items.""" """Test pagination functionality for input items."""
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
@ -308,7 +301,6 @@ async def test_responses_store_input_items_pagination():
await store.list_response_input_items("test-resp", before="some-id", after="other-id") await store.list_response_input_items("test-resp", before="some-id", after="other-id")
@pytest.mark.asyncio
async def test_responses_store_input_items_before_pagination(): async def test_responses_store_input_items_before_pagination():
"""Test before pagination functionality for input items.""" """Test before pagination functionality for input items."""
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:

View file

@ -14,7 +14,6 @@ from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemyS
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
@pytest.mark.asyncio
async def test_sqlite_sqlstore(): async def test_sqlite_sqlstore():
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
db_name = "test.db" db_name = "test.db"
@ -66,7 +65,6 @@ async def test_sqlite_sqlstore():
assert result.has_more is False assert result.has_more is False
@pytest.mark.asyncio
async def test_sqlstore_pagination_basic(): async def test_sqlstore_pagination_basic():
"""Test basic pagination functionality at the SQL store level.""" """Test basic pagination functionality at the SQL store level."""
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
@ -131,7 +129,6 @@ async def test_sqlstore_pagination_basic():
assert result3.has_more is False assert result3.has_more is False
@pytest.mark.asyncio
async def test_sqlstore_pagination_with_filter(): async def test_sqlstore_pagination_with_filter():
"""Test pagination with WHERE conditions.""" """Test pagination with WHERE conditions."""
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
@ -184,7 +181,6 @@ async def test_sqlstore_pagination_with_filter():
assert result2.has_more is False assert result2.has_more is False
@pytest.mark.asyncio
async def test_sqlstore_pagination_ascending_order(): async def test_sqlstore_pagination_ascending_order():
"""Test pagination with ascending order.""" """Test pagination with ascending order."""
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
@ -233,7 +229,6 @@ async def test_sqlstore_pagination_ascending_order():
assert result2.has_more is True assert result2.has_more is True
@pytest.mark.asyncio
async def test_sqlstore_pagination_multi_column_ordering_error(): async def test_sqlstore_pagination_multi_column_ordering_error():
"""Test that multi-column ordering raises an error when using cursor pagination.""" """Test that multi-column ordering raises an error when using cursor pagination."""
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
@ -271,7 +266,6 @@ async def test_sqlstore_pagination_multi_column_ordering_error():
assert result.data[0]["id"] == "task1" assert result.data[0]["id"] == "task1"
@pytest.mark.asyncio
async def test_sqlstore_pagination_cursor_requires_order_by(): async def test_sqlstore_pagination_cursor_requires_order_by():
"""Test that cursor pagination requires order_by parameter.""" """Test that cursor pagination requires order_by parameter."""
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
@ -289,7 +283,6 @@ async def test_sqlstore_pagination_cursor_requires_order_by():
) )
@pytest.mark.asyncio
async def test_sqlstore_pagination_error_handling(): async def test_sqlstore_pagination_error_handling():
"""Test error handling for invalid columns and cursor IDs.""" """Test error handling for invalid columns and cursor IDs."""
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
@ -339,7 +332,6 @@ async def test_sqlstore_pagination_error_handling():
) )
@pytest.mark.asyncio
async def test_sqlstore_pagination_custom_key_column(): async def test_sqlstore_pagination_custom_key_column():
"""Test pagination with custom primary key column (not 'id').""" """Test pagination with custom primary key column (not 'id')."""
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:

View file

@ -7,8 +7,6 @@
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from unittest.mock import patch from unittest.mock import patch
import pytest
from llama_stack.distribution.access_control.access_control import default_policy, is_action_allowed from llama_stack.distribution.access_control.access_control import default_policy, is_action_allowed
from llama_stack.distribution.access_control.datatypes import Action from llama_stack.distribution.access_control.datatypes import Action
from llama_stack.distribution.datatypes import User from llama_stack.distribution.datatypes import User
@ -18,7 +16,6 @@ from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemyS
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
@pytest.mark.asyncio
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") @patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
async def test_authorized_fetch_with_where_sql_access_control(mock_get_authenticated_user): async def test_authorized_fetch_with_where_sql_access_control(mock_get_authenticated_user):
"""Test that fetch_all works correctly with where_sql for access control""" """Test that fetch_all works correctly with where_sql for access control"""
@ -81,7 +78,6 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
assert row["title"] == "User Document" assert row["title"] == "User Document"
@pytest.mark.asyncio
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") @patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
async def test_sql_policy_consistency(mock_get_authenticated_user): async def test_sql_policy_consistency(mock_get_authenticated_user):
"""Test that SQL WHERE clause logic exactly matches is_action_allowed policy logic""" """Test that SQL WHERE clause logic exactly matches is_action_allowed policy logic"""
@ -168,7 +164,6 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
) )
@pytest.mark.asyncio
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") @patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
async def test_authorized_store_user_attribute_capture(mock_get_authenticated_user): async def test_authorized_store_user_attribute_capture(mock_get_authenticated_user):
"""Test that user attributes are properly captured during insert""" """Test that user attributes are properly captured during insert"""

17
uv.lock generated
View file

@ -1394,8 +1394,8 @@ dev = [
{ name = "black" }, { name = "black" },
{ name = "nbval" }, { name = "nbval" },
{ name = "pre-commit" }, { name = "pre-commit" },
{ name = "pytest" }, { name = "pytest", specifier = ">=8.4" },
{ name = "pytest-asyncio" }, { name = "pytest-asyncio", specifier = ">=1.0" },
{ name = "pytest-cov" }, { name = "pytest-cov" },
{ name = "pytest-html" }, { name = "pytest-html" },
{ name = "pytest-json-report" }, { name = "pytest-json-report" },
@ -2432,29 +2432,30 @@ wheels = [
[[package]] [[package]]
name = "pytest" name = "pytest"
version = "8.3.4" version = "8.4.1"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "colorama", marker = "sys_platform == 'win32'" }, { name = "colorama", marker = "sys_platform == 'win32'" },
{ name = "iniconfig" }, { name = "iniconfig" },
{ name = "packaging" }, { name = "packaging" },
{ name = "pluggy" }, { name = "pluggy" },
{ name = "pygments" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/05/35/30e0d83068951d90a01852cb1cef56e5d8a09d20c7f511634cc2f7e0372a/pytest-8.3.4.tar.gz", hash = "sha256:965370d062bce11e73868e0335abac31b4d3de0e82f4007408d242b4f8610761", size = 1445919, upload-time = "2024-12-01T12:54:25.98Z" } sdist = { url = "https://files.pythonhosted.org/packages/08/ba/45911d754e8eba3d5a841a5ce61a65a685ff1798421ac054f85aa8747dfb/pytest-8.4.1.tar.gz", hash = "sha256:7c67fd69174877359ed9371ec3af8a3d2b04741818c51e5e99cc1742251fa93c", size = 1517714, upload-time = "2025-06-18T05:48:06.109Z" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/11/92/76a1c94d3afee238333bc0a42b82935dd8f9cf8ce9e336ff87ee14d9e1cf/pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6", size = 343083, upload-time = "2024-12-01T12:54:19.735Z" }, { url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474, upload-time = "2025-06-18T05:48:03.955Z" },
] ]
[[package]] [[package]]
name = "pytest-asyncio" name = "pytest-asyncio"
version = "0.25.3" version = "1.0.0"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "pytest" }, { name = "pytest" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/f2/a8/ecbc8ede70921dd2f544ab1cadd3ff3bf842af27f87bbdea774c7baa1d38/pytest_asyncio-0.25.3.tar.gz", hash = "sha256:fc1da2cf9f125ada7e710b4ddad05518d4cee187ae9412e9ac9271003497f07a", size = 54239, upload-time = "2025-01-28T18:37:58.729Z" } sdist = { url = "https://files.pythonhosted.org/packages/d0/d4/14f53324cb1a6381bef29d698987625d80052bb33932d8e7cbf9b337b17c/pytest_asyncio-1.0.0.tar.gz", hash = "sha256:d15463d13f4456e1ead2594520216b225a16f781e144f8fdf6c5bb4667c48b3f", size = 46960, upload-time = "2025-05-26T04:54:40.484Z" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/67/17/3493c5624e48fd97156ebaec380dcaafee9506d7e2c46218ceebbb57d7de/pytest_asyncio-0.25.3-py3-none-any.whl", hash = "sha256:9e89518e0f9bd08928f97a3482fdc4e244df17529460bc038291ccaf8f85c7c3", size = 19467, upload-time = "2025-01-28T18:37:56.798Z" }, { url = "https://files.pythonhosted.org/packages/30/05/ce271016e351fddc8399e546f6e23761967ee09c8c568bbfbecb0c150171/pytest_asyncio-1.0.0-py3-none-any.whl", hash = "sha256:4f024da9f1ef945e680dc68610b52550e36590a67fd31bb3b4943979a1f90ef3", size = 15976, upload-time = "2025-05-26T04:54:39.035Z" },
] ]
[[package]] [[package]]