mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-14 17:16:09 +00:00
chore: default to pytest asyncio-mode=auto (#2730)
# What does this PR do? previously, developers who ran `./scripts/unit-tests.sh` would get `asyncio-mode=auto`, which meant `@pytest.mark.asyncio` and `@pytest_asyncio.fixture` were redundent. developers who ran `pytest` directly would get pytest's default (strict mode), would run into errors leading them to add `@pytest.mark.asyncio` / `@pytest_asyncio.fixture` to their code. with this change - - `asyncio_mode=auto` is included in `pyproject.toml` making behavior consistent for all invocations of pytest - removes all redundant `@pytest_asyncio.fixture` and `@pytest.mark.asyncio` - for good measure, requires `pytest>=8.4` and `pytest-asyncio>=1.0` ## Test Plan - `./scripts/unit-tests.sh` - `uv run pytest tests/unit`
This commit is contained in:
parent
2ebc172f33
commit
30b2e6a495
35 changed files with 29 additions and 239 deletions
|
@ -58,9 +58,9 @@ ui = [
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
dev = [
|
dev = [
|
||||||
"pytest",
|
"pytest>=8.4",
|
||||||
"pytest-timeout",
|
"pytest-timeout",
|
||||||
"pytest-asyncio",
|
"pytest-asyncio>=1.0",
|
||||||
"pytest-cov",
|
"pytest-cov",
|
||||||
"pytest-html",
|
"pytest-html",
|
||||||
"pytest-json-report",
|
"pytest-json-report",
|
||||||
|
@ -339,3 +339,6 @@ warn_required_dynamic_aliases = true
|
||||||
|
|
||||||
[tool.ruff.lint.pep8-naming]
|
[tool.ruff.lint.pep8-naming]
|
||||||
classmethod-decorators = ["classmethod", "pydantic.field_validator"]
|
classmethod-decorators = ["classmethod", "pydantic.field_validator"]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
asyncio_mode = "auto"
|
||||||
|
|
|
@ -16,4 +16,4 @@ if [ $FOUND_PYTHON -ne 0 ]; then
|
||||||
uv python install "$PYTHON_VERSION"
|
uv python install "$PYTHON_VERSION"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
uv run --python "$PYTHON_VERSION" --with-editable . --group unit pytest --asyncio-mode=auto -s -v tests/unit/ $@
|
uv run --python "$PYTHON_VERSION" --with-editable . --group unit pytest -s -v tests/unit/ $@
|
||||||
|
|
|
@ -44,7 +44,6 @@ def common_params(inference_model):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.skip(reason="This test needs to be migrated to api / client-sdk world")
|
@pytest.mark.skip(reason="This test needs to be migrated to api / client-sdk world")
|
||||||
async def test_delete_agents_and_sessions(self, agents_stack, common_params):
|
async def test_delete_agents_and_sessions(self, agents_stack, common_params):
|
||||||
agents_impl = agents_stack.impls[Api.agents]
|
agents_impl = agents_stack.impls[Api.agents]
|
||||||
|
@ -73,7 +72,6 @@ async def test_delete_agents_and_sessions(self, agents_stack, common_params):
|
||||||
assert agent_response is None
|
assert agent_response is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.skip(reason="This test needs to be migrated to api / client-sdk world")
|
@pytest.mark.skip(reason="This test needs to be migrated to api / client-sdk world")
|
||||||
async def test_get_agent_turns_and_steps(self, agents_stack, sample_messages, common_params):
|
async def test_get_agent_turns_and_steps(self, agents_stack, sample_messages, common_params):
|
||||||
agents_impl = agents_stack.impls[Api.agents]
|
agents_impl = agents_stack.impls[Api.agents]
|
||||||
|
|
|
@ -4,20 +4,17 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import pytest
|
|
||||||
from llama_stack_client import LlamaStackClient
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
from llama_stack import LlamaStackAsLibraryClient
|
from llama_stack import LlamaStackAsLibraryClient
|
||||||
|
|
||||||
|
|
||||||
class TestInspect:
|
class TestInspect:
|
||||||
@pytest.mark.asyncio
|
|
||||||
def test_health(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
def test_health(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
||||||
health = llama_stack_client.inspect.health()
|
health = llama_stack_client.inspect.health()
|
||||||
assert health is not None
|
assert health is not None
|
||||||
assert health.status == "OK"
|
assert health.status == "OK"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
def test_version(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
def test_version(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
||||||
version = llama_stack_client.inspect.version()
|
version = llama_stack_client.inspect.version()
|
||||||
assert version is not None
|
assert version is not None
|
||||||
|
|
|
@ -4,14 +4,12 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import pytest
|
|
||||||
from llama_stack_client import LlamaStackClient
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
from llama_stack import LlamaStackAsLibraryClient
|
from llama_stack import LlamaStackAsLibraryClient
|
||||||
|
|
||||||
|
|
||||||
class TestProviders:
|
class TestProviders:
|
||||||
@pytest.mark.asyncio
|
|
||||||
def test_providers(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
def test_providers(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
||||||
provider_list = llama_stack_client.providers.list()
|
provider_list = llama_stack_client.providers.list()
|
||||||
assert provider_list is not None
|
assert provider_list is not None
|
||||||
|
|
|
@ -88,7 +88,6 @@ async def cleanup_records(sql_store, table_name, record_ids):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.parametrize("backend_config", BACKEND_CONFIGS)
|
@pytest.mark.parametrize("backend_config", BACKEND_CONFIGS)
|
||||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||||
async def test_authorized_store_attributes(mock_get_authenticated_user, authorized_store, request):
|
async def test_authorized_store_attributes(mock_get_authenticated_user, authorized_store, request):
|
||||||
|
@ -183,7 +182,6 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
|
||||||
await cleanup_records(authorized_store.sql_store, table_name, ["1", "2", "3", "4", "5", "6"])
|
await cleanup_records(authorized_store.sql_store, table_name, ["1", "2", "3", "4", "5", "6"])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.parametrize("backend_config", BACKEND_CONFIGS)
|
@pytest.mark.parametrize("backend_config", BACKEND_CONFIGS)
|
||||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||||
async def test_user_ownership_policy(mock_get_authenticated_user, authorized_store, request):
|
async def test_user_ownership_policy(mock_get_authenticated_user, authorized_store, request):
|
||||||
|
|
|
@ -8,8 +8,6 @@
|
||||||
|
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import NumberType
|
from llama_stack.apis.common.type_system import NumberType
|
||||||
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
|
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
|
||||||
from llama_stack.apis.datatypes import Api
|
from llama_stack.apis.datatypes import Api
|
||||||
|
@ -119,7 +117,6 @@ class ToolGroupsImpl(Impl):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_models_routing_table(cached_disk_dist_registry):
|
async def test_models_routing_table(cached_disk_dist_registry):
|
||||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||||
await table.initialize()
|
await table.initialize()
|
||||||
|
@ -161,7 +158,6 @@ async def test_models_routing_table(cached_disk_dist_registry):
|
||||||
assert len(openai_models.data) == 0
|
assert len(openai_models.data) == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_shields_routing_table(cached_disk_dist_registry):
|
async def test_shields_routing_table(cached_disk_dist_registry):
|
||||||
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry, {})
|
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry, {})
|
||||||
await table.initialize()
|
await table.initialize()
|
||||||
|
@ -177,7 +173,6 @@ async def test_shields_routing_table(cached_disk_dist_registry):
|
||||||
assert "test-shield-2" in shield_ids
|
assert "test-shield-2" in shield_ids
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
||||||
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
|
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
|
||||||
await table.initialize()
|
await table.initialize()
|
||||||
|
@ -233,7 +228,6 @@ async def test_datasets_routing_table(cached_disk_dist_registry):
|
||||||
assert len(datasets.data) == 0
|
assert len(datasets.data) == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_scoring_functions_routing_table(cached_disk_dist_registry):
|
async def test_scoring_functions_routing_table(cached_disk_dist_registry):
|
||||||
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry, {})
|
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry, {})
|
||||||
await table.initialize()
|
await table.initialize()
|
||||||
|
@ -259,7 +253,6 @@ async def test_scoring_functions_routing_table(cached_disk_dist_registry):
|
||||||
assert "test-scoring-fn-2" in scoring_fn_ids
|
assert "test-scoring-fn-2" in scoring_fn_ids
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_benchmarks_routing_table(cached_disk_dist_registry):
|
async def test_benchmarks_routing_table(cached_disk_dist_registry):
|
||||||
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry, {})
|
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry, {})
|
||||||
await table.initialize()
|
await table.initialize()
|
||||||
|
@ -277,7 +270,6 @@ async def test_benchmarks_routing_table(cached_disk_dist_registry):
|
||||||
assert "test-benchmark" in benchmark_ids
|
assert "test-benchmark" in benchmark_ids
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_groups_routing_table(cached_disk_dist_registry):
|
async def test_tool_groups_routing_table(cached_disk_dist_registry):
|
||||||
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry, {})
|
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry, {})
|
||||||
await table.initialize()
|
await table.initialize()
|
||||||
|
|
|
@ -13,7 +13,6 @@ import pytest
|
||||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_preserve_contexts_with_exception():
|
async def test_preserve_contexts_with_exception():
|
||||||
# Create context variable
|
# Create context variable
|
||||||
context_var = ContextVar("exception_var", default="initial")
|
context_var = ContextVar("exception_var", default="initial")
|
||||||
|
@ -41,7 +40,6 @@ async def test_preserve_contexts_with_exception():
|
||||||
context_var.reset(token)
|
context_var.reset(token)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_preserve_contexts_empty_generator():
|
async def test_preserve_contexts_empty_generator():
|
||||||
# Create context variable
|
# Create context variable
|
||||||
context_var = ContextVar("empty_var", default="initial")
|
context_var = ContextVar("empty_var", default="initial")
|
||||||
|
@ -66,7 +64,6 @@ async def test_preserve_contexts_empty_generator():
|
||||||
context_var.reset(token)
|
context_var.reset(token)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_preserve_contexts_across_event_loops():
|
async def test_preserve_contexts_across_event_loops():
|
||||||
"""
|
"""
|
||||||
Test that context variables are preserved across event loop boundaries with nested generators.
|
Test that context variables are preserved across event loop boundaries with nested generators.
|
||||||
|
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
|
||||||
|
|
||||||
from llama_stack.apis.common.responses import Order
|
from llama_stack.apis.common.responses import Order
|
||||||
from llama_stack.apis.files import OpenAIFilePurpose
|
from llama_stack.apis.files import OpenAIFilePurpose
|
||||||
|
@ -29,7 +28,7 @@ class MockUploadFile:
|
||||||
return self.content
|
return self.content
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest.fixture
|
||||||
async def files_provider(tmp_path):
|
async def files_provider(tmp_path):
|
||||||
"""Create a files provider with temporary storage for testing."""
|
"""Create a files provider with temporary storage for testing."""
|
||||||
storage_dir = tmp_path / "files"
|
storage_dir = tmp_path / "files"
|
||||||
|
@ -68,7 +67,6 @@ def large_file():
|
||||||
class TestOpenAIFilesAPI:
|
class TestOpenAIFilesAPI:
|
||||||
"""Test suite for OpenAI Files API endpoints."""
|
"""Test suite for OpenAI Files API endpoints."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_upload_file_success(self, files_provider, sample_text_file):
|
async def test_upload_file_success(self, files_provider, sample_text_file):
|
||||||
"""Test successful file upload."""
|
"""Test successful file upload."""
|
||||||
# Upload file
|
# Upload file
|
||||||
|
@ -82,7 +80,6 @@ class TestOpenAIFilesAPI:
|
||||||
assert result.created_at > 0
|
assert result.created_at > 0
|
||||||
assert result.expires_at > result.created_at
|
assert result.expires_at > result.created_at
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_upload_different_purposes(self, files_provider, sample_text_file):
|
async def test_upload_different_purposes(self, files_provider, sample_text_file):
|
||||||
"""Test uploading files with different purposes."""
|
"""Test uploading files with different purposes."""
|
||||||
purposes = list(OpenAIFilePurpose)
|
purposes = list(OpenAIFilePurpose)
|
||||||
|
@ -93,7 +90,6 @@ class TestOpenAIFilesAPI:
|
||||||
uploaded_files.append(result)
|
uploaded_files.append(result)
|
||||||
assert result.purpose == purpose
|
assert result.purpose == purpose
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_upload_different_file_types(self, files_provider, sample_text_file, sample_json_file, large_file):
|
async def test_upload_different_file_types(self, files_provider, sample_text_file, sample_json_file, large_file):
|
||||||
"""Test uploading different types and sizes of files."""
|
"""Test uploading different types and sizes of files."""
|
||||||
files_to_test = [
|
files_to_test = [
|
||||||
|
@ -107,7 +103,6 @@ class TestOpenAIFilesAPI:
|
||||||
assert result.filename == expected_filename
|
assert result.filename == expected_filename
|
||||||
assert result.bytes == len(file_obj.content)
|
assert result.bytes == len(file_obj.content)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_files_empty(self, files_provider):
|
async def test_list_files_empty(self, files_provider):
|
||||||
"""Test listing files when no files exist."""
|
"""Test listing files when no files exist."""
|
||||||
result = await files_provider.openai_list_files()
|
result = await files_provider.openai_list_files()
|
||||||
|
@ -117,7 +112,6 @@ class TestOpenAIFilesAPI:
|
||||||
assert result.first_id == ""
|
assert result.first_id == ""
|
||||||
assert result.last_id == ""
|
assert result.last_id == ""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_files_with_content(self, files_provider, sample_text_file, sample_json_file):
|
async def test_list_files_with_content(self, files_provider, sample_text_file, sample_json_file):
|
||||||
"""Test listing files when files exist."""
|
"""Test listing files when files exist."""
|
||||||
# Upload multiple files
|
# Upload multiple files
|
||||||
|
@ -132,7 +126,6 @@ class TestOpenAIFilesAPI:
|
||||||
assert file1.id in file_ids
|
assert file1.id in file_ids
|
||||||
assert file2.id in file_ids
|
assert file2.id in file_ids
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_files_with_purpose_filter(self, files_provider, sample_text_file):
|
async def test_list_files_with_purpose_filter(self, files_provider, sample_text_file):
|
||||||
"""Test listing files with purpose filtering."""
|
"""Test listing files with purpose filtering."""
|
||||||
# Upload file with specific purpose
|
# Upload file with specific purpose
|
||||||
|
@ -146,7 +139,6 @@ class TestOpenAIFilesAPI:
|
||||||
assert result.data[0].id == uploaded_file.id
|
assert result.data[0].id == uploaded_file.id
|
||||||
assert result.data[0].purpose == OpenAIFilePurpose.ASSISTANTS
|
assert result.data[0].purpose == OpenAIFilePurpose.ASSISTANTS
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_files_with_limit(self, files_provider, sample_text_file):
|
async def test_list_files_with_limit(self, files_provider, sample_text_file):
|
||||||
"""Test listing files with limit parameter."""
|
"""Test listing files with limit parameter."""
|
||||||
# Upload multiple files
|
# Upload multiple files
|
||||||
|
@ -157,7 +149,6 @@ class TestOpenAIFilesAPI:
|
||||||
result = await files_provider.openai_list_files(limit=3)
|
result = await files_provider.openai_list_files(limit=3)
|
||||||
assert len(result.data) == 3
|
assert len(result.data) == 3
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_files_with_order(self, files_provider, sample_text_file):
|
async def test_list_files_with_order(self, files_provider, sample_text_file):
|
||||||
"""Test listing files with different order."""
|
"""Test listing files with different order."""
|
||||||
# Upload multiple files
|
# Upload multiple files
|
||||||
|
@ -178,7 +169,6 @@ class TestOpenAIFilesAPI:
|
||||||
# Oldest should be first
|
# Oldest should be first
|
||||||
assert result_asc.data[0].created_at <= result_asc.data[1].created_at <= result_asc.data[2].created_at
|
assert result_asc.data[0].created_at <= result_asc.data[1].created_at <= result_asc.data[2].created_at
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_retrieve_file_success(self, files_provider, sample_text_file):
|
async def test_retrieve_file_success(self, files_provider, sample_text_file):
|
||||||
"""Test successful file retrieval."""
|
"""Test successful file retrieval."""
|
||||||
# Upload file
|
# Upload file
|
||||||
|
@ -197,13 +187,11 @@ class TestOpenAIFilesAPI:
|
||||||
assert retrieved_file.created_at == uploaded_file.created_at
|
assert retrieved_file.created_at == uploaded_file.created_at
|
||||||
assert retrieved_file.expires_at == uploaded_file.expires_at
|
assert retrieved_file.expires_at == uploaded_file.expires_at
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_retrieve_file_not_found(self, files_provider):
|
async def test_retrieve_file_not_found(self, files_provider):
|
||||||
"""Test retrieving a non-existent file."""
|
"""Test retrieving a non-existent file."""
|
||||||
with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
|
with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
|
||||||
await files_provider.openai_retrieve_file("file-nonexistent")
|
await files_provider.openai_retrieve_file("file-nonexistent")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_retrieve_file_content_success(self, files_provider, sample_text_file):
|
async def test_retrieve_file_content_success(self, files_provider, sample_text_file):
|
||||||
"""Test successful file content retrieval."""
|
"""Test successful file content retrieval."""
|
||||||
# Upload file
|
# Upload file
|
||||||
|
@ -217,13 +205,11 @@ class TestOpenAIFilesAPI:
|
||||||
# Verify content
|
# Verify content
|
||||||
assert content.body == sample_text_file.content
|
assert content.body == sample_text_file.content
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_retrieve_file_content_not_found(self, files_provider):
|
async def test_retrieve_file_content_not_found(self, files_provider):
|
||||||
"""Test retrieving content of a non-existent file."""
|
"""Test retrieving content of a non-existent file."""
|
||||||
with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
|
with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
|
||||||
await files_provider.openai_retrieve_file_content("file-nonexistent")
|
await files_provider.openai_retrieve_file_content("file-nonexistent")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_file_success(self, files_provider, sample_text_file):
|
async def test_delete_file_success(self, files_provider, sample_text_file):
|
||||||
"""Test successful file deletion."""
|
"""Test successful file deletion."""
|
||||||
# Upload file
|
# Upload file
|
||||||
|
@ -245,13 +231,11 @@ class TestOpenAIFilesAPI:
|
||||||
with pytest.raises(ValueError, match=f"File with id {uploaded_file.id} not found"):
|
with pytest.raises(ValueError, match=f"File with id {uploaded_file.id} not found"):
|
||||||
await files_provider.openai_retrieve_file(uploaded_file.id)
|
await files_provider.openai_retrieve_file(uploaded_file.id)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_file_not_found(self, files_provider):
|
async def test_delete_file_not_found(self, files_provider):
|
||||||
"""Test deleting a non-existent file."""
|
"""Test deleting a non-existent file."""
|
||||||
with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
|
with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
|
||||||
await files_provider.openai_delete_file("file-nonexistent")
|
await files_provider.openai_delete_file("file-nonexistent")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_file_persistence_across_operations(self, files_provider, sample_text_file):
|
async def test_file_persistence_across_operations(self, files_provider, sample_text_file):
|
||||||
"""Test that files persist correctly across multiple operations."""
|
"""Test that files persist correctly across multiple operations."""
|
||||||
# Upload file
|
# Upload file
|
||||||
|
@ -279,7 +263,6 @@ class TestOpenAIFilesAPI:
|
||||||
files_list = await files_provider.openai_list_files()
|
files_list = await files_provider.openai_list_files()
|
||||||
assert len(files_list.data) == 0
|
assert len(files_list.data) == 0
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_multiple_files_operations(self, files_provider, sample_text_file, sample_json_file):
|
async def test_multiple_files_operations(self, files_provider, sample_text_file, sample_json_file):
|
||||||
"""Test operations with multiple files."""
|
"""Test operations with multiple files."""
|
||||||
# Upload multiple files
|
# Upload multiple files
|
||||||
|
@ -302,7 +285,6 @@ class TestOpenAIFilesAPI:
|
||||||
content = await files_provider.openai_retrieve_file_content(file2.id)
|
content = await files_provider.openai_retrieve_file_content(file2.id)
|
||||||
assert content.body == sample_json_file.content
|
assert content.body == sample_json_file.content
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_file_id_uniqueness(self, files_provider, sample_text_file):
|
async def test_file_id_uniqueness(self, files_provider, sample_text_file):
|
||||||
"""Test that each uploaded file gets a unique ID."""
|
"""Test that each uploaded file gets a unique ID."""
|
||||||
file_ids = set()
|
file_ids = set()
|
||||||
|
@ -316,7 +298,6 @@ class TestOpenAIFilesAPI:
|
||||||
file_ids.add(uploaded_file.id)
|
file_ids.add(uploaded_file.id)
|
||||||
assert uploaded_file.id.startswith("file-")
|
assert uploaded_file.id.startswith("file-")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_file_no_filename_handling(self, files_provider):
|
async def test_file_no_filename_handling(self, files_provider):
|
||||||
"""Test handling files with no filename."""
|
"""Test handling files with no filename."""
|
||||||
file_without_name = MockUploadFile(b"content", None) # No filename
|
file_without_name = MockUploadFile(b"content", None) # No filename
|
||||||
|
@ -327,7 +308,6 @@ class TestOpenAIFilesAPI:
|
||||||
|
|
||||||
assert uploaded_file.filename == "uploaded_file" # Default filename
|
assert uploaded_file.filename == "uploaded_file" # Default filename
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_after_pagination_works(self, files_provider, sample_text_file):
|
async def test_after_pagination_works(self, files_provider, sample_text_file):
|
||||||
"""Test that 'after' pagination works correctly."""
|
"""Test that 'after' pagination works correctly."""
|
||||||
# Upload multiple files to test pagination
|
# Upload multiple files to test pagination
|
||||||
|
|
|
@ -4,14 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import pytest_asyncio
|
import pytest
|
||||||
|
|
||||||
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry, DiskDistributionRegistry
|
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry, DiskDistributionRegistry
|
||||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
|
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
async def sqlite_kvstore(tmp_path):
|
async def sqlite_kvstore(tmp_path):
|
||||||
db_path = tmp_path / "test_kv.db"
|
db_path = tmp_path / "test_kv.db"
|
||||||
kvstore_config = SqliteKVStoreConfig(db_path=db_path.as_posix())
|
kvstore_config = SqliteKVStoreConfig(db_path=db_path.as_posix())
|
||||||
|
@ -20,14 +20,14 @@ async def sqlite_kvstore(tmp_path):
|
||||||
yield kvstore
|
yield kvstore
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
async def disk_dist_registry(sqlite_kvstore):
|
async def disk_dist_registry(sqlite_kvstore):
|
||||||
registry = DiskDistributionRegistry(sqlite_kvstore)
|
registry = DiskDistributionRegistry(sqlite_kvstore)
|
||||||
await registry.initialize()
|
await registry.initialize()
|
||||||
yield registry
|
yield registry
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
async def cached_disk_dist_registry(sqlite_kvstore):
|
async def cached_disk_dist_registry(sqlite_kvstore):
|
||||||
registry = CachedDiskDistributionRegistry(sqlite_kvstore)
|
registry = CachedDiskDistributionRegistry(sqlite_kvstore)
|
||||||
await registry.initialize()
|
await registry.initialize()
|
||||||
|
|
|
@ -8,7 +8,6 @@ from datetime import datetime
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
|
||||||
|
|
||||||
from llama_stack.apis.agents import (
|
from llama_stack.apis.agents import (
|
||||||
Agent,
|
Agent,
|
||||||
|
@ -50,7 +49,7 @@ def config(tmp_path):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest.fixture
|
||||||
async def agents_impl(config, mock_apis):
|
async def agents_impl(config, mock_apis):
|
||||||
impl = MetaReferenceAgentsImpl(
|
impl = MetaReferenceAgentsImpl(
|
||||||
config,
|
config,
|
||||||
|
@ -117,7 +116,6 @@ def sample_agent_config():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_agent(agents_impl, sample_agent_config):
|
async def test_create_agent(agents_impl, sample_agent_config):
|
||||||
response = await agents_impl.create_agent(sample_agent_config)
|
response = await agents_impl.create_agent(sample_agent_config)
|
||||||
|
|
||||||
|
@ -132,7 +130,6 @@ async def test_create_agent(agents_impl, sample_agent_config):
|
||||||
assert isinstance(agent_info.created_at, datetime)
|
assert isinstance(agent_info.created_at, datetime)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_agent(agents_impl, sample_agent_config):
|
async def test_get_agent(agents_impl, sample_agent_config):
|
||||||
create_response = await agents_impl.create_agent(sample_agent_config)
|
create_response = await agents_impl.create_agent(sample_agent_config)
|
||||||
agent_id = create_response.agent_id
|
agent_id = create_response.agent_id
|
||||||
|
@ -146,7 +143,6 @@ async def test_get_agent(agents_impl, sample_agent_config):
|
||||||
assert isinstance(agent.created_at, datetime)
|
assert isinstance(agent.created_at, datetime)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_agents(agents_impl, sample_agent_config):
|
async def test_list_agents(agents_impl, sample_agent_config):
|
||||||
agent1_response = await agents_impl.create_agent(sample_agent_config)
|
agent1_response = await agents_impl.create_agent(sample_agent_config)
|
||||||
agent2_response = await agents_impl.create_agent(sample_agent_config)
|
agent2_response = await agents_impl.create_agent(sample_agent_config)
|
||||||
|
@ -160,7 +156,6 @@ async def test_list_agents(agents_impl, sample_agent_config):
|
||||||
assert agent2_response.agent_id in agent_ids
|
assert agent2_response.agent_id in agent_ids
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.parametrize("enable_session_persistence", [True, False])
|
@pytest.mark.parametrize("enable_session_persistence", [True, False])
|
||||||
async def test_create_agent_session_persistence(agents_impl, sample_agent_config, enable_session_persistence):
|
async def test_create_agent_session_persistence(agents_impl, sample_agent_config, enable_session_persistence):
|
||||||
# Create an agent with specified persistence setting
|
# Create an agent with specified persistence setting
|
||||||
|
@ -188,7 +183,6 @@ async def test_create_agent_session_persistence(agents_impl, sample_agent_config
|
||||||
await agents_impl.get_agents_session(agent_id, session_response.session_id)
|
await agents_impl.get_agents_session(agent_id, session_response.session_id)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.parametrize("enable_session_persistence", [True, False])
|
@pytest.mark.parametrize("enable_session_persistence", [True, False])
|
||||||
async def test_list_agent_sessions_persistence(agents_impl, sample_agent_config, enable_session_persistence):
|
async def test_list_agent_sessions_persistence(agents_impl, sample_agent_config, enable_session_persistence):
|
||||||
# Create an agent with specified persistence setting
|
# Create an agent with specified persistence setting
|
||||||
|
@ -221,7 +215,6 @@ async def test_list_agent_sessions_persistence(agents_impl, sample_agent_config,
|
||||||
assert session2.session_id in {s["session_id"] for s in sessions.data}
|
assert session2.session_id in {s["session_id"] for s in sessions.data}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_agent(agents_impl, sample_agent_config):
|
async def test_delete_agent(agents_impl, sample_agent_config):
|
||||||
# Create an agent
|
# Create an agent
|
||||||
response = await agents_impl.create_agent(sample_agent_config)
|
response = await agents_impl.create_agent(sample_agent_config)
|
||||||
|
|
|
@ -122,7 +122,6 @@ async def fake_stream(fixture: str = "simple_chat_completion.yaml"):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_openai_response_with_string_input(openai_responses_impl, mock_inference_api):
|
async def test_create_openai_response_with_string_input(openai_responses_impl, mock_inference_api):
|
||||||
"""Test creating an OpenAI response with a simple string input."""
|
"""Test creating an OpenAI response with a simple string input."""
|
||||||
# Setup
|
# Setup
|
||||||
|
@ -155,7 +154,6 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
|
||||||
assert result.output[0].content[0].text == "Dublin"
|
assert result.output[0].content[0].text == "Dublin"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_openai_response_with_string_input_with_tools(openai_responses_impl, mock_inference_api):
|
async def test_create_openai_response_with_string_input_with_tools(openai_responses_impl, mock_inference_api):
|
||||||
"""Test creating an OpenAI response with a simple string input and tools."""
|
"""Test creating an OpenAI response with a simple string input and tools."""
|
||||||
# Setup
|
# Setup
|
||||||
|
@ -224,7 +222,6 @@ async def test_create_openai_response_with_string_input_with_tools(openai_respon
|
||||||
assert result.output[1].content[0].annotations == []
|
assert result.output[1].content[0].annotations == []
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_openai_response_with_tool_call_type_none(openai_responses_impl, mock_inference_api):
|
async def test_create_openai_response_with_tool_call_type_none(openai_responses_impl, mock_inference_api):
|
||||||
"""Test creating an OpenAI response with a tool call response that has a type of None."""
|
"""Test creating an OpenAI response with a tool call response that has a type of None."""
|
||||||
# Setup
|
# Setup
|
||||||
|
@ -294,7 +291,6 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
|
||||||
assert chunks[1].response.output[0].name == "get_weather"
|
assert chunks[1].response.output[0].name == "get_weather"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api):
|
async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api):
|
||||||
"""Test creating an OpenAI response with multiple messages."""
|
"""Test creating an OpenAI response with multiple messages."""
|
||||||
# Setup
|
# Setup
|
||||||
|
@ -340,7 +336,6 @@ async def test_create_openai_response_with_multiple_messages(openai_responses_im
|
||||||
assert isinstance(inference_messages[i], OpenAIDeveloperMessageParam)
|
assert isinstance(inference_messages[i], OpenAIDeveloperMessageParam)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_prepend_previous_response_none(openai_responses_impl):
|
async def test_prepend_previous_response_none(openai_responses_impl):
|
||||||
"""Test prepending no previous response to a new response."""
|
"""Test prepending no previous response to a new response."""
|
||||||
|
|
||||||
|
@ -348,7 +343,6 @@ async def test_prepend_previous_response_none(openai_responses_impl):
|
||||||
assert input == "fake_input"
|
assert input == "fake_input"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_prepend_previous_response_basic(openai_responses_impl, mock_responses_store):
|
async def test_prepend_previous_response_basic(openai_responses_impl, mock_responses_store):
|
||||||
"""Test prepending a basic previous response to a new response."""
|
"""Test prepending a basic previous response to a new response."""
|
||||||
|
|
||||||
|
@ -388,7 +382,6 @@ async def test_prepend_previous_response_basic(openai_responses_impl, mock_respo
|
||||||
assert input[2].content == "fake_input"
|
assert input[2].content == "fake_input"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_prepend_previous_response_web_search(openai_responses_impl, mock_responses_store):
|
async def test_prepend_previous_response_web_search(openai_responses_impl, mock_responses_store):
|
||||||
"""Test prepending a web search previous response to a new response."""
|
"""Test prepending a web search previous response to a new response."""
|
||||||
input_item_message = OpenAIResponseMessage(
|
input_item_message = OpenAIResponseMessage(
|
||||||
|
@ -434,7 +427,6 @@ async def test_prepend_previous_response_web_search(openai_responses_impl, mock_
|
||||||
assert input[3].content == "fake_input"
|
assert input[3].content == "fake_input"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_openai_response_with_instructions(openai_responses_impl, mock_inference_api):
|
async def test_create_openai_response_with_instructions(openai_responses_impl, mock_inference_api):
|
||||||
# Setup
|
# Setup
|
||||||
input_text = "What is the capital of Ireland?"
|
input_text = "What is the capital of Ireland?"
|
||||||
|
@ -463,7 +455,6 @@ async def test_create_openai_response_with_instructions(openai_responses_impl, m
|
||||||
assert sent_messages[1].content == input_text
|
assert sent_messages[1].content == input_text
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_openai_response_with_instructions_and_multiple_messages(
|
async def test_create_openai_response_with_instructions_and_multiple_messages(
|
||||||
openai_responses_impl, mock_inference_api
|
openai_responses_impl, mock_inference_api
|
||||||
):
|
):
|
||||||
|
@ -508,7 +499,6 @@ async def test_create_openai_response_with_instructions_and_multiple_messages(
|
||||||
assert sent_messages[3].content == "Which is the largest?"
|
assert sent_messages[3].content == "Which is the largest?"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_openai_response_with_instructions_and_previous_response(
|
async def test_create_openai_response_with_instructions_and_previous_response(
|
||||||
openai_responses_impl, mock_responses_store, mock_inference_api
|
openai_responses_impl, mock_responses_store, mock_inference_api
|
||||||
):
|
):
|
||||||
|
@ -565,7 +555,6 @@ async def test_create_openai_response_with_instructions_and_previous_response(
|
||||||
assert sent_messages[3].content == "Which is the largest?"
|
assert sent_messages[3].content == "Which is the largest?"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_openai_response_input_items_delegation(openai_responses_impl, mock_responses_store):
|
async def test_list_openai_response_input_items_delegation(openai_responses_impl, mock_responses_store):
|
||||||
"""Test that list_openai_response_input_items properly delegates to responses_store with correct parameters."""
|
"""Test that list_openai_response_input_items properly delegates to responses_store with correct parameters."""
|
||||||
# Setup
|
# Setup
|
||||||
|
@ -601,7 +590,6 @@ async def test_list_openai_response_input_items_delegation(openai_responses_impl
|
||||||
assert result.data[0].id == "msg_123"
|
assert result.data[0].id == "msg_123"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_responses_store_list_input_items_logic():
|
async def test_responses_store_list_input_items_logic():
|
||||||
"""Test ResponsesStore list_response_input_items logic - mocks get_response_object to test actual ordering/limiting."""
|
"""Test ResponsesStore list_response_input_items logic - mocks get_response_object to test actual ordering/limiting."""
|
||||||
|
|
||||||
|
@ -680,7 +668,6 @@ async def test_responses_store_list_input_items_logic():
|
||||||
assert len(result.data) == 0 # Should return no items
|
assert len(result.data) == 0 # Should return no items
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_store_response_uses_rehydrated_input_with_previous_response(
|
async def test_store_response_uses_rehydrated_input_with_previous_response(
|
||||||
openai_responses_impl, mock_responses_store, mock_inference_api
|
openai_responses_impl, mock_responses_store, mock_inference_api
|
||||||
):
|
):
|
||||||
|
@ -747,7 +734,6 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
|
||||||
assert result.status == "completed"
|
assert result.status == "completed"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"text_format, response_format",
|
"text_format, response_format",
|
||||||
[
|
[
|
||||||
|
@ -787,7 +773,6 @@ async def test_create_openai_response_with_text_format(
|
||||||
assert first_call.kwargs["response_format"] == response_format
|
assert first_call.kwargs["response_format"] == response_format
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_openai_response_with_invalid_text_format(openai_responses_impl, mock_inference_api):
|
async def test_create_openai_response_with_invalid_text_format(openai_responses_impl, mock_inference_api):
|
||||||
"""Test creating an OpenAI response with an invalid text format."""
|
"""Test creating an OpenAI response with an invalid text format."""
|
||||||
# Setup
|
# Setup
|
||||||
|
|
|
@ -9,7 +9,6 @@ from datetime import datetime
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
|
||||||
|
|
||||||
from llama_stack.apis.agents import Turn
|
from llama_stack.apis.agents import Turn
|
||||||
from llama_stack.apis.inference import CompletionMessage, StopReason
|
from llama_stack.apis.inference import CompletionMessage, StopReason
|
||||||
|
@ -17,13 +16,12 @@ from llama_stack.distribution.datatypes import User
|
||||||
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
|
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest.fixture
|
||||||
async def test_setup(sqlite_kvstore):
|
async def test_setup(sqlite_kvstore):
|
||||||
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore, policy={})
|
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore, policy={})
|
||||||
yield agent_persistence
|
yield agent_persistence
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
||||||
async def test_session_creation_with_access_attributes(mock_get_authenticated_user, test_setup):
|
async def test_session_creation_with_access_attributes(mock_get_authenticated_user, test_setup):
|
||||||
agent_persistence = test_setup
|
agent_persistence = test_setup
|
||||||
|
@ -44,7 +42,6 @@ async def test_session_creation_with_access_attributes(mock_get_authenticated_us
|
||||||
assert session_info.owner.attributes["teams"] == ["ai-team"]
|
assert session_info.owner.attributes["teams"] == ["ai-team"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
||||||
async def test_session_access_control(mock_get_authenticated_user, test_setup):
|
async def test_session_access_control(mock_get_authenticated_user, test_setup):
|
||||||
agent_persistence = test_setup
|
agent_persistence = test_setup
|
||||||
|
@ -79,7 +76,6 @@ async def test_session_access_control(mock_get_authenticated_user, test_setup):
|
||||||
assert retrieved_session is None
|
assert retrieved_session is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
||||||
async def test_turn_access_control(mock_get_authenticated_user, test_setup):
|
async def test_turn_access_control(mock_get_authenticated_user, test_setup):
|
||||||
agent_persistence = test_setup
|
agent_persistence = test_setup
|
||||||
|
@ -133,7 +129,6 @@ async def test_turn_access_control(mock_get_authenticated_user, test_setup):
|
||||||
await agent_persistence.get_session_turns(session_id)
|
await agent_persistence.get_session_turns(session_id)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
|
||||||
async def test_tool_call_and_infer_iters_access_control(mock_get_authenticated_user, test_setup):
|
async def test_tool_call_and_infer_iters_access_control(mock_get_authenticated_user, test_setup):
|
||||||
agent_persistence = test_setup
|
agent_persistence = test_setup
|
||||||
|
|
|
@ -14,7 +14,6 @@ from typing import Any
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
|
||||||
from openai.types.chat.chat_completion_chunk import (
|
from openai.types.chat.chat_completion_chunk import (
|
||||||
ChatCompletionChunk as OpenAIChatCompletionChunk,
|
ChatCompletionChunk as OpenAIChatCompletionChunk,
|
||||||
)
|
)
|
||||||
|
@ -103,7 +102,7 @@ def mock_openai_models_list():
|
||||||
yield mock_list
|
yield mock_list
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
async def vllm_inference_adapter():
|
async def vllm_inference_adapter():
|
||||||
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
|
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
|
||||||
inference_adapter = VLLMInferenceAdapter(config)
|
inference_adapter = VLLMInferenceAdapter(config)
|
||||||
|
@ -112,7 +111,6 @@ async def vllm_inference_adapter():
|
||||||
return inference_adapter
|
return inference_adapter
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_register_model_checks_vllm(mock_openai_models_list, vllm_inference_adapter):
|
async def test_register_model_checks_vllm(mock_openai_models_list, vllm_inference_adapter):
|
||||||
async def mock_openai_models():
|
async def mock_openai_models():
|
||||||
yield OpenAIModel(id="foo", created=1, object="model", owned_by="test")
|
yield OpenAIModel(id="foo", created=1, object="model", owned_by="test")
|
||||||
|
@ -125,7 +123,6 @@ async def test_register_model_checks_vllm(mock_openai_models_list, vllm_inferenc
|
||||||
mock_openai_models_list.assert_called()
|
mock_openai_models_list.assert_called()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_old_vllm_tool_choice(vllm_inference_adapter):
|
async def test_old_vllm_tool_choice(vllm_inference_adapter):
|
||||||
"""
|
"""
|
||||||
Test that we set tool_choice to none when no tools are in use
|
Test that we set tool_choice to none when no tools are in use
|
||||||
|
@ -149,7 +146,6 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter):
|
||||||
assert request.tool_config.tool_choice == ToolChoice.none
|
assert request.tool_config.tool_choice == ToolChoice.none
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_call_response(vllm_inference_adapter):
|
async def test_tool_call_response(vllm_inference_adapter):
|
||||||
"""Verify that tool call arguments from a CompletionMessage are correctly converted
|
"""Verify that tool call arguments from a CompletionMessage are correctly converted
|
||||||
into the expected JSON format."""
|
into the expected JSON format."""
|
||||||
|
@ -192,7 +188,6 @@ async def test_tool_call_response(vllm_inference_adapter):
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_call_delta_empty_tool_call_buf():
|
async def test_tool_call_delta_empty_tool_call_buf():
|
||||||
"""
|
"""
|
||||||
Test that we don't generate extra chunks when processing a
|
Test that we don't generate extra chunks when processing a
|
||||||
|
@ -222,7 +217,6 @@ async def test_tool_call_delta_empty_tool_call_buf():
|
||||||
assert chunks[1].event.stop_reason == StopReason.end_of_turn
|
assert chunks[1].event.stop_reason == StopReason.end_of_turn
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_call_delta_streaming_arguments_dict():
|
async def test_tool_call_delta_streaming_arguments_dict():
|
||||||
async def mock_stream():
|
async def mock_stream():
|
||||||
mock_chunk_1 = OpenAIChatCompletionChunk(
|
mock_chunk_1 = OpenAIChatCompletionChunk(
|
||||||
|
@ -297,7 +291,6 @@ async def test_tool_call_delta_streaming_arguments_dict():
|
||||||
assert chunks[2].event.event_type.value == "complete"
|
assert chunks[2].event.event_type.value == "complete"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_multiple_tool_calls():
|
async def test_multiple_tool_calls():
|
||||||
async def mock_stream():
|
async def mock_stream():
|
||||||
mock_chunk_1 = OpenAIChatCompletionChunk(
|
mock_chunk_1 = OpenAIChatCompletionChunk(
|
||||||
|
@ -376,7 +369,6 @@ async def test_multiple_tool_calls():
|
||||||
assert chunks[3].event.event_type.value == "complete"
|
assert chunks[3].event.event_type.value == "complete"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_process_vllm_chat_completion_stream_response_no_choices():
|
async def test_process_vllm_chat_completion_stream_response_no_choices():
|
||||||
"""
|
"""
|
||||||
Test that we don't error out when vLLM returns no choices for a
|
Test that we don't error out when vLLM returns no choices for a
|
||||||
|
@ -453,7 +445,6 @@ def test_chat_completion_doesnt_block_event_loop(caplog):
|
||||||
assert not asyncio_warnings
|
assert not asyncio_warnings
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_params_empty_tools(vllm_inference_adapter):
|
async def test_get_params_empty_tools(vllm_inference_adapter):
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
tools=[],
|
tools=[],
|
||||||
|
@ -464,7 +455,6 @@ async def test_get_params_empty_tools(vllm_inference_adapter):
|
||||||
assert "tools" not in params
|
assert "tools" not in params
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_process_vllm_chat_completion_stream_response_tool_call_args_last_chunk():
|
async def test_process_vllm_chat_completion_stream_response_tool_call_args_last_chunk():
|
||||||
"""
|
"""
|
||||||
Tests the edge case where the model returns the arguments for the tool call in the same chunk that
|
Tests the edge case where the model returns the arguments for the tool call in the same chunk that
|
||||||
|
@ -543,7 +533,6 @@ async def test_process_vllm_chat_completion_stream_response_tool_call_args_last_
|
||||||
assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments
|
assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_process_vllm_chat_completion_stream_response_no_finish_reason():
|
async def test_process_vllm_chat_completion_stream_response_no_finish_reason():
|
||||||
"""
|
"""
|
||||||
Tests the edge case where the model requests a tool call and stays idle without explicitly providing the
|
Tests the edge case where the model requests a tool call and stays idle without explicitly providing the
|
||||||
|
@ -596,7 +585,6 @@ async def test_process_vllm_chat_completion_stream_response_no_finish_reason():
|
||||||
assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments
|
assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_process_vllm_chat_completion_stream_response_tool_without_args():
|
async def test_process_vllm_chat_completion_stream_response_tool_without_args():
|
||||||
"""
|
"""
|
||||||
Tests the edge case where no arguments are provided for the tool call.
|
Tests the edge case where no arguments are provided for the tool call.
|
||||||
|
@ -645,7 +633,6 @@ async def test_process_vllm_chat_completion_stream_response_tool_without_args():
|
||||||
assert chunks[-2].event.delta.tool_call.arguments == {}
|
assert chunks[-2].event.delta.tool_call.arguments == {}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_health_status_success(vllm_inference_adapter):
|
async def test_health_status_success(vllm_inference_adapter):
|
||||||
"""
|
"""
|
||||||
Test the health method of VLLM InferenceAdapter when the connection is successful.
|
Test the health method of VLLM InferenceAdapter when the connection is successful.
|
||||||
|
@ -679,7 +666,6 @@ async def test_health_status_success(vllm_inference_adapter):
|
||||||
mock_models.list.assert_called_once()
|
mock_models.list.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_health_status_failure(vllm_inference_adapter):
|
async def test_health_status_failure(vllm_inference_adapter):
|
||||||
"""
|
"""
|
||||||
Test the health method of VLLM InferenceAdapter when the connection fails.
|
Test the health method of VLLM InferenceAdapter when the connection fails.
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import TextContentItem
|
from llama_stack.apis.common.content_types import TextContentItem
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
@ -23,7 +22,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_convert_message_to_openai_dict():
|
async def test_convert_message_to_openai_dict():
|
||||||
message = UserMessage(content=[TextContentItem(text="Hello, world!")], role="user")
|
message = UserMessage(content=[TextContentItem(text="Hello, world!")], role="user")
|
||||||
assert await convert_message_to_openai_dict(message) == {
|
assert await convert_message_to_openai_dict(message) == {
|
||||||
|
@ -33,7 +31,6 @@ async def test_convert_message_to_openai_dict():
|
||||||
|
|
||||||
|
|
||||||
# Test convert_message_to_openai_dict with a tool call
|
# Test convert_message_to_openai_dict with a tool call
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_convert_message_to_openai_dict_with_tool_call():
|
async def test_convert_message_to_openai_dict_with_tool_call():
|
||||||
message = CompletionMessage(
|
message = CompletionMessage(
|
||||||
content="",
|
content="",
|
||||||
|
@ -54,7 +51,6 @@ async def test_convert_message_to_openai_dict_with_tool_call():
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_convert_message_to_openai_dict_with_builtin_tool_call():
|
async def test_convert_message_to_openai_dict_with_builtin_tool_call():
|
||||||
message = CompletionMessage(
|
message = CompletionMessage(
|
||||||
content="",
|
content="",
|
||||||
|
@ -80,7 +76,6 @@ async def test_convert_message_to_openai_dict_with_builtin_tool_call():
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_openai_messages_to_messages_with_content_str():
|
async def test_openai_messages_to_messages_with_content_str():
|
||||||
openai_messages = [
|
openai_messages = [
|
||||||
OpenAISystemMessageParam(content="system message"),
|
OpenAISystemMessageParam(content="system message"),
|
||||||
|
@ -98,7 +93,6 @@ async def test_openai_messages_to_messages_with_content_str():
|
||||||
assert llama_messages[2].content == "assistant message"
|
assert llama_messages[2].content == "assistant message"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_openai_messages_to_messages_with_content_list():
|
async def test_openai_messages_to_messages_with_content_list():
|
||||||
openai_messages = [
|
openai_messages = [
|
||||||
OpenAISystemMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="system message")]),
|
OpenAISystemMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="system message")]),
|
||||||
|
|
|
@ -13,7 +13,6 @@ from llama_stack.apis.tools import RAGDocument
|
||||||
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, content_from_doc
|
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, content_from_doc
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_content_from_doc_with_url():
|
async def test_content_from_doc_with_url():
|
||||||
"""Test extracting content from RAGDocument with URL content."""
|
"""Test extracting content from RAGDocument with URL content."""
|
||||||
mock_url = URL(uri="https://example.com")
|
mock_url = URL(uri="https://example.com")
|
||||||
|
@ -33,7 +32,6 @@ async def test_content_from_doc_with_url():
|
||||||
mock_instance.get.assert_called_once_with(mock_url.uri)
|
mock_instance.get.assert_called_once_with(mock_url.uri)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_content_from_doc_with_pdf_url():
|
async def test_content_from_doc_with_pdf_url():
|
||||||
"""Test extracting content from RAGDocument with URL pointing to a PDF."""
|
"""Test extracting content from RAGDocument with URL pointing to a PDF."""
|
||||||
mock_url = URL(uri="https://example.com/document.pdf")
|
mock_url = URL(uri="https://example.com/document.pdf")
|
||||||
|
@ -58,7 +56,6 @@ async def test_content_from_doc_with_pdf_url():
|
||||||
mock_parse_pdf.assert_called_once_with(b"PDF binary data")
|
mock_parse_pdf.assert_called_once_with(b"PDF binary data")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_content_from_doc_with_data_url():
|
async def test_content_from_doc_with_data_url():
|
||||||
"""Test extracting content from RAGDocument with data URL content."""
|
"""Test extracting content from RAGDocument with data URL content."""
|
||||||
data_url = "data:text/plain;base64,SGVsbG8gV29ybGQ=" # "Hello World" base64 encoded
|
data_url = "data:text/plain;base64,SGVsbG8gV29ybGQ=" # "Hello World" base64 encoded
|
||||||
|
@ -74,7 +71,6 @@ async def test_content_from_doc_with_data_url():
|
||||||
mock_content_from_data.assert_called_once_with(data_url)
|
mock_content_from_data.assert_called_once_with(data_url)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_content_from_doc_with_string():
|
async def test_content_from_doc_with_string():
|
||||||
"""Test extracting content from RAGDocument with string content."""
|
"""Test extracting content from RAGDocument with string content."""
|
||||||
content_string = "This is plain text content"
|
content_string = "This is plain text content"
|
||||||
|
@ -85,7 +81,6 @@ async def test_content_from_doc_with_string():
|
||||||
assert result == content_string
|
assert result == content_string
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_content_from_doc_with_string_url():
|
async def test_content_from_doc_with_string_url():
|
||||||
"""Test extracting content from RAGDocument with string URL content."""
|
"""Test extracting content from RAGDocument with string URL content."""
|
||||||
url_string = "https://example.com"
|
url_string = "https://example.com"
|
||||||
|
@ -105,7 +100,6 @@ async def test_content_from_doc_with_string_url():
|
||||||
mock_instance.get.assert_called_once_with(url_string)
|
mock_instance.get.assert_called_once_with(url_string)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_content_from_doc_with_string_pdf_url():
|
async def test_content_from_doc_with_string_pdf_url():
|
||||||
"""Test extracting content from RAGDocument with string URL pointing to a PDF."""
|
"""Test extracting content from RAGDocument with string URL pointing to a PDF."""
|
||||||
url_string = "https://example.com/document.pdf"
|
url_string = "https://example.com/document.pdf"
|
||||||
|
@ -130,7 +124,6 @@ async def test_content_from_doc_with_string_pdf_url():
|
||||||
mock_parse_pdf.assert_called_once_with(b"PDF binary data")
|
mock_parse_pdf.assert_called_once_with(b"PDF binary data")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_content_from_doc_with_interleaved_content():
|
async def test_content_from_doc_with_interleaved_content():
|
||||||
"""Test extracting content from RAGDocument with InterleavedContent (the new case added in the commit)."""
|
"""Test extracting content from RAGDocument with InterleavedContent (the new case added in the commit)."""
|
||||||
interleaved_content = [TextContentItem(text="First item"), TextContentItem(text="Second item")]
|
interleaved_content = [TextContentItem(text="First item"), TextContentItem(text="Second item")]
|
||||||
|
|
|
@ -87,18 +87,15 @@ def helper(known_provider_model: ProviderModelEntry, known_provider_model2: Prov
|
||||||
return ModelRegistryHelper([known_provider_model, known_provider_model2])
|
return ModelRegistryHelper([known_provider_model, known_provider_model2])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_lookup_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
|
async def test_lookup_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
|
||||||
assert helper.get_provider_model_id(unknown_model.model_id) is None
|
assert helper.get_provider_model_id(unknown_model.model_id) is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_register_unknown_provider_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
|
async def test_register_unknown_provider_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await helper.register_model(unknown_model)
|
await helper.register_model(unknown_model)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_register_model(helper: ModelRegistryHelper, known_model: Model) -> None:
|
async def test_register_model(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||||
model = Model(
|
model = Model(
|
||||||
provider_id=known_model.provider_id,
|
provider_id=known_model.provider_id,
|
||||||
|
@ -110,7 +107,6 @@ async def test_register_model(helper: ModelRegistryHelper, known_model: Model) -
|
||||||
assert helper.get_provider_model_id(model.model_id) == model.provider_resource_id
|
assert helper.get_provider_model_id(model.model_id) == model.provider_resource_id
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_register_model_from_alias(helper: ModelRegistryHelper, known_model: Model) -> None:
|
async def test_register_model_from_alias(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||||
model = Model(
|
model = Model(
|
||||||
provider_id=known_model.provider_id,
|
provider_id=known_model.provider_id,
|
||||||
|
@ -122,13 +118,11 @@ async def test_register_model_from_alias(helper: ModelRegistryHelper, known_mode
|
||||||
assert helper.get_provider_model_id(model.model_id) == known_model.provider_resource_id
|
assert helper.get_provider_model_id(model.model_id) == known_model.provider_resource_id
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_register_model_existing(helper: ModelRegistryHelper, known_model: Model) -> None:
|
async def test_register_model_existing(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||||
await helper.register_model(known_model)
|
await helper.register_model(known_model)
|
||||||
assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_resource_id
|
assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_resource_id
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_register_model_existing_different(
|
async def test_register_model_existing_different(
|
||||||
helper: ModelRegistryHelper, known_model: Model, known_model2: Model
|
helper: ModelRegistryHelper, known_model: Model, known_model2: Model
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -137,7 +131,6 @@ async def test_register_model_existing_different(
|
||||||
await helper.register_model(known_model)
|
await helper.register_model(known_model)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_unregister_model(helper: ModelRegistryHelper, known_model: Model) -> None:
|
async def test_unregister_model(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||||
await helper.register_model(known_model) # duplicate entry
|
await helper.register_model(known_model) # duplicate entry
|
||||||
assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_model_id
|
assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_model_id
|
||||||
|
@ -145,18 +138,15 @@ async def test_unregister_model(helper: ModelRegistryHelper, known_model: Model)
|
||||||
assert helper.get_provider_model_id(known_model.model_id) is None
|
assert helper.get_provider_model_id(known_model.model_id) is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_unregister_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
|
async def test_unregister_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await helper.unregister_model(unknown_model.model_id)
|
await helper.unregister_model(unknown_model.model_id)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_register_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
|
async def test_register_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||||
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
|
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
|
async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||||
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
|
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
|
||||||
await helper.unregister_model(known_model.provider_resource_id)
|
await helper.unregister_model(known_model.provider_resource_id)
|
||||||
|
|
|
@ -11,7 +11,6 @@ import pytest
|
||||||
from llama_stack.providers.utils.scheduler import JobStatus, Scheduler
|
from llama_stack.providers.utils.scheduler import JobStatus, Scheduler
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_scheduler_unknown_backend():
|
async def test_scheduler_unknown_backend():
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
Scheduler(backend="unknown")
|
Scheduler(backend="unknown")
|
||||||
|
@ -26,7 +25,6 @@ async def wait_for_job_completed(sched: Scheduler, job_id: str) -> None:
|
||||||
raise TimeoutError(f"Job {job_id} did not complete in time.")
|
raise TimeoutError(f"Job {job_id} did not complete in time.")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_scheduler_naive():
|
async def test_scheduler_naive():
|
||||||
sched = Scheduler()
|
sched = Scheduler()
|
||||||
|
|
||||||
|
@ -87,7 +85,6 @@ async def test_scheduler_naive():
|
||||||
assert job.logs[0][0] < job.logs[1][0]
|
assert job.logs[0][0] < job.logs[1][0]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_scheduler_naive_handler_raises():
|
async def test_scheduler_naive_handler_raises():
|
||||||
sched = Scheduler()
|
sched = Scheduler()
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,6 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
|
||||||
|
|
||||||
from llama_stack.apis.files import Files
|
from llama_stack.apis.files import Files
|
||||||
from llama_stack.apis.inference import EmbeddingsResponse, Inference
|
from llama_stack.apis.inference import EmbeddingsResponse, Inference
|
||||||
|
@ -91,13 +90,13 @@ def faiss_config():
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest.fixture
|
||||||
async def faiss_index(embedding_dimension):
|
async def faiss_index(embedding_dimension):
|
||||||
index = await FaissIndex.create(dimension=embedding_dimension)
|
index = await FaissIndex.create(dimension=embedding_dimension)
|
||||||
yield index
|
yield index
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest.fixture
|
||||||
async def faiss_adapter(faiss_config, mock_inference_api, mock_files_api) -> FaissVectorIOAdapter:
|
async def faiss_adapter(faiss_config, mock_inference_api, mock_files_api) -> FaissVectorIOAdapter:
|
||||||
# Create the adapter
|
# Create the adapter
|
||||||
adapter = FaissVectorIOAdapter(config=faiss_config, inference_api=mock_inference_api, files_api=mock_files_api)
|
adapter = FaissVectorIOAdapter(config=faiss_config, inference_api=mock_inference_api, files_api=mock_files_api)
|
||||||
|
@ -113,7 +112,6 @@ async def faiss_adapter(faiss_config, mock_inference_api, mock_files_api) -> Fai
|
||||||
yield adapter
|
yield adapter
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_identical(
|
async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_identical(
|
||||||
faiss_index, sample_chunks, sample_embeddings, embedding_dimension
|
faiss_index, sample_chunks, sample_embeddings, embedding_dimension
|
||||||
):
|
):
|
||||||
|
@ -136,7 +134,6 @@ async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_
|
||||||
assert response.chunks[1] == sample_chunks[1]
|
assert response.chunks[1] == sample_chunks[1]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_health_success():
|
async def test_health_success():
|
||||||
"""Test that the health check returns OK status when faiss is working correctly."""
|
"""Test that the health check returns OK status when faiss is working correctly."""
|
||||||
# Create a fresh instance of FaissVectorIOAdapter for testing
|
# Create a fresh instance of FaissVectorIOAdapter for testing
|
||||||
|
@ -160,7 +157,6 @@ async def test_health_success():
|
||||||
mock_index_flat.assert_called_once_with(128) # VECTOR_DIMENSION is 128
|
mock_index_flat.assert_called_once_with(128) # VECTOR_DIMENSION is 128
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_health_failure():
|
async def test_health_failure():
|
||||||
"""Test that the health check returns ERROR status when faiss encounters an error."""
|
"""Test that the health check returns ERROR status when faiss encounters an error."""
|
||||||
# Create a fresh instance of FaissVectorIOAdapter for testing
|
# Create a fresh instance of FaissVectorIOAdapter for testing
|
||||||
|
|
|
@ -10,7 +10,6 @@ from typing import Any
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import EmbeddingsResponse, Inference
|
from llama_stack.apis.inference import EmbeddingsResponse, Inference
|
||||||
from llama_stack.apis.vector_io import (
|
from llama_stack.apis.vector_io import (
|
||||||
|
@ -68,7 +67,7 @@ def mock_api_service(sample_embeddings):
|
||||||
return mock_api_service
|
return mock_api_service
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest.fixture
|
||||||
async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service, loop) -> QdrantVectorIOAdapter:
|
async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service, loop) -> QdrantVectorIOAdapter:
|
||||||
adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service)
|
adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service)
|
||||||
adapter.vector_db_store = mock_vector_db_store
|
adapter.vector_db_store = mock_vector_db_store
|
||||||
|
@ -80,7 +79,6 @@ async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service,
|
||||||
__QUERY = "Sample query"
|
__QUERY = "Sample query"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.parametrize("max_query_chunks, expected_chunks", [(2, 2), (100, 60)])
|
@pytest.mark.parametrize("max_query_chunks, expected_chunks", [(2, 2), (100, 60)])
|
||||||
async def test_qdrant_adapter_returns_expected_chunks(
|
async def test_qdrant_adapter_returns_expected_chunks(
|
||||||
qdrant_adapter: QdrantVectorIOAdapter,
|
qdrant_adapter: QdrantVectorIOAdapter,
|
||||||
|
@ -111,7 +109,6 @@ def _prepare_for_json(value: Any) -> str:
|
||||||
|
|
||||||
|
|
||||||
@patch("llama_stack.providers.utils.telemetry.trace_protocol._prepare_for_json", new=_prepare_for_json)
|
@patch("llama_stack.providers.utils.telemetry.trace_protocol._prepare_for_json", new=_prepare_for_json)
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_qdrant_register_and_unregister_vector_db(
|
async def test_qdrant_register_and_unregister_vector_db(
|
||||||
qdrant_adapter: QdrantVectorIOAdapter,
|
qdrant_adapter: QdrantVectorIOAdapter,
|
||||||
mock_vector_db,
|
mock_vector_db,
|
||||||
|
|
|
@ -8,7 +8,6 @@ import asyncio
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
|
||||||
|
|
||||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import (
|
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import (
|
||||||
|
@ -34,7 +33,7 @@ def loop():
|
||||||
return asyncio.new_event_loop()
|
return asyncio.new_event_loop()
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest.fixture
|
||||||
async def sqlite_vec_index(embedding_dimension, tmp_path_factory):
|
async def sqlite_vec_index(embedding_dimension, tmp_path_factory):
|
||||||
temp_dir = tmp_path_factory.getbasetemp()
|
temp_dir = tmp_path_factory.getbasetemp()
|
||||||
db_path = str(temp_dir / "test_sqlite.db")
|
db_path = str(temp_dir / "test_sqlite.db")
|
||||||
|
@ -43,14 +42,12 @@ async def sqlite_vec_index(embedding_dimension, tmp_path_factory):
|
||||||
await index.delete()
|
await index.delete()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_chunk_metadata(sqlite_vec_index, sample_chunks_with_metadata, sample_embeddings_with_metadata):
|
async def test_query_chunk_metadata(sqlite_vec_index, sample_chunks_with_metadata, sample_embeddings_with_metadata):
|
||||||
await sqlite_vec_index.add_chunks(sample_chunks_with_metadata, sample_embeddings_with_metadata)
|
await sqlite_vec_index.add_chunks(sample_chunks_with_metadata, sample_embeddings_with_metadata)
|
||||||
response = await sqlite_vec_index.query_vector(sample_embeddings_with_metadata[-1], k=2, score_threshold=0.0)
|
response = await sqlite_vec_index.query_vector(sample_embeddings_with_metadata[-1], k=2, score_threshold=0.0)
|
||||||
assert response.chunks[0].chunk_metadata == sample_chunks_with_metadata[-1].chunk_metadata
|
assert response.chunks[0].chunk_metadata == sample_chunks_with_metadata[-1].chunk_metadata
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_chunks_full_text_search(sqlite_vec_index, sample_chunks, sample_embeddings):
|
async def test_query_chunks_full_text_search(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
query_string = "Sentence 5"
|
query_string = "Sentence 5"
|
||||||
|
@ -68,7 +65,6 @@ async def test_query_chunks_full_text_search(sqlite_vec_index, sample_chunks, sa
|
||||||
assert len(response_no_results.chunks) == 0, f"Expected 0 results, but got {len(response_no_results.chunks)}"
|
assert len(response_no_results.chunks) == 0, f"Expected 0 results, but got {len(response_no_results.chunks)}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_chunks_hybrid(sqlite_vec_index, sample_chunks, sample_embeddings):
|
async def test_query_chunks_hybrid(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
@ -90,7 +86,6 @@ async def test_query_chunks_hybrid(sqlite_vec_index, sample_chunks, sample_embed
|
||||||
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
|
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_chunks_full_text_search_k_greater_than_results(sqlite_vec_index, sample_chunks, sample_embeddings):
|
async def test_query_chunks_full_text_search_k_greater_than_results(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
# Re-initialize with a clean index
|
# Re-initialize with a clean index
|
||||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
@ -103,7 +98,6 @@ async def test_query_chunks_full_text_search_k_greater_than_results(sqlite_vec_i
|
||||||
assert any("Sentence 1 from document 0" in chunk.content for chunk in response.chunks), "Expected chunk not found"
|
assert any("Sentence 1 from document 0" in chunk.content for chunk in response.chunks), "Expected chunk not found"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dimension):
|
async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dimension):
|
||||||
"""Test that chunk IDs do not conflict across batches when inserting chunks."""
|
"""Test that chunk IDs do not conflict across batches when inserting chunks."""
|
||||||
# Reduce batch size to force multiple batches for same document
|
# Reduce batch size to force multiple batches for same document
|
||||||
|
@ -134,7 +128,6 @@ async def sqlite_vec_adapter(sqlite_connection):
|
||||||
await adapter.shutdown()
|
await adapter.shutdown()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_chunks_hybrid_no_keyword_matches(sqlite_vec_index, sample_chunks, sample_embeddings):
|
async def test_query_chunks_hybrid_no_keyword_matches(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
"""Test hybrid search when keyword search returns no matches - should still return vector results."""
|
"""Test hybrid search when keyword search returns no matches - should still return vector results."""
|
||||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
@ -163,7 +156,6 @@ async def test_query_chunks_hybrid_no_keyword_matches(sqlite_vec_index, sample_c
|
||||||
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
|
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_chunks_hybrid_score_threshold(sqlite_vec_index, sample_chunks, sample_embeddings):
|
async def test_query_chunks_hybrid_score_threshold(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
"""Test hybrid search with a high score threshold."""
|
"""Test hybrid search with a high score threshold."""
|
||||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
@ -185,7 +177,6 @@ async def test_query_chunks_hybrid_score_threshold(sqlite_vec_index, sample_chun
|
||||||
assert len(response.chunks) == 0
|
assert len(response.chunks) == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_chunks_hybrid_different_embedding(
|
async def test_query_chunks_hybrid_different_embedding(
|
||||||
sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension
|
sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension
|
||||||
):
|
):
|
||||||
|
@ -211,7 +202,6 @@ async def test_query_chunks_hybrid_different_embedding(
|
||||||
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
|
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_chunks_hybrid_rrf_ranking(sqlite_vec_index, sample_chunks, sample_embeddings):
|
async def test_query_chunks_hybrid_rrf_ranking(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
"""Test that RRF properly combines rankings when documents appear in both search methods."""
|
"""Test that RRF properly combines rankings when documents appear in both search methods."""
|
||||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
@ -236,7 +226,6 @@ async def test_query_chunks_hybrid_rrf_ranking(sqlite_vec_index, sample_chunks,
|
||||||
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
|
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_chunks_hybrid_score_selection(sqlite_vec_index, sample_chunks, sample_embeddings):
|
async def test_query_chunks_hybrid_score_selection(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
@ -284,7 +273,6 @@ async def test_query_chunks_hybrid_score_selection(sqlite_vec_index, sample_chun
|
||||||
assert response.scores[0] == pytest.approx(2.0 / 61.0, rel=1e-6) # Should behave like RRF
|
assert response.scores[0] == pytest.approx(2.0 / 61.0, rel=1e-6) # Should behave like RRF
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_chunks_hybrid_mixed_results(sqlite_vec_index, sample_chunks, sample_embeddings):
|
async def test_query_chunks_hybrid_mixed_results(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
"""Test hybrid search with documents that appear in only one search method."""
|
"""Test hybrid search with documents that appear in only one search method."""
|
||||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
@ -313,7 +301,6 @@ async def test_query_chunks_hybrid_mixed_results(sqlite_vec_index, sample_chunks
|
||||||
assert "document-2" in doc_ids # From keyword search
|
assert "document-2" in doc_ids # From keyword search
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_chunks_hybrid_weighted_reranker_parametrization(
|
async def test_query_chunks_hybrid_weighted_reranker_parametrization(
|
||||||
sqlite_vec_index, sample_chunks, sample_embeddings
|
sqlite_vec_index, sample_chunks, sample_embeddings
|
||||||
):
|
):
|
||||||
|
@ -369,7 +356,6 @@ async def test_query_chunks_hybrid_weighted_reranker_parametrization(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_chunks_hybrid_rrf_impact_factor(sqlite_vec_index, sample_chunks, sample_embeddings):
|
async def test_query_chunks_hybrid_rrf_impact_factor(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
"""Test RRFReRanker with different impact factors."""
|
"""Test RRFReRanker with different impact factors."""
|
||||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
@ -401,7 +387,6 @@ async def test_query_chunks_hybrid_rrf_impact_factor(sqlite_vec_index, sample_ch
|
||||||
assert response.scores[0] == pytest.approx(2.0 / 101.0, rel=1e-6)
|
assert response.scores[0] == pytest.approx(2.0 / 101.0, rel=1e-6)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_chunks_hybrid_edge_cases(sqlite_vec_index, sample_chunks, sample_embeddings):
|
async def test_query_chunks_hybrid_edge_cases(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
@ -445,7 +430,6 @@ async def test_query_chunks_hybrid_edge_cases(sqlite_vec_index, sample_chunks, s
|
||||||
assert len(response.chunks) <= 100
|
assert len(response.chunks) <= 100
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_chunks_hybrid_tie_breaking(
|
async def test_query_chunks_hybrid_tie_breaking(
|
||||||
sqlite_vec_index, sample_embeddings, embedding_dimension, tmp_path_factory
|
sqlite_vec_index, sample_embeddings, embedding_dimension, tmp_path_factory
|
||||||
):
|
):
|
||||||
|
|
|
@ -25,12 +25,10 @@ from llama_stack.providers.remote.vector_io.milvus.milvus import VECTOR_DBS_PREF
|
||||||
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_initialize_index(vector_index):
|
async def test_initialize_index(vector_index):
|
||||||
await vector_index.initialize()
|
await vector_index.initialize()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embeddings):
|
async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embeddings):
|
||||||
vector_index.delete()
|
vector_index.delete()
|
||||||
vector_index.initialize()
|
vector_index.initialize()
|
||||||
|
@ -40,7 +38,6 @@ async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embed
|
||||||
vector_index.delete()
|
vector_index.delete()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimension):
|
async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimension):
|
||||||
embeddings = np.random.rand(len(sample_chunks), embedding_dimension).astype(np.float32)
|
embeddings = np.random.rand(len(sample_chunks), embedding_dimension).astype(np.float32)
|
||||||
await vector_index.add_chunks(sample_chunks, embeddings)
|
await vector_index.add_chunks(sample_chunks, embeddings)
|
||||||
|
@ -54,7 +51,6 @@ async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimensio
|
||||||
assert len(contents) == len(set(contents))
|
assert len(contents) == len(set(contents))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter):
|
async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter):
|
||||||
key = f"{VECTOR_DBS_PREFIX}db1"
|
key = f"{VECTOR_DBS_PREFIX}db1"
|
||||||
dummy = VectorDB(
|
dummy = VectorDB(
|
||||||
|
@ -65,7 +61,6 @@ async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter):
|
||||||
await vector_io_adapter.initialize()
|
await vector_io_adapter.initialize()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_persistence_across_adapter_restarts(vector_io_adapter):
|
async def test_persistence_across_adapter_restarts(vector_io_adapter):
|
||||||
await vector_io_adapter.initialize()
|
await vector_io_adapter.initialize()
|
||||||
dummy = VectorDB(
|
dummy = VectorDB(
|
||||||
|
@ -79,7 +74,6 @@ async def test_persistence_across_adapter_restarts(vector_io_adapter):
|
||||||
await vector_io_adapter.shutdown()
|
await vector_io_adapter.shutdown()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_register_and_unregister_vector_db(vector_io_adapter):
|
async def test_register_and_unregister_vector_db(vector_io_adapter):
|
||||||
unique_id = f"foo_db_{np.random.randint(1e6)}"
|
unique_id = f"foo_db_{np.random.randint(1e6)}"
|
||||||
dummy = VectorDB(
|
dummy = VectorDB(
|
||||||
|
@ -92,14 +86,12 @@ async def test_register_and_unregister_vector_db(vector_io_adapter):
|
||||||
assert dummy.identifier not in vector_io_adapter.cache
|
assert dummy.identifier not in vector_io_adapter.cache
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_unregistered_raises(vector_io_adapter):
|
async def test_query_unregistered_raises(vector_io_adapter):
|
||||||
fake_emb = np.zeros(8, dtype=np.float32)
|
fake_emb = np.zeros(8, dtype=np.float32)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await vector_io_adapter.query_chunks("no_such_db", fake_emb)
|
await vector_io_adapter.query_chunks("no_such_db", fake_emb)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_insert_chunks_calls_underlying_index(vector_io_adapter):
|
async def test_insert_chunks_calls_underlying_index(vector_io_adapter):
|
||||||
fake_index = AsyncMock()
|
fake_index = AsyncMock()
|
||||||
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index)
|
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index)
|
||||||
|
@ -110,7 +102,6 @@ async def test_insert_chunks_calls_underlying_index(vector_io_adapter):
|
||||||
fake_index.insert_chunks.assert_awaited_once_with(chunks)
|
fake_index.insert_chunks.assert_awaited_once_with(chunks)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_insert_chunks_missing_db_raises(vector_io_adapter):
|
async def test_insert_chunks_missing_db_raises(vector_io_adapter):
|
||||||
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
|
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
@ -118,7 +109,6 @@ async def test_insert_chunks_missing_db_raises(vector_io_adapter):
|
||||||
await vector_io_adapter.insert_chunks("db_not_exist", [])
|
await vector_io_adapter.insert_chunks("db_not_exist", [])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter):
|
async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter):
|
||||||
expected = QueryChunksResponse(chunks=[Chunk(content="c1")], scores=[0.1])
|
expected = QueryChunksResponse(chunks=[Chunk(content="c1")], scores=[0.1])
|
||||||
fake_index = AsyncMock(query_chunks=AsyncMock(return_value=expected))
|
fake_index = AsyncMock(query_chunks=AsyncMock(return_value=expected))
|
||||||
|
@ -130,7 +120,6 @@ async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter
|
||||||
assert response is expected
|
assert response is expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_chunks_missing_db_raises(vector_io_adapter):
|
async def test_query_chunks_missing_db_raises(vector_io_adapter):
|
||||||
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
|
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
@ -138,7 +127,6 @@ async def test_query_chunks_missing_db_raises(vector_io_adapter):
|
||||||
await vector_io_adapter.query_chunks("db_missing", "q", None)
|
await vector_io_adapter.query_chunks("db_missing", "q", None)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_save_openai_vector_store(vector_io_adapter):
|
async def test_save_openai_vector_store(vector_io_adapter):
|
||||||
store_id = "vs_1234"
|
store_id = "vs_1234"
|
||||||
openai_vector_store = {
|
openai_vector_store = {
|
||||||
|
@ -155,7 +143,6 @@ async def test_save_openai_vector_store(vector_io_adapter):
|
||||||
assert vector_io_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store
|
assert vector_io_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_openai_vector_store(vector_io_adapter):
|
async def test_update_openai_vector_store(vector_io_adapter):
|
||||||
store_id = "vs_1234"
|
store_id = "vs_1234"
|
||||||
openai_vector_store = {
|
openai_vector_store = {
|
||||||
|
@ -172,7 +159,6 @@ async def test_update_openai_vector_store(vector_io_adapter):
|
||||||
assert vector_io_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store
|
assert vector_io_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_openai_vector_store(vector_io_adapter):
|
async def test_delete_openai_vector_store(vector_io_adapter):
|
||||||
store_id = "vs_1234"
|
store_id = "vs_1234"
|
||||||
openai_vector_store = {
|
openai_vector_store = {
|
||||||
|
@ -188,7 +174,6 @@ async def test_delete_openai_vector_store(vector_io_adapter):
|
||||||
assert openai_vector_store["id"] not in vector_io_adapter.openai_vector_stores
|
assert openai_vector_store["id"] not in vector_io_adapter.openai_vector_stores
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_load_openai_vector_stores(vector_io_adapter):
|
async def test_load_openai_vector_stores(vector_io_adapter):
|
||||||
store_id = "vs_1234"
|
store_id = "vs_1234"
|
||||||
openai_vector_store = {
|
openai_vector_store = {
|
||||||
|
@ -204,7 +189,6 @@ async def test_load_openai_vector_stores(vector_io_adapter):
|
||||||
assert loaded_stores[store_id] == openai_vector_store
|
assert loaded_stores[store_id] == openai_vector_store
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_save_openai_vector_store_file(vector_io_adapter, tmp_path_factory):
|
async def test_save_openai_vector_store_file(vector_io_adapter, tmp_path_factory):
|
||||||
store_id = "vs_1234"
|
store_id = "vs_1234"
|
||||||
file_id = "file_1234"
|
file_id = "file_1234"
|
||||||
|
@ -226,7 +210,6 @@ async def test_save_openai_vector_store_file(vector_io_adapter, tmp_path_factory
|
||||||
await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
|
await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_openai_vector_store_file(vector_io_adapter, tmp_path_factory):
|
async def test_update_openai_vector_store_file(vector_io_adapter, tmp_path_factory):
|
||||||
store_id = "vs_1234"
|
store_id = "vs_1234"
|
||||||
file_id = "file_1234"
|
file_id = "file_1234"
|
||||||
|
@ -260,7 +243,6 @@ async def test_update_openai_vector_store_file(vector_io_adapter, tmp_path_facto
|
||||||
assert loaded_contents != file_info
|
assert loaded_contents != file_info
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_load_openai_vector_store_file_contents(vector_io_adapter, tmp_path_factory):
|
async def test_load_openai_vector_store_file_contents(vector_io_adapter, tmp_path_factory):
|
||||||
store_id = "vs_1234"
|
store_id = "vs_1234"
|
||||||
file_id = "file_1234"
|
file_id = "file_1234"
|
||||||
|
@ -284,7 +266,6 @@ async def test_load_openai_vector_store_file_contents(vector_io_adapter, tmp_pat
|
||||||
assert loaded_contents == file_contents
|
assert loaded_contents == file_contents
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_openai_vector_store_file_from_storage(vector_io_adapter, tmp_path_factory):
|
async def test_delete_openai_vector_store_file_from_storage(vector_io_adapter, tmp_path_factory):
|
||||||
store_id = "vs_1234"
|
store_id = "vs_1234"
|
||||||
file_id = "file_1234"
|
file_id = "file_1234"
|
||||||
|
|
|
@ -17,13 +17,11 @@ from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRunti
|
||||||
|
|
||||||
|
|
||||||
class TestRagQuery:
|
class TestRagQuery:
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_raises_on_empty_vector_db_ids(self):
|
async def test_query_raises_on_empty_vector_db_ids(self):
|
||||||
rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock())
|
rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock())
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await rag_tool.query(content=MagicMock(), vector_db_ids=[])
|
await rag_tool.query(content=MagicMock(), vector_db_ids=[])
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_chunk_metadata_handling(self):
|
async def test_query_chunk_metadata_handling(self):
|
||||||
rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock())
|
rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock())
|
||||||
content = "test query content"
|
content = "test query content"
|
||||||
|
|
|
@ -112,7 +112,6 @@ class TestValidateEmbedding:
|
||||||
|
|
||||||
|
|
||||||
class TestVectorStore:
|
class TestVectorStore:
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_returns_content_from_pdf_data_uri(self):
|
async def test_returns_content_from_pdf_data_uri(self):
|
||||||
data_uri = data_url_from_file(DUMMY_PDF_PATH)
|
data_uri = data_url_from_file(DUMMY_PDF_PATH)
|
||||||
doc = RAGDocument(
|
doc = RAGDocument(
|
||||||
|
@ -124,7 +123,6 @@ class TestVectorStore:
|
||||||
content = await content_from_doc(doc)
|
content = await content_from_doc(doc)
|
||||||
assert content in DUMMY_PDF_TEXT_CHOICES
|
assert content in DUMMY_PDF_TEXT_CHOICES
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_downloads_pdf_and_returns_content(self):
|
async def test_downloads_pdf_and_returns_content(self):
|
||||||
# Using GitHub to host the PDF file
|
# Using GitHub to host the PDF file
|
||||||
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
|
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
|
||||||
|
@ -137,7 +135,6 @@ class TestVectorStore:
|
||||||
content = await content_from_doc(doc)
|
content = await content_from_doc(doc)
|
||||||
assert content in DUMMY_PDF_TEXT_CHOICES
|
assert content in DUMMY_PDF_TEXT_CHOICES
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_downloads_pdf_and_returns_content_with_url_object(self):
|
async def test_downloads_pdf_and_returns_content_with_url_object(self):
|
||||||
# Using GitHub to host the PDF file
|
# Using GitHub to host the PDF file
|
||||||
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
|
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
|
||||||
|
@ -204,7 +201,6 @@ class TestVectorStore:
|
||||||
|
|
||||||
|
|
||||||
class TestVectorDBWithIndex:
|
class TestVectorDBWithIndex:
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_insert_chunks_without_embeddings(self):
|
async def test_insert_chunks_without_embeddings(self):
|
||||||
mock_vector_db = MagicMock()
|
mock_vector_db = MagicMock()
|
||||||
mock_vector_db.embedding_model = "test-model without embeddings"
|
mock_vector_db.embedding_model = "test-model without embeddings"
|
||||||
|
@ -230,7 +226,6 @@ class TestVectorDBWithIndex:
|
||||||
assert args[0] == chunks
|
assert args[0] == chunks
|
||||||
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
|
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_insert_chunks_with_valid_embeddings(self):
|
async def test_insert_chunks_with_valid_embeddings(self):
|
||||||
mock_vector_db = MagicMock()
|
mock_vector_db = MagicMock()
|
||||||
mock_vector_db.embedding_model = "test-model with embeddings"
|
mock_vector_db.embedding_model = "test-model with embeddings"
|
||||||
|
@ -255,7 +250,6 @@ class TestVectorDBWithIndex:
|
||||||
assert args[0] == chunks
|
assert args[0] == chunks
|
||||||
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
|
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_insert_chunks_with_invalid_embeddings(self):
|
async def test_insert_chunks_with_invalid_embeddings(self):
|
||||||
mock_vector_db = MagicMock()
|
mock_vector_db = MagicMock()
|
||||||
mock_vector_db.embedding_dimension = 3
|
mock_vector_db.embedding_dimension = 3
|
||||||
|
@ -295,7 +289,6 @@ class TestVectorDBWithIndex:
|
||||||
mock_inference_api.embeddings.assert_not_called()
|
mock_inference_api.embeddings.assert_not_called()
|
||||||
mock_index.add_chunks.assert_not_called()
|
mock_index.add_chunks.assert_not_called()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_insert_chunks_with_partially_precomputed_embeddings(self):
|
async def test_insert_chunks_with_partially_precomputed_embeddings(self):
|
||||||
mock_vector_db = MagicMock()
|
mock_vector_db = MagicMock()
|
||||||
mock_vector_db.embedding_model = "test-model with partial embeddings"
|
mock_vector_db.embedding_model = "test-model with partial embeddings"
|
||||||
|
|
|
@ -38,14 +38,12 @@ def sample_model():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_registry_initialization(disk_dist_registry):
|
async def test_registry_initialization(disk_dist_registry):
|
||||||
# Test empty registry
|
# Test empty registry
|
||||||
result = await disk_dist_registry.get("nonexistent", "nonexistent")
|
result = await disk_dist_registry.get("nonexistent", "nonexistent")
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_model):
|
async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_model):
|
||||||
print(f"Registering {sample_vector_db}")
|
print(f"Registering {sample_vector_db}")
|
||||||
await disk_dist_registry.register(sample_vector_db)
|
await disk_dist_registry.register(sample_vector_db)
|
||||||
|
@ -64,7 +62,6 @@ async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_m
|
||||||
assert result_model.provider_id == sample_model.provider_id
|
assert result_model.provider_id == sample_model.provider_id
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db, sample_model):
|
async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db, sample_model):
|
||||||
# First populate the disk registry
|
# First populate the disk registry
|
||||||
disk_registry = DiskDistributionRegistry(sqlite_kvstore)
|
disk_registry = DiskDistributionRegistry(sqlite_kvstore)
|
||||||
|
@ -85,7 +82,6 @@ async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db,
|
||||||
assert result_vector_db.provider_id == sample_vector_db.provider_id
|
assert result_vector_db.provider_id == sample_vector_db.provider_id
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_cached_registry_updates(cached_disk_dist_registry):
|
async def test_cached_registry_updates(cached_disk_dist_registry):
|
||||||
new_vector_db = VectorDB(
|
new_vector_db = VectorDB(
|
||||||
identifier="test_vector_db_2",
|
identifier="test_vector_db_2",
|
||||||
|
@ -112,7 +108,6 @@ async def test_cached_registry_updates(cached_disk_dist_registry):
|
||||||
assert result_vector_db.provider_id == new_vector_db.provider_id
|
assert result_vector_db.provider_id == new_vector_db.provider_id
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_duplicate_provider_registration(cached_disk_dist_registry):
|
async def test_duplicate_provider_registration(cached_disk_dist_registry):
|
||||||
original_vector_db = VectorDB(
|
original_vector_db = VectorDB(
|
||||||
identifier="test_vector_db_2",
|
identifier="test_vector_db_2",
|
||||||
|
@ -137,7 +132,6 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry):
|
||||||
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
|
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_all_objects(cached_disk_dist_registry):
|
async def test_get_all_objects(cached_disk_dist_registry):
|
||||||
# Create multiple test banks
|
# Create multiple test banks
|
||||||
# Create multiple test banks
|
# Create multiple test banks
|
||||||
|
@ -170,7 +164,6 @@ async def test_get_all_objects(cached_disk_dist_registry):
|
||||||
assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension
|
assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_parse_registry_values_error_handling(sqlite_kvstore):
|
async def test_parse_registry_values_error_handling(sqlite_kvstore):
|
||||||
valid_db = VectorDB(
|
valid_db = VectorDB(
|
||||||
identifier="valid_vector_db",
|
identifier="valid_vector_db",
|
||||||
|
@ -209,7 +202,6 @@ async def test_parse_registry_values_error_handling(sqlite_kvstore):
|
||||||
assert invalid_obj is None
|
assert invalid_obj is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_cached_registry_error_handling(sqlite_kvstore):
|
async def test_cached_registry_error_handling(sqlite_kvstore):
|
||||||
valid_db = VectorDB(
|
valid_db = VectorDB(
|
||||||
identifier="valid_cached_db",
|
identifier="valid_cached_db",
|
||||||
|
|
|
@ -5,14 +5,11 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from llama_stack.apis.models import ModelType
|
from llama_stack.apis.models import ModelType
|
||||||
from llama_stack.distribution.datatypes import ModelWithOwner, User
|
from llama_stack.distribution.datatypes import ModelWithOwner, User
|
||||||
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
|
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_registry_cache_with_acl(cached_disk_dist_registry):
|
async def test_registry_cache_with_acl(cached_disk_dist_registry):
|
||||||
model = ModelWithOwner(
|
model = ModelWithOwner(
|
||||||
identifier="model-acl",
|
identifier="model-acl",
|
||||||
|
@ -48,7 +45,6 @@ async def test_registry_cache_with_acl(cached_disk_dist_registry):
|
||||||
assert new_model.owner.attributes["teams"] == ["ai-team"]
|
assert new_model.owner.attributes["teams"] == ["ai-team"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_registry_empty_acl(cached_disk_dist_registry):
|
async def test_registry_empty_acl(cached_disk_dist_registry):
|
||||||
model = ModelWithOwner(
|
model = ModelWithOwner(
|
||||||
identifier="model-empty-acl",
|
identifier="model-empty-acl",
|
||||||
|
@ -85,7 +81,6 @@ async def test_registry_empty_acl(cached_disk_dist_registry):
|
||||||
assert len(all_models) == 2
|
assert len(all_models) == 2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_registry_serialization(cached_disk_dist_registry):
|
async def test_registry_serialization(cached_disk_dist_registry):
|
||||||
attributes = {
|
attributes = {
|
||||||
"roles": ["admin", "researcher"],
|
"roles": ["admin", "researcher"],
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
from unittest.mock import MagicMock, Mock, patch
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
|
||||||
import yaml
|
import yaml
|
||||||
from pydantic import TypeAdapter, ValidationError
|
from pydantic import TypeAdapter, ValidationError
|
||||||
|
|
||||||
|
@ -27,7 +26,7 @@ def _return_model(model):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest.fixture
|
||||||
async def test_setup(cached_disk_dist_registry):
|
async def test_setup(cached_disk_dist_registry):
|
||||||
mock_inference = Mock()
|
mock_inference = Mock()
|
||||||
mock_inference.__provider_spec__ = MagicMock()
|
mock_inference.__provider_spec__ = MagicMock()
|
||||||
|
@ -41,7 +40,6 @@ async def test_setup(cached_disk_dist_registry):
|
||||||
yield cached_disk_dist_registry, routing_table
|
yield cached_disk_dist_registry, routing_table
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||||
async def test_access_control_with_cache(mock_get_authenticated_user, test_setup):
|
async def test_access_control_with_cache(mock_get_authenticated_user, test_setup):
|
||||||
registry, routing_table = test_setup
|
registry, routing_table = test_setup
|
||||||
|
@ -106,7 +104,6 @@ async def test_access_control_with_cache(mock_get_authenticated_user, test_setup
|
||||||
await routing_table.get_model("model-admin")
|
await routing_table.get_model("model-admin")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||||
async def test_access_control_and_updates(mock_get_authenticated_user, test_setup):
|
async def test_access_control_and_updates(mock_get_authenticated_user, test_setup):
|
||||||
registry, routing_table = test_setup
|
registry, routing_table = test_setup
|
||||||
|
@ -145,7 +142,6 @@ async def test_access_control_and_updates(mock_get_authenticated_user, test_setu
|
||||||
assert model.identifier == "model-updates"
|
assert model.identifier == "model-updates"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||||
async def test_access_control_empty_attributes(mock_get_authenticated_user, test_setup):
|
async def test_access_control_empty_attributes(mock_get_authenticated_user, test_setup):
|
||||||
registry, routing_table = test_setup
|
registry, routing_table = test_setup
|
||||||
|
@ -170,7 +166,6 @@ async def test_access_control_empty_attributes(mock_get_authenticated_user, test
|
||||||
assert "model-empty-attrs" in model_ids
|
assert "model-empty-attrs" in model_ids
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||||
async def test_no_user_attributes(mock_get_authenticated_user, test_setup):
|
async def test_no_user_attributes(mock_get_authenticated_user, test_setup):
|
||||||
registry, routing_table = test_setup
|
registry, routing_table = test_setup
|
||||||
|
@ -201,7 +196,6 @@ async def test_no_user_attributes(mock_get_authenticated_user, test_setup):
|
||||||
assert all_models.data[0].identifier == "model-public-2"
|
assert all_models.data[0].identifier == "model-public-2"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||||
async def test_automatic_access_attributes(mock_get_authenticated_user, test_setup):
|
async def test_automatic_access_attributes(mock_get_authenticated_user, test_setup):
|
||||||
"""Test that newly created resources inherit access attributes from their creator."""
|
"""Test that newly created resources inherit access attributes from their creator."""
|
||||||
|
@ -246,7 +240,7 @@ async def test_automatic_access_attributes(mock_get_authenticated_user, test_set
|
||||||
assert model.identifier == "auto-access-model"
|
assert model.identifier == "auto-access-model"
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest.fixture
|
||||||
async def test_setup_with_access_policy(cached_disk_dist_registry):
|
async def test_setup_with_access_policy(cached_disk_dist_registry):
|
||||||
mock_inference = Mock()
|
mock_inference = Mock()
|
||||||
mock_inference.__provider_spec__ = MagicMock()
|
mock_inference.__provider_spec__ = MagicMock()
|
||||||
|
@ -281,7 +275,6 @@ async def test_setup_with_access_policy(cached_disk_dist_registry):
|
||||||
yield routing_table
|
yield routing_table
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
|
||||||
async def test_access_policy(mock_get_authenticated_user, test_setup_with_access_policy):
|
async def test_access_policy(mock_get_authenticated_user, test_setup_with_access_policy):
|
||||||
routing_table = test_setup_with_access_policy
|
routing_table = test_setup_with_access_policy
|
||||||
|
|
|
@ -202,7 +202,6 @@ def test_http_auth_request_payload(http_client, valid_api_key, mock_auth_endpoin
|
||||||
assert "param2" in payload["request"]["params"]
|
assert "param2" in payload["request"]["params"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_http_middleware_with_access_attributes(mock_http_middleware, mock_scope):
|
async def test_http_middleware_with_access_attributes(mock_http_middleware, mock_scope):
|
||||||
"""Test HTTP middleware behavior with access attributes"""
|
"""Test HTTP middleware behavior with access attributes"""
|
||||||
middleware, mock_app = mock_http_middleware
|
middleware, mock_app = mock_http_middleware
|
||||||
|
|
|
@ -9,7 +9,6 @@ import sys
|
||||||
from typing import Any, Protocol
|
from typing import Any, Protocol
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.apis.inference import Inference
|
from llama_stack.apis.inference import Inference
|
||||||
|
@ -66,7 +65,6 @@ class SampleImpl:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_resolve_impls_basic():
|
async def test_resolve_impls_basic():
|
||||||
# Create a real provider spec
|
# Create a real provider spec
|
||||||
provider_spec = InlineProviderSpec(
|
provider_spec = InlineProviderSpec(
|
||||||
|
|
|
@ -7,13 +7,10 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from llama_stack.apis.common.responses import PaginatedResponse
|
from llama_stack.apis.common.responses import PaginatedResponse
|
||||||
from llama_stack.distribution.server.server import create_dynamic_typed_route, create_sse_event, sse_generator
|
from llama_stack.distribution.server.server import create_dynamic_typed_route, create_sse_event, sse_generator
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_sse_generator_basic():
|
async def test_sse_generator_basic():
|
||||||
# An AsyncIterator wrapped in an Awaitable, just like our web methods
|
# An AsyncIterator wrapped in an Awaitable, just like our web methods
|
||||||
async def async_event_gen():
|
async def async_event_gen():
|
||||||
|
@ -35,7 +32,6 @@ async def test_sse_generator_basic():
|
||||||
assert seen_events[1] == create_sse_event("Test event 2")
|
assert seen_events[1] == create_sse_event("Test event 2")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_sse_generator_client_disconnected():
|
async def test_sse_generator_client_disconnected():
|
||||||
# An AsyncIterator wrapped in an Awaitable, just like our web methods
|
# An AsyncIterator wrapped in an Awaitable, just like our web methods
|
||||||
async def async_event_gen():
|
async def async_event_gen():
|
||||||
|
@ -58,7 +54,6 @@ async def test_sse_generator_client_disconnected():
|
||||||
assert seen_events[0] == create_sse_event("Test event 1")
|
assert seen_events[0] == create_sse_event("Test event 1")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_sse_generator_client_disconnected_before_response_starts():
|
async def test_sse_generator_client_disconnected_before_response_starts():
|
||||||
# Disconnect before the response starts
|
# Disconnect before the response starts
|
||||||
async def async_event_gen():
|
async def async_event_gen():
|
||||||
|
@ -75,7 +70,6 @@ async def test_sse_generator_client_disconnected_before_response_starts():
|
||||||
assert len(seen_events) == 0
|
assert len(seen_events) == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_sse_generator_error_before_response_starts():
|
async def test_sse_generator_error_before_response_starts():
|
||||||
# Raise an error before the response starts
|
# Raise an error before the response starts
|
||||||
async def async_event_gen():
|
async def async_event_gen():
|
||||||
|
@ -93,7 +87,6 @@ async def test_sse_generator_error_before_response_starts():
|
||||||
assert 'data: {"error":' in seen_events[0]
|
assert 'data: {"error":' in seen_events[0]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_paginated_response_url_setting():
|
async def test_paginated_response_url_setting():
|
||||||
"""Test that PaginatedResponse gets url set to route path."""
|
"""Test that PaginatedResponse gets url set to route path."""
|
||||||
|
|
||||||
|
|
|
@ -42,7 +42,6 @@ def create_test_chat_completion(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_inference_store_pagination_basic():
|
async def test_inference_store_pagination_basic():
|
||||||
"""Test basic pagination functionality."""
|
"""Test basic pagination functionality."""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
|
@ -88,7 +87,6 @@ async def test_inference_store_pagination_basic():
|
||||||
assert result3.has_more is False
|
assert result3.has_more is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_inference_store_pagination_ascending():
|
async def test_inference_store_pagination_ascending():
|
||||||
"""Test pagination with ascending order."""
|
"""Test pagination with ascending order."""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
|
@ -123,7 +121,6 @@ async def test_inference_store_pagination_ascending():
|
||||||
assert result2.has_more is True
|
assert result2.has_more is True
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_inference_store_pagination_with_model_filter():
|
async def test_inference_store_pagination_with_model_filter():
|
||||||
"""Test pagination combined with model filtering."""
|
"""Test pagination combined with model filtering."""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
|
@ -161,7 +158,6 @@ async def test_inference_store_pagination_with_model_filter():
|
||||||
assert result2.has_more is False
|
assert result2.has_more is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_inference_store_pagination_invalid_after():
|
async def test_inference_store_pagination_invalid_after():
|
||||||
"""Test error handling for invalid 'after' parameter."""
|
"""Test error handling for invalid 'after' parameter."""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
|
@ -174,7 +170,6 @@ async def test_inference_store_pagination_invalid_after():
|
||||||
await store.list_chat_completions(after="non-existent", limit=2)
|
await store.list_chat_completions(after="non-existent", limit=2)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_inference_store_pagination_no_limit():
|
async def test_inference_store_pagination_no_limit():
|
||||||
"""Test pagination behavior when no limit is specified."""
|
"""Test pagination behavior when no limit is specified."""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
|
|
|
@ -44,7 +44,6 @@ def create_test_response_input(content: str, input_id: str) -> OpenAIResponseInp
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_responses_store_pagination_basic():
|
async def test_responses_store_pagination_basic():
|
||||||
"""Test basic pagination functionality for responses store."""
|
"""Test basic pagination functionality for responses store."""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
|
@ -90,7 +89,6 @@ async def test_responses_store_pagination_basic():
|
||||||
assert result3.has_more is False
|
assert result3.has_more is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_responses_store_pagination_ascending():
|
async def test_responses_store_pagination_ascending():
|
||||||
"""Test pagination with ascending order."""
|
"""Test pagination with ascending order."""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
|
@ -125,7 +123,6 @@ async def test_responses_store_pagination_ascending():
|
||||||
assert result2.has_more is True
|
assert result2.has_more is True
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_responses_store_pagination_with_model_filter():
|
async def test_responses_store_pagination_with_model_filter():
|
||||||
"""Test pagination combined with model filtering."""
|
"""Test pagination combined with model filtering."""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
|
@ -163,7 +160,6 @@ async def test_responses_store_pagination_with_model_filter():
|
||||||
assert result2.has_more is False
|
assert result2.has_more is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_responses_store_pagination_invalid_after():
|
async def test_responses_store_pagination_invalid_after():
|
||||||
"""Test error handling for invalid 'after' parameter."""
|
"""Test error handling for invalid 'after' parameter."""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
|
@ -176,7 +172,6 @@ async def test_responses_store_pagination_invalid_after():
|
||||||
await store.list_responses(after="non-existent", limit=2)
|
await store.list_responses(after="non-existent", limit=2)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_responses_store_pagination_no_limit():
|
async def test_responses_store_pagination_no_limit():
|
||||||
"""Test pagination behavior when no limit is specified."""
|
"""Test pagination behavior when no limit is specified."""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
|
@ -205,7 +200,6 @@ async def test_responses_store_pagination_no_limit():
|
||||||
assert result.has_more is False
|
assert result.has_more is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_responses_store_get_response_object():
|
async def test_responses_store_get_response_object():
|
||||||
"""Test retrieving a single response object."""
|
"""Test retrieving a single response object."""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
|
@ -230,7 +224,6 @@ async def test_responses_store_get_response_object():
|
||||||
await store.get_response_object("non-existent")
|
await store.get_response_object("non-existent")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_responses_store_input_items_pagination():
|
async def test_responses_store_input_items_pagination():
|
||||||
"""Test pagination functionality for input items."""
|
"""Test pagination functionality for input items."""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
|
@ -308,7 +301,6 @@ async def test_responses_store_input_items_pagination():
|
||||||
await store.list_response_input_items("test-resp", before="some-id", after="other-id")
|
await store.list_response_input_items("test-resp", before="some-id", after="other-id")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_responses_store_input_items_before_pagination():
|
async def test_responses_store_input_items_before_pagination():
|
||||||
"""Test before pagination functionality for input items."""
|
"""Test before pagination functionality for input items."""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
|
|
|
@ -14,7 +14,6 @@ from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemyS
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_sqlite_sqlstore():
|
async def test_sqlite_sqlstore():
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
db_name = "test.db"
|
db_name = "test.db"
|
||||||
|
@ -66,7 +65,6 @@ async def test_sqlite_sqlstore():
|
||||||
assert result.has_more is False
|
assert result.has_more is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_sqlstore_pagination_basic():
|
async def test_sqlstore_pagination_basic():
|
||||||
"""Test basic pagination functionality at the SQL store level."""
|
"""Test basic pagination functionality at the SQL store level."""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
|
@ -131,7 +129,6 @@ async def test_sqlstore_pagination_basic():
|
||||||
assert result3.has_more is False
|
assert result3.has_more is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_sqlstore_pagination_with_filter():
|
async def test_sqlstore_pagination_with_filter():
|
||||||
"""Test pagination with WHERE conditions."""
|
"""Test pagination with WHERE conditions."""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
|
@ -184,7 +181,6 @@ async def test_sqlstore_pagination_with_filter():
|
||||||
assert result2.has_more is False
|
assert result2.has_more is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_sqlstore_pagination_ascending_order():
|
async def test_sqlstore_pagination_ascending_order():
|
||||||
"""Test pagination with ascending order."""
|
"""Test pagination with ascending order."""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
|
@ -233,7 +229,6 @@ async def test_sqlstore_pagination_ascending_order():
|
||||||
assert result2.has_more is True
|
assert result2.has_more is True
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_sqlstore_pagination_multi_column_ordering_error():
|
async def test_sqlstore_pagination_multi_column_ordering_error():
|
||||||
"""Test that multi-column ordering raises an error when using cursor pagination."""
|
"""Test that multi-column ordering raises an error when using cursor pagination."""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
|
@ -271,7 +266,6 @@ async def test_sqlstore_pagination_multi_column_ordering_error():
|
||||||
assert result.data[0]["id"] == "task1"
|
assert result.data[0]["id"] == "task1"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_sqlstore_pagination_cursor_requires_order_by():
|
async def test_sqlstore_pagination_cursor_requires_order_by():
|
||||||
"""Test that cursor pagination requires order_by parameter."""
|
"""Test that cursor pagination requires order_by parameter."""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
|
@ -289,7 +283,6 @@ async def test_sqlstore_pagination_cursor_requires_order_by():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_sqlstore_pagination_error_handling():
|
async def test_sqlstore_pagination_error_handling():
|
||||||
"""Test error handling for invalid columns and cursor IDs."""
|
"""Test error handling for invalid columns and cursor IDs."""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
|
@ -339,7 +332,6 @@ async def test_sqlstore_pagination_error_handling():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_sqlstore_pagination_custom_key_column():
|
async def test_sqlstore_pagination_custom_key_column():
|
||||||
"""Test pagination with custom primary key column (not 'id')."""
|
"""Test pagination with custom primary key column (not 'id')."""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
|
|
|
@ -7,8 +7,6 @@
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from llama_stack.distribution.access_control.access_control import default_policy, is_action_allowed
|
from llama_stack.distribution.access_control.access_control import default_policy, is_action_allowed
|
||||||
from llama_stack.distribution.access_control.datatypes import Action
|
from llama_stack.distribution.access_control.datatypes import Action
|
||||||
from llama_stack.distribution.datatypes import User
|
from llama_stack.distribution.datatypes import User
|
||||||
|
@ -18,7 +16,6 @@ from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemyS
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||||
async def test_authorized_fetch_with_where_sql_access_control(mock_get_authenticated_user):
|
async def test_authorized_fetch_with_where_sql_access_control(mock_get_authenticated_user):
|
||||||
"""Test that fetch_all works correctly with where_sql for access control"""
|
"""Test that fetch_all works correctly with where_sql for access control"""
|
||||||
|
@ -81,7 +78,6 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
|
||||||
assert row["title"] == "User Document"
|
assert row["title"] == "User Document"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||||
async def test_sql_policy_consistency(mock_get_authenticated_user):
|
async def test_sql_policy_consistency(mock_get_authenticated_user):
|
||||||
"""Test that SQL WHERE clause logic exactly matches is_action_allowed policy logic"""
|
"""Test that SQL WHERE clause logic exactly matches is_action_allowed policy logic"""
|
||||||
|
@ -168,7 +164,6 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||||
async def test_authorized_store_user_attribute_capture(mock_get_authenticated_user):
|
async def test_authorized_store_user_attribute_capture(mock_get_authenticated_user):
|
||||||
"""Test that user attributes are properly captured during insert"""
|
"""Test that user attributes are properly captured during insert"""
|
||||||
|
|
17
uv.lock
generated
17
uv.lock
generated
|
@ -1394,8 +1394,8 @@ dev = [
|
||||||
{ name = "black" },
|
{ name = "black" },
|
||||||
{ name = "nbval" },
|
{ name = "nbval" },
|
||||||
{ name = "pre-commit" },
|
{ name = "pre-commit" },
|
||||||
{ name = "pytest" },
|
{ name = "pytest", specifier = ">=8.4" },
|
||||||
{ name = "pytest-asyncio" },
|
{ name = "pytest-asyncio", specifier = ">=1.0" },
|
||||||
{ name = "pytest-cov" },
|
{ name = "pytest-cov" },
|
||||||
{ name = "pytest-html" },
|
{ name = "pytest-html" },
|
||||||
{ name = "pytest-json-report" },
|
{ name = "pytest-json-report" },
|
||||||
|
@ -2432,29 +2432,30 @@ wheels = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pytest"
|
name = "pytest"
|
||||||
version = "8.3.4"
|
version = "8.4.1"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||||
{ name = "iniconfig" },
|
{ name = "iniconfig" },
|
||||||
{ name = "packaging" },
|
{ name = "packaging" },
|
||||||
{ name = "pluggy" },
|
{ name = "pluggy" },
|
||||||
|
{ name = "pygments" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/05/35/30e0d83068951d90a01852cb1cef56e5d8a09d20c7f511634cc2f7e0372a/pytest-8.3.4.tar.gz", hash = "sha256:965370d062bce11e73868e0335abac31b4d3de0e82f4007408d242b4f8610761", size = 1445919, upload-time = "2024-12-01T12:54:25.98Z" }
|
sdist = { url = "https://files.pythonhosted.org/packages/08/ba/45911d754e8eba3d5a841a5ce61a65a685ff1798421ac054f85aa8747dfb/pytest-8.4.1.tar.gz", hash = "sha256:7c67fd69174877359ed9371ec3af8a3d2b04741818c51e5e99cc1742251fa93c", size = 1517714, upload-time = "2025-06-18T05:48:06.109Z" }
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/11/92/76a1c94d3afee238333bc0a42b82935dd8f9cf8ce9e336ff87ee14d9e1cf/pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6", size = 343083, upload-time = "2024-12-01T12:54:19.735Z" },
|
{ url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474, upload-time = "2025-06-18T05:48:03.955Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pytest-asyncio"
|
name = "pytest-asyncio"
|
||||||
version = "0.25.3"
|
version = "1.0.0"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "pytest" },
|
{ name = "pytest" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/f2/a8/ecbc8ede70921dd2f544ab1cadd3ff3bf842af27f87bbdea774c7baa1d38/pytest_asyncio-0.25.3.tar.gz", hash = "sha256:fc1da2cf9f125ada7e710b4ddad05518d4cee187ae9412e9ac9271003497f07a", size = 54239, upload-time = "2025-01-28T18:37:58.729Z" }
|
sdist = { url = "https://files.pythonhosted.org/packages/d0/d4/14f53324cb1a6381bef29d698987625d80052bb33932d8e7cbf9b337b17c/pytest_asyncio-1.0.0.tar.gz", hash = "sha256:d15463d13f4456e1ead2594520216b225a16f781e144f8fdf6c5bb4667c48b3f", size = 46960, upload-time = "2025-05-26T04:54:40.484Z" }
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/67/17/3493c5624e48fd97156ebaec380dcaafee9506d7e2c46218ceebbb57d7de/pytest_asyncio-0.25.3-py3-none-any.whl", hash = "sha256:9e89518e0f9bd08928f97a3482fdc4e244df17529460bc038291ccaf8f85c7c3", size = 19467, upload-time = "2025-01-28T18:37:56.798Z" },
|
{ url = "https://files.pythonhosted.org/packages/30/05/ce271016e351fddc8399e546f6e23761967ee09c8c568bbfbecb0c150171/pytest_asyncio-1.0.0-py3-none-any.whl", hash = "sha256:4f024da9f1ef945e680dc68610b52550e36590a67fd31bb3b4943979a1f90ef3", size = 15976, upload-time = "2025-05-26T04:54:39.035Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue