Merge branch 'main' into allow-dynamic-models-nvidia

This commit is contained in:
Matthew Farrellee 2025-07-14 19:01:28 -04:00
commit c2ab8988e6
127 changed files with 3997 additions and 3394 deletions

View file

@ -4,6 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest_socket
# We need to import the fixtures here so that pytest can find them
# but ruff doesn't think they are used and removes the import. "noqa: F401" prevents them from being removed
from .fixtures import cached_disk_dist_registry, disk_dist_registry, sqlite_kvstore # noqa: F401
def pytest_runtest_setup(item):
"""Setup for each test - check if network access should be allowed."""
if "allow_network" in item.keywords:
pytest_socket.enable_socket()
else:
# Allowing Unix sockets is necessary for some tests that use local servers and mocks
pytest_socket.disable_socket(allow_unix_socket=True)

View file

@ -8,8 +8,6 @@
from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
from llama_stack.apis.datatypes import Api
@ -119,7 +117,6 @@ class ToolGroupsImpl(Impl):
)
@pytest.mark.asyncio
async def test_models_routing_table(cached_disk_dist_registry):
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
@ -161,7 +158,6 @@ async def test_models_routing_table(cached_disk_dist_registry):
assert len(openai_models.data) == 0
@pytest.mark.asyncio
async def test_shields_routing_table(cached_disk_dist_registry):
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry, {})
await table.initialize()
@ -177,7 +173,6 @@ async def test_shields_routing_table(cached_disk_dist_registry):
assert "test-shield-2" in shield_ids
@pytest.mark.asyncio
async def test_vectordbs_routing_table(cached_disk_dist_registry):
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
await table.initialize()
@ -233,7 +228,6 @@ async def test_datasets_routing_table(cached_disk_dist_registry):
assert len(datasets.data) == 0
@pytest.mark.asyncio
async def test_scoring_functions_routing_table(cached_disk_dist_registry):
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry, {})
await table.initialize()
@ -259,7 +253,6 @@ async def test_scoring_functions_routing_table(cached_disk_dist_registry):
assert "test-scoring-fn-2" in scoring_fn_ids
@pytest.mark.asyncio
async def test_benchmarks_routing_table(cached_disk_dist_registry):
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry, {})
await table.initialize()
@ -277,7 +270,6 @@ async def test_benchmarks_routing_table(cached_disk_dist_registry):
assert "test-benchmark" in benchmark_ids
@pytest.mark.asyncio
async def test_tool_groups_routing_table(cached_disk_dist_registry):
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry, {})
await table.initialize()

View file

@ -13,7 +13,6 @@ import pytest
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
@pytest.mark.asyncio
async def test_preserve_contexts_with_exception():
# Create context variable
context_var = ContextVar("exception_var", default="initial")
@ -41,7 +40,6 @@ async def test_preserve_contexts_with_exception():
context_var.reset(token)
@pytest.mark.asyncio
async def test_preserve_contexts_empty_generator():
# Create context variable
context_var = ContextVar("empty_var", default="initial")
@ -66,7 +64,6 @@ async def test_preserve_contexts_empty_generator():
context_var.reset(token)
@pytest.mark.asyncio
async def test_preserve_contexts_across_event_loops():
"""
Test that context variables are preserved across event loop boundaries with nested generators.

View file

@ -6,7 +6,6 @@
import pytest
import pytest_asyncio
from llama_stack.apis.common.responses import Order
from llama_stack.apis.files import OpenAIFilePurpose
@ -29,7 +28,7 @@ class MockUploadFile:
return self.content
@pytest_asyncio.fixture
@pytest.fixture
async def files_provider(tmp_path):
"""Create a files provider with temporary storage for testing."""
storage_dir = tmp_path / "files"
@ -68,7 +67,6 @@ def large_file():
class TestOpenAIFilesAPI:
"""Test suite for OpenAI Files API endpoints."""
@pytest.mark.asyncio
async def test_upload_file_success(self, files_provider, sample_text_file):
"""Test successful file upload."""
# Upload file
@ -82,7 +80,6 @@ class TestOpenAIFilesAPI:
assert result.created_at > 0
assert result.expires_at > result.created_at
@pytest.mark.asyncio
async def test_upload_different_purposes(self, files_provider, sample_text_file):
"""Test uploading files with different purposes."""
purposes = list(OpenAIFilePurpose)
@ -93,7 +90,6 @@ class TestOpenAIFilesAPI:
uploaded_files.append(result)
assert result.purpose == purpose
@pytest.mark.asyncio
async def test_upload_different_file_types(self, files_provider, sample_text_file, sample_json_file, large_file):
"""Test uploading different types and sizes of files."""
files_to_test = [
@ -107,7 +103,6 @@ class TestOpenAIFilesAPI:
assert result.filename == expected_filename
assert result.bytes == len(file_obj.content)
@pytest.mark.asyncio
async def test_list_files_empty(self, files_provider):
"""Test listing files when no files exist."""
result = await files_provider.openai_list_files()
@ -117,7 +112,6 @@ class TestOpenAIFilesAPI:
assert result.first_id == ""
assert result.last_id == ""
@pytest.mark.asyncio
async def test_list_files_with_content(self, files_provider, sample_text_file, sample_json_file):
"""Test listing files when files exist."""
# Upload multiple files
@ -132,7 +126,6 @@ class TestOpenAIFilesAPI:
assert file1.id in file_ids
assert file2.id in file_ids
@pytest.mark.asyncio
async def test_list_files_with_purpose_filter(self, files_provider, sample_text_file):
"""Test listing files with purpose filtering."""
# Upload file with specific purpose
@ -146,7 +139,6 @@ class TestOpenAIFilesAPI:
assert result.data[0].id == uploaded_file.id
assert result.data[0].purpose == OpenAIFilePurpose.ASSISTANTS
@pytest.mark.asyncio
async def test_list_files_with_limit(self, files_provider, sample_text_file):
"""Test listing files with limit parameter."""
# Upload multiple files
@ -157,7 +149,6 @@ class TestOpenAIFilesAPI:
result = await files_provider.openai_list_files(limit=3)
assert len(result.data) == 3
@pytest.mark.asyncio
async def test_list_files_with_order(self, files_provider, sample_text_file):
"""Test listing files with different order."""
# Upload multiple files
@ -178,7 +169,6 @@ class TestOpenAIFilesAPI:
# Oldest should be first
assert result_asc.data[0].created_at <= result_asc.data[1].created_at <= result_asc.data[2].created_at
@pytest.mark.asyncio
async def test_retrieve_file_success(self, files_provider, sample_text_file):
"""Test successful file retrieval."""
# Upload file
@ -197,13 +187,11 @@ class TestOpenAIFilesAPI:
assert retrieved_file.created_at == uploaded_file.created_at
assert retrieved_file.expires_at == uploaded_file.expires_at
@pytest.mark.asyncio
async def test_retrieve_file_not_found(self, files_provider):
"""Test retrieving a non-existent file."""
with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
await files_provider.openai_retrieve_file("file-nonexistent")
@pytest.mark.asyncio
async def test_retrieve_file_content_success(self, files_provider, sample_text_file):
"""Test successful file content retrieval."""
# Upload file
@ -217,13 +205,11 @@ class TestOpenAIFilesAPI:
# Verify content
assert content.body == sample_text_file.content
@pytest.mark.asyncio
async def test_retrieve_file_content_not_found(self, files_provider):
"""Test retrieving content of a non-existent file."""
with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
await files_provider.openai_retrieve_file_content("file-nonexistent")
@pytest.mark.asyncio
async def test_delete_file_success(self, files_provider, sample_text_file):
"""Test successful file deletion."""
# Upload file
@ -245,13 +231,11 @@ class TestOpenAIFilesAPI:
with pytest.raises(ValueError, match=f"File with id {uploaded_file.id} not found"):
await files_provider.openai_retrieve_file(uploaded_file.id)
@pytest.mark.asyncio
async def test_delete_file_not_found(self, files_provider):
"""Test deleting a non-existent file."""
with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
await files_provider.openai_delete_file("file-nonexistent")
@pytest.mark.asyncio
async def test_file_persistence_across_operations(self, files_provider, sample_text_file):
"""Test that files persist correctly across multiple operations."""
# Upload file
@ -279,7 +263,6 @@ class TestOpenAIFilesAPI:
files_list = await files_provider.openai_list_files()
assert len(files_list.data) == 0
@pytest.mark.asyncio
async def test_multiple_files_operations(self, files_provider, sample_text_file, sample_json_file):
"""Test operations with multiple files."""
# Upload multiple files
@ -302,7 +285,6 @@ class TestOpenAIFilesAPI:
content = await files_provider.openai_retrieve_file_content(file2.id)
assert content.body == sample_json_file.content
@pytest.mark.asyncio
async def test_file_id_uniqueness(self, files_provider, sample_text_file):
"""Test that each uploaded file gets a unique ID."""
file_ids = set()
@ -316,7 +298,6 @@ class TestOpenAIFilesAPI:
file_ids.add(uploaded_file.id)
assert uploaded_file.id.startswith("file-")
@pytest.mark.asyncio
async def test_file_no_filename_handling(self, files_provider):
"""Test handling files with no filename."""
file_without_name = MockUploadFile(b"content", None) # No filename
@ -327,7 +308,6 @@ class TestOpenAIFilesAPI:
assert uploaded_file.filename == "uploaded_file" # Default filename
@pytest.mark.asyncio
async def test_after_pagination_works(self, files_provider, sample_text_file):
"""Test that 'after' pagination works correctly."""
# Upload multiple files to test pagination

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest_asyncio
import pytest
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry, DiskDistributionRegistry
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
@pytest_asyncio.fixture(scope="function")
@pytest.fixture(scope="function")
async def sqlite_kvstore(tmp_path):
db_path = tmp_path / "test_kv.db"
kvstore_config = SqliteKVStoreConfig(db_path=db_path.as_posix())
@ -20,14 +20,14 @@ async def sqlite_kvstore(tmp_path):
yield kvstore
@pytest_asyncio.fixture(scope="function")
@pytest.fixture(scope="function")
async def disk_dist_registry(sqlite_kvstore):
registry = DiskDistributionRegistry(sqlite_kvstore)
await registry.initialize()
yield registry
@pytest_asyncio.fixture(scope="function")
@pytest.fixture(scope="function")
async def cached_disk_dist_registry(sqlite_kvstore):
registry = CachedDiskDistributionRegistry(sqlite_kvstore)
await registry.initialize()

View file

@ -8,7 +8,6 @@ from datetime import datetime
from unittest.mock import AsyncMock
import pytest
import pytest_asyncio
from llama_stack.apis.agents import (
Agent,
@ -50,7 +49,7 @@ def config(tmp_path):
)
@pytest_asyncio.fixture
@pytest.fixture
async def agents_impl(config, mock_apis):
impl = MetaReferenceAgentsImpl(
config,
@ -117,7 +116,6 @@ def sample_agent_config():
)
@pytest.mark.asyncio
async def test_create_agent(agents_impl, sample_agent_config):
response = await agents_impl.create_agent(sample_agent_config)
@ -132,7 +130,6 @@ async def test_create_agent(agents_impl, sample_agent_config):
assert isinstance(agent_info.created_at, datetime)
@pytest.mark.asyncio
async def test_get_agent(agents_impl, sample_agent_config):
create_response = await agents_impl.create_agent(sample_agent_config)
agent_id = create_response.agent_id
@ -146,7 +143,6 @@ async def test_get_agent(agents_impl, sample_agent_config):
assert isinstance(agent.created_at, datetime)
@pytest.mark.asyncio
async def test_list_agents(agents_impl, sample_agent_config):
agent1_response = await agents_impl.create_agent(sample_agent_config)
agent2_response = await agents_impl.create_agent(sample_agent_config)
@ -160,7 +156,6 @@ async def test_list_agents(agents_impl, sample_agent_config):
assert agent2_response.agent_id in agent_ids
@pytest.mark.asyncio
@pytest.mark.parametrize("enable_session_persistence", [True, False])
async def test_create_agent_session_persistence(agents_impl, sample_agent_config, enable_session_persistence):
# Create an agent with specified persistence setting
@ -188,7 +183,6 @@ async def test_create_agent_session_persistence(agents_impl, sample_agent_config
await agents_impl.get_agents_session(agent_id, session_response.session_id)
@pytest.mark.asyncio
@pytest.mark.parametrize("enable_session_persistence", [True, False])
async def test_list_agent_sessions_persistence(agents_impl, sample_agent_config, enable_session_persistence):
# Create an agent with specified persistence setting
@ -221,7 +215,6 @@ async def test_list_agent_sessions_persistence(agents_impl, sample_agent_config,
assert session2.session_id in {s["session_id"] for s in sessions.data}
@pytest.mark.asyncio
async def test_delete_agent(agents_impl, sample_agent_config):
# Create an agent
response = await agents_impl.create_agent(sample_agent_config)

View file

@ -122,7 +122,6 @@ async def fake_stream(fixture: str = "simple_chat_completion.yaml"):
)
@pytest.mark.asyncio
async def test_create_openai_response_with_string_input(openai_responses_impl, mock_inference_api):
"""Test creating an OpenAI response with a simple string input."""
# Setup
@ -155,7 +154,6 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
assert result.output[0].content[0].text == "Dublin"
@pytest.mark.asyncio
async def test_create_openai_response_with_string_input_with_tools(openai_responses_impl, mock_inference_api):
"""Test creating an OpenAI response with a simple string input and tools."""
# Setup
@ -224,7 +222,6 @@ async def test_create_openai_response_with_string_input_with_tools(openai_respon
assert result.output[1].content[0].annotations == []
@pytest.mark.asyncio
async def test_create_openai_response_with_tool_call_type_none(openai_responses_impl, mock_inference_api):
"""Test creating an OpenAI response with a tool call response that has a type of None."""
# Setup
@ -294,7 +291,6 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
assert chunks[1].response.output[0].name == "get_weather"
@pytest.mark.asyncio
async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api):
"""Test creating an OpenAI response with multiple messages."""
# Setup
@ -340,7 +336,6 @@ async def test_create_openai_response_with_multiple_messages(openai_responses_im
assert isinstance(inference_messages[i], OpenAIDeveloperMessageParam)
@pytest.mark.asyncio
async def test_prepend_previous_response_none(openai_responses_impl):
"""Test prepending no previous response to a new response."""
@ -348,7 +343,6 @@ async def test_prepend_previous_response_none(openai_responses_impl):
assert input == "fake_input"
@pytest.mark.asyncio
async def test_prepend_previous_response_basic(openai_responses_impl, mock_responses_store):
"""Test prepending a basic previous response to a new response."""
@ -388,7 +382,6 @@ async def test_prepend_previous_response_basic(openai_responses_impl, mock_respo
assert input[2].content == "fake_input"
@pytest.mark.asyncio
async def test_prepend_previous_response_web_search(openai_responses_impl, mock_responses_store):
"""Test prepending a web search previous response to a new response."""
input_item_message = OpenAIResponseMessage(
@ -434,7 +427,6 @@ async def test_prepend_previous_response_web_search(openai_responses_impl, mock_
assert input[3].content == "fake_input"
@pytest.mark.asyncio
async def test_create_openai_response_with_instructions(openai_responses_impl, mock_inference_api):
# Setup
input_text = "What is the capital of Ireland?"
@ -463,7 +455,6 @@ async def test_create_openai_response_with_instructions(openai_responses_impl, m
assert sent_messages[1].content == input_text
@pytest.mark.asyncio
async def test_create_openai_response_with_instructions_and_multiple_messages(
openai_responses_impl, mock_inference_api
):
@ -508,7 +499,6 @@ async def test_create_openai_response_with_instructions_and_multiple_messages(
assert sent_messages[3].content == "Which is the largest?"
@pytest.mark.asyncio
async def test_create_openai_response_with_instructions_and_previous_response(
openai_responses_impl, mock_responses_store, mock_inference_api
):
@ -565,7 +555,6 @@ async def test_create_openai_response_with_instructions_and_previous_response(
assert sent_messages[3].content == "Which is the largest?"
@pytest.mark.asyncio
async def test_list_openai_response_input_items_delegation(openai_responses_impl, mock_responses_store):
"""Test that list_openai_response_input_items properly delegates to responses_store with correct parameters."""
# Setup
@ -601,7 +590,6 @@ async def test_list_openai_response_input_items_delegation(openai_responses_impl
assert result.data[0].id == "msg_123"
@pytest.mark.asyncio
async def test_responses_store_list_input_items_logic():
"""Test ResponsesStore list_response_input_items logic - mocks get_response_object to test actual ordering/limiting."""
@ -680,7 +668,6 @@ async def test_responses_store_list_input_items_logic():
assert len(result.data) == 0 # Should return no items
@pytest.mark.asyncio
async def test_store_response_uses_rehydrated_input_with_previous_response(
openai_responses_impl, mock_responses_store, mock_inference_api
):
@ -747,7 +734,6 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
assert result.status == "completed"
@pytest.mark.asyncio
@pytest.mark.parametrize(
"text_format, response_format",
[
@ -787,7 +773,6 @@ async def test_create_openai_response_with_text_format(
assert first_call.kwargs["response_format"] == response_format
@pytest.mark.asyncio
async def test_create_openai_response_with_invalid_text_format(openai_responses_impl, mock_inference_api):
"""Test creating an OpenAI response with an invalid text format."""
# Setup

View file

@ -9,7 +9,6 @@ from datetime import datetime
from unittest.mock import patch
import pytest
import pytest_asyncio
from llama_stack.apis.agents import Turn
from llama_stack.apis.inference import CompletionMessage, StopReason
@ -17,13 +16,12 @@ from llama_stack.distribution.datatypes import User
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
@pytest_asyncio.fixture
@pytest.fixture
async def test_setup(sqlite_kvstore):
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore, policy={})
yield agent_persistence
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_session_creation_with_access_attributes(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup
@ -44,7 +42,6 @@ async def test_session_creation_with_access_attributes(mock_get_authenticated_us
assert session_info.owner.attributes["teams"] == ["ai-team"]
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_session_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup
@ -79,7 +76,6 @@ async def test_session_access_control(mock_get_authenticated_user, test_setup):
assert retrieved_session is None
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_turn_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup
@ -133,7 +129,6 @@ async def test_turn_access_control(mock_get_authenticated_user, test_setup):
await agent_persistence.get_session_turns(session_id)
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_tool_call_and_infer_iters_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup

View file

@ -0,0 +1,73 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from unittest.mock import MagicMock
from llama_stack.distribution.request_headers import request_provider_data_context
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter
def test_groq_provider_openai_client_caching():
"""Ensure the Groq provider does not cache api keys across client requests"""
config = GroqConfig()
inference_adapter = GroqInferenceAdapter(config)
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = (
"llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator"
)
for api_key in ["test1", "test2"]:
with request_provider_data_context(
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
):
openai_client = inference_adapter._get_openai_client()
assert openai_client.api_key == api_key
def test_openai_provider_openai_client_caching():
"""Ensure the OpenAI provider does not cache api keys across client requests"""
config = OpenAIConfig()
inference_adapter = OpenAIInferenceAdapter(config)
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = (
"llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator"
)
for api_key in ["test1", "test2"]:
with request_provider_data_context(
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
):
openai_client = inference_adapter._get_openai_client()
assert openai_client.api_key == api_key
def test_together_provider_openai_client_caching():
"""Ensure the Together provider does not cache api keys across client requests"""
config = TogetherImplConfig()
inference_adapter = TogetherInferenceAdapter(config)
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = (
"llama_stack.providers.remote.inference.together.TogetherProviderDataValidator"
)
for api_key in ["test1", "test2"]:
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"together_api_key": api_key})}):
together_client = inference_adapter._get_client()
assert together_client.client.api_key == api_key
openai_client = inference_adapter._get_openai_client()
assert openai_client.api_key == api_key

View file

@ -14,7 +14,6 @@ from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk as OpenAIChatCompletionChunk,
)
@ -103,7 +102,7 @@ def mock_openai_models_list():
yield mock_list
@pytest_asyncio.fixture(scope="module")
@pytest.fixture(scope="module")
async def vllm_inference_adapter():
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
inference_adapter = VLLMInferenceAdapter(config)
@ -112,7 +111,6 @@ async def vllm_inference_adapter():
return inference_adapter
@pytest.mark.asyncio
async def test_register_model_checks_vllm(mock_openai_models_list, vllm_inference_adapter):
async def mock_openai_models():
yield OpenAIModel(id="foo", created=1, object="model", owned_by="test")
@ -125,7 +123,6 @@ async def test_register_model_checks_vllm(mock_openai_models_list, vllm_inferenc
mock_openai_models_list.assert_called()
@pytest.mark.asyncio
async def test_old_vllm_tool_choice(vllm_inference_adapter):
"""
Test that we set tool_choice to none when no tools are in use
@ -149,7 +146,6 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter):
assert request.tool_config.tool_choice == ToolChoice.none
@pytest.mark.asyncio
async def test_tool_call_response(vllm_inference_adapter):
"""Verify that tool call arguments from a CompletionMessage are correctly converted
into the expected JSON format."""
@ -192,7 +188,6 @@ async def test_tool_call_response(vllm_inference_adapter):
]
@pytest.mark.asyncio
async def test_tool_call_delta_empty_tool_call_buf():
"""
Test that we don't generate extra chunks when processing a
@ -222,7 +217,6 @@ async def test_tool_call_delta_empty_tool_call_buf():
assert chunks[1].event.stop_reason == StopReason.end_of_turn
@pytest.mark.asyncio
async def test_tool_call_delta_streaming_arguments_dict():
async def mock_stream():
mock_chunk_1 = OpenAIChatCompletionChunk(
@ -297,7 +291,6 @@ async def test_tool_call_delta_streaming_arguments_dict():
assert chunks[2].event.event_type.value == "complete"
@pytest.mark.asyncio
async def test_multiple_tool_calls():
async def mock_stream():
mock_chunk_1 = OpenAIChatCompletionChunk(
@ -376,7 +369,6 @@ async def test_multiple_tool_calls():
assert chunks[3].event.event_type.value == "complete"
@pytest.mark.asyncio
async def test_process_vllm_chat_completion_stream_response_no_choices():
"""
Test that we don't error out when vLLM returns no choices for a
@ -401,6 +393,7 @@ async def test_process_vllm_chat_completion_stream_response_no_choices():
assert chunks[0].event.event_type.value == "start"
@pytest.mark.allow_network
def test_chat_completion_doesnt_block_event_loop(caplog):
loop = asyncio.new_event_loop()
loop.set_debug(True)
@ -453,7 +446,6 @@ def test_chat_completion_doesnt_block_event_loop(caplog):
assert not asyncio_warnings
@pytest.mark.asyncio
async def test_get_params_empty_tools(vllm_inference_adapter):
request = ChatCompletionRequest(
tools=[],
@ -464,7 +456,6 @@ async def test_get_params_empty_tools(vllm_inference_adapter):
assert "tools" not in params
@pytest.mark.asyncio
async def test_process_vllm_chat_completion_stream_response_tool_call_args_last_chunk():
"""
Tests the edge case where the model returns the arguments for the tool call in the same chunk that
@ -543,7 +534,6 @@ async def test_process_vllm_chat_completion_stream_response_tool_call_args_last_
assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments
@pytest.mark.asyncio
async def test_process_vllm_chat_completion_stream_response_no_finish_reason():
"""
Tests the edge case where the model requests a tool call and stays idle without explicitly providing the
@ -596,7 +586,6 @@ async def test_process_vllm_chat_completion_stream_response_no_finish_reason():
assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments
@pytest.mark.asyncio
async def test_process_vllm_chat_completion_stream_response_tool_without_args():
"""
Tests the edge case where no arguments are provided for the tool call.
@ -645,7 +634,6 @@ async def test_process_vllm_chat_completion_stream_response_tool_without_args():
assert chunks[-2].event.delta.tool_call.arguments == {}
@pytest.mark.asyncio
async def test_health_status_success(vllm_inference_adapter):
"""
Test the health method of VLLM InferenceAdapter when the connection is successful.
@ -679,7 +667,6 @@ async def test_health_status_success(vllm_inference_adapter):
mock_models.list.assert_called_once()
@pytest.mark.asyncio
async def test_health_status_failure(vllm_inference_adapter):
"""
Test the health method of VLLM InferenceAdapter when the connection fails.

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.apis.common.content_types import TextContentItem
from llama_stack.apis.inference import (
@ -23,7 +22,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
)
@pytest.mark.asyncio
async def test_convert_message_to_openai_dict():
message = UserMessage(content=[TextContentItem(text="Hello, world!")], role="user")
assert await convert_message_to_openai_dict(message) == {
@ -33,7 +31,6 @@ async def test_convert_message_to_openai_dict():
# Test convert_message_to_openai_dict with a tool call
@pytest.mark.asyncio
async def test_convert_message_to_openai_dict_with_tool_call():
message = CompletionMessage(
content="",
@ -54,7 +51,6 @@ async def test_convert_message_to_openai_dict_with_tool_call():
}
@pytest.mark.asyncio
async def test_convert_message_to_openai_dict_with_builtin_tool_call():
message = CompletionMessage(
content="",
@ -80,7 +76,6 @@ async def test_convert_message_to_openai_dict_with_builtin_tool_call():
}
@pytest.mark.asyncio
async def test_openai_messages_to_messages_with_content_str():
openai_messages = [
OpenAISystemMessageParam(content="system message"),
@ -98,7 +93,6 @@ async def test_openai_messages_to_messages_with_content_str():
assert llama_messages[2].content == "assistant message"
@pytest.mark.asyncio
async def test_openai_messages_to_messages_with_content_list():
openai_messages = [
OpenAISystemMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="system message")]),

View file

@ -13,7 +13,6 @@ from llama_stack.apis.tools import RAGDocument
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, content_from_doc
@pytest.mark.asyncio
async def test_content_from_doc_with_url():
"""Test extracting content from RAGDocument with URL content."""
mock_url = URL(uri="https://example.com")
@ -33,7 +32,6 @@ async def test_content_from_doc_with_url():
mock_instance.get.assert_called_once_with(mock_url.uri)
@pytest.mark.asyncio
async def test_content_from_doc_with_pdf_url():
"""Test extracting content from RAGDocument with URL pointing to a PDF."""
mock_url = URL(uri="https://example.com/document.pdf")
@ -58,7 +56,6 @@ async def test_content_from_doc_with_pdf_url():
mock_parse_pdf.assert_called_once_with(b"PDF binary data")
@pytest.mark.asyncio
async def test_content_from_doc_with_data_url():
"""Test extracting content from RAGDocument with data URL content."""
data_url = "data:text/plain;base64,SGVsbG8gV29ybGQ=" # "Hello World" base64 encoded
@ -74,7 +71,6 @@ async def test_content_from_doc_with_data_url():
mock_content_from_data.assert_called_once_with(data_url)
@pytest.mark.asyncio
async def test_content_from_doc_with_string():
"""Test extracting content from RAGDocument with string content."""
content_string = "This is plain text content"
@ -85,7 +81,6 @@ async def test_content_from_doc_with_string():
assert result == content_string
@pytest.mark.asyncio
async def test_content_from_doc_with_string_url():
"""Test extracting content from RAGDocument with string URL content."""
url_string = "https://example.com"
@ -105,7 +100,6 @@ async def test_content_from_doc_with_string_url():
mock_instance.get.assert_called_once_with(url_string)
@pytest.mark.asyncio
async def test_content_from_doc_with_string_pdf_url():
"""Test extracting content from RAGDocument with string URL pointing to a PDF."""
url_string = "https://example.com/document.pdf"
@ -130,7 +124,6 @@ async def test_content_from_doc_with_string_pdf_url():
mock_parse_pdf.assert_called_once_with(b"PDF binary data")
@pytest.mark.asyncio
async def test_content_from_doc_with_interleaved_content():
"""Test extracting content from RAGDocument with InterleavedContent (the new case added in the commit)."""
interleaved_content = [TextContentItem(text="First item"), TextContentItem(text="Second item")]

View file

@ -94,8 +94,8 @@ class MockModelRegistryHelperWithDynamicModels(ModelRegistryHelper):
super().__init__(model_entries)
self._available_models = available_models
async def query_available_models(self) -> list[str]:
return self._available_models
async def check_model_availability(self, model: str) -> bool:
return model in self._available_models
@pytest.fixture
@ -118,18 +118,15 @@ def helper_with_dynamic_models(
)
@pytest.mark.asyncio
async def test_lookup_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
assert helper.get_provider_model_id(unknown_model.model_id) is None
@pytest.mark.asyncio
async def test_register_unknown_provider_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
with pytest.raises(ValueError):
await helper.register_model(unknown_model)
@pytest.mark.asyncio
async def test_register_model(helper: ModelRegistryHelper, known_model: Model) -> None:
model = Model(
provider_id=known_model.provider_id,
@ -141,7 +138,6 @@ async def test_register_model(helper: ModelRegistryHelper, known_model: Model) -
assert helper.get_provider_model_id(model.model_id) == model.provider_resource_id
@pytest.mark.asyncio
async def test_register_model_from_alias(helper: ModelRegistryHelper, known_model: Model) -> None:
model = Model(
provider_id=known_model.provider_id,
@ -153,13 +149,11 @@ async def test_register_model_from_alias(helper: ModelRegistryHelper, known_mode
assert helper.get_provider_model_id(model.model_id) == known_model.provider_resource_id
@pytest.mark.asyncio
async def test_register_model_existing(helper: ModelRegistryHelper, known_model: Model) -> None:
await helper.register_model(known_model)
assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_resource_id
@pytest.mark.asyncio
async def test_register_model_existing_different(
helper: ModelRegistryHelper, known_model: Model, known_model2: Model
) -> None:
@ -168,7 +162,6 @@ async def test_register_model_existing_different(
await helper.register_model(known_model)
@pytest.mark.asyncio
async def test_unregister_model(helper: ModelRegistryHelper, known_model: Model) -> None:
await helper.register_model(known_model) # duplicate entry
assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_model_id
@ -176,35 +169,31 @@ async def test_unregister_model(helper: ModelRegistryHelper, known_model: Model)
assert helper.get_provider_model_id(known_model.model_id) is None
@pytest.mark.asyncio
async def test_unregister_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
with pytest.raises(ValueError):
await helper.unregister_model(unknown_model.model_id)
@pytest.mark.asyncio
async def test_register_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
@pytest.mark.asyncio
async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
await helper.unregister_model(known_model.provider_resource_id)
assert helper.get_provider_model_id(known_model.provider_resource_id) is None
@pytest.mark.asyncio
async def test_register_model_from_query_available_models(
async def test_register_model_from_check_model_availability(
helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, dynamic_model: Model
) -> None:
"""Test that models returned by query_available_models can be registered."""
"""Test that models returned by check_model_availability can be registered."""
# Verify the model is not in static config
assert helper_with_dynamic_models.get_provider_model_id(dynamic_model.provider_resource_id) is None
# But it should be available via query_available_models
available_models = await helper_with_dynamic_models.query_available_models()
assert dynamic_model.provider_resource_id in available_models
# But it should be available via check_model_availability
is_available = await helper_with_dynamic_models.check_model_availability(dynamic_model.provider_resource_id)
assert is_available
# Registration should succeed
registered_model = await helper_with_dynamic_models.register_model(dynamic_model)
@ -216,7 +205,6 @@ async def test_register_model_from_query_available_models(
)
@pytest.mark.asyncio
async def test_register_model_not_in_static_or_dynamic(
helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, unknown_model: Model
) -> None:
@ -224,20 +212,19 @@ async def test_register_model_not_in_static_or_dynamic(
# Verify the model is not in static config
assert helper_with_dynamic_models.get_provider_model_id(unknown_model.provider_resource_id) is None
# And not in dynamic models
available_models = await helper_with_dynamic_models.query_available_models()
assert unknown_model.provider_resource_id not in available_models
# And not available via check_model_availability
is_available = await helper_with_dynamic_models.check_model_availability(unknown_model.provider_resource_id)
assert not is_available
# Registration should fail with comprehensive error message
with pytest.raises(Exception) as exc_info: # UnsupportedModelError
await helper_with_dynamic_models.register_model(unknown_model)
# Error should include both static and dynamic models
# Error should include static models and "..." for dynamic models
error_str = str(exc_info.value)
assert "dynamic-provider-id" in error_str # dynamic model should be in error
assert "..." in error_str # "..." should be in error message
@pytest.mark.asyncio
async def test_register_alias_for_dynamic_model(
helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, dynamic_model: Model
) -> None:

View file

@ -11,7 +11,6 @@ import pytest
from llama_stack.providers.utils.scheduler import JobStatus, Scheduler
@pytest.mark.asyncio
async def test_scheduler_unknown_backend():
with pytest.raises(ValueError):
Scheduler(backend="unknown")
@ -26,7 +25,6 @@ async def wait_for_job_completed(sched: Scheduler, job_id: str) -> None:
raise TimeoutError(f"Job {job_id} did not complete in time.")
@pytest.mark.asyncio
async def test_scheduler_naive():
sched = Scheduler()
@ -87,7 +85,6 @@ async def test_scheduler_naive():
assert job.logs[0][0] < job.logs[1][0]
@pytest.mark.asyncio
async def test_scheduler_naive_handler_raises():
sched = Scheduler()

View file

@ -8,10 +8,20 @@ import random
import numpy as np
import pytest
from pymilvus import MilvusClient, connections
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, ChunkMetadata
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
EMBEDDING_DIMENSION = 384
COLLECTION_PREFIX = "test_collection"
MILVUS_ALIAS = "test_milvus"
@pytest.fixture
@ -50,7 +60,194 @@ def sample_chunks():
return sample
@pytest.fixture(scope="session")
def sample_chunks_with_metadata():
"""Generates chunks that force multiple batches for a single document to expose ID conflicts."""
n, k = 10, 3
sample = [
Chunk(
content=f"Sentence {i} from document {j}",
metadata={"document_id": f"document-{j}"},
chunk_metadata=ChunkMetadata(
document_id=f"document-{j}",
chunk_id=f"document-{j}-chunk-{i}",
source=f"example source-{j}-{i}",
),
)
for j in range(k)
for i in range(n)
]
return sample
@pytest.fixture(scope="session")
def sample_embeddings(sample_chunks):
np.random.seed(42)
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks])
@pytest.fixture(scope="session")
def sample_embeddings_with_metadata(sample_chunks_with_metadata):
np.random.seed(42)
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks_with_metadata])
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss"])
def vector_provider(request):
return request.param
@pytest.fixture(scope="session")
def mock_inference_api(embedding_dimension):
class MockInferenceAPI:
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
return [np.random.rand(embedding_dimension).astype(np.float32).tolist() for _ in texts]
return MockInferenceAPI()
@pytest.fixture
async def unique_kvstore_config(tmp_path_factory):
# Generate a unique filename for this test
unique_id = f"test_kv_{np.random.randint(1e6)}"
temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / f"{unique_id}.db")
return SqliteKVStoreConfig(db_path=db_path)
@pytest.fixture(scope="session")
def sqlite_vec_db_path(tmp_path_factory):
db_path = str(tmp_path_factory.getbasetemp() / "test_sqlite_vec.db")
return db_path
@pytest.fixture
async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory):
temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / f"test_sqlite_vec_{np.random.randint(1e6)}.db")
bank_id = f"sqlite_vec_bank_{np.random.randint(1e6)}"
index = SQLiteVecIndex(embedding_dimension, db_path, bank_id)
await index.initialize()
index.db_path = db_path
yield index
index.delete()
@pytest.fixture
async def sqlite_vec_adapter(sqlite_vec_db_path, mock_inference_api, embedding_dimension):
config = SQLiteVectorIOConfig(
db_path=sqlite_vec_db_path,
kvstore=SqliteKVStoreConfig(),
)
adapter = SQLiteVecVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
collection_id = f"sqlite_test_collection_{np.random.randint(1e6)}"
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
identifier=collection_id,
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=embedding_dimension,
)
)
adapter.test_collection_id = collection_id
yield adapter
await adapter.shutdown()
@pytest.fixture(scope="session")
def milvus_vec_db_path(tmp_path_factory):
db_path = str(tmp_path_factory.getbasetemp() / "test_milvus.db")
return db_path
@pytest.fixture
async def milvus_vec_index(milvus_vec_db_path, embedding_dimension):
client = MilvusClient(milvus_vec_db_path)
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_db_path)
index = MilvusIndex(client, name, consistency_level="Strong")
index.db_path = milvus_vec_db_path
yield index
@pytest.fixture
async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api):
config = MilvusVectorIOConfig(
db_path=milvus_vec_db_path,
kvstore=SqliteKVStoreConfig(),
)
adapter = MilvusVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
identifier=adapter.metadata_collection_name,
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=128,
)
)
yield adapter
await adapter.shutdown()
@pytest.fixture
def faiss_vec_db_path(tmp_path_factory):
db_path = str(tmp_path_factory.getbasetemp() / "test_faiss.db")
return db_path
@pytest.fixture
async def faiss_vec_index(embedding_dimension):
index = FaissIndex(embedding_dimension)
yield index
await index.delete()
@pytest.fixture
async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding_dimension):
config = FaissVectorIOConfig(
kvstore=unique_kvstore_config,
)
adapter = FaissVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
identifier=f"faiss_test_collection_{np.random.randint(1e6)}",
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=embedding_dimension,
)
)
yield adapter
await adapter.shutdown()
@pytest.fixture
def vector_io_adapter(vector_provider, request):
"""Returns the appropriate vector IO adapter based on the provider parameter."""
if vector_provider == "milvus":
return request.getfixturevalue("milvus_vec_adapter")
elif vector_provider == "faiss":
return request.getfixturevalue("faiss_vec_adapter")
else:
return request.getfixturevalue("sqlite_vec_adapter")
@pytest.fixture
def vector_index(vector_provider, request):
"""Returns appropriate vector index based on provider parameter"""
return request.getfixturevalue(f"{vector_provider}_vec_index")

View file

@ -9,7 +9,6 @@ from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
import pytest_asyncio
from llama_stack.apis.files import Files
from llama_stack.apis.inference import EmbeddingsResponse, Inference
@ -91,13 +90,13 @@ def faiss_config():
return config
@pytest_asyncio.fixture
@pytest.fixture
async def faiss_index(embedding_dimension):
index = await FaissIndex.create(dimension=embedding_dimension)
yield index
@pytest_asyncio.fixture
@pytest.fixture
async def faiss_adapter(faiss_config, mock_inference_api, mock_files_api) -> FaissVectorIOAdapter:
# Create the adapter
adapter = FaissVectorIOAdapter(config=faiss_config, inference_api=mock_inference_api, files_api=mock_files_api)
@ -113,7 +112,6 @@ async def faiss_adapter(faiss_config, mock_inference_api, mock_files_api) -> Fai
yield adapter
@pytest.mark.asyncio
async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_identical(
faiss_index, sample_chunks, sample_embeddings, embedding_dimension
):
@ -136,7 +134,6 @@ async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_
assert response.chunks[1] == sample_chunks[1]
@pytest.mark.asyncio
async def test_health_success():
"""Test that the health check returns OK status when faiss is working correctly."""
# Create a fresh instance of FaissVectorIOAdapter for testing
@ -160,7 +157,6 @@ async def test_health_success():
mock_index_flat.assert_called_once_with(128) # VECTOR_DIMENSION is 128
@pytest.mark.asyncio
async def test_health_failure():
"""Test that the health check returns ERROR status when faiss encounters an error."""
# Create a fresh instance of FaissVectorIOAdapter for testing

View file

@ -10,7 +10,6 @@ from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from llama_stack.apis.inference import EmbeddingsResponse, Inference
from llama_stack.apis.vector_io import (
@ -68,7 +67,7 @@ def mock_api_service(sample_embeddings):
return mock_api_service
@pytest_asyncio.fixture
@pytest.fixture
async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service, loop) -> QdrantVectorIOAdapter:
adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service)
adapter.vector_db_store = mock_vector_db_store
@ -80,7 +79,6 @@ async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service,
__QUERY = "Sample query"
@pytest.mark.asyncio
@pytest.mark.parametrize("max_query_chunks, expected_chunks", [(2, 2), (100, 60)])
async def test_qdrant_adapter_returns_expected_chunks(
qdrant_adapter: QdrantVectorIOAdapter,
@ -111,7 +109,6 @@ def _prepare_for_json(value: Any) -> str:
@patch("llama_stack.providers.utils.telemetry.trace_protocol._prepare_for_json", new=_prepare_for_json)
@pytest.mark.asyncio
async def test_qdrant_register_and_unregister_vector_db(
qdrant_adapter: QdrantVectorIOAdapter,
mock_vector_db,

View file

@ -8,7 +8,6 @@ import asyncio
import numpy as np
import pytest
import pytest_asyncio
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import (
@ -34,7 +33,7 @@ def loop():
return asyncio.new_event_loop()
@pytest_asyncio.fixture(scope="session", autouse=True)
@pytest.fixture
async def sqlite_vec_index(embedding_dimension, tmp_path_factory):
temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / "test_sqlite.db")
@ -43,39 +42,14 @@ async def sqlite_vec_index(embedding_dimension, tmp_path_factory):
await index.delete()
@pytest.mark.asyncio
async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=2)
connection = _create_sqlite_connection(sqlite_vec_index.db_path)
cur = connection.cursor()
cur.execute(f"SELECT COUNT(*) FROM {sqlite_vec_index.metadata_table}")
count = cur.fetchone()[0]
assert count == len(sample_chunks)
cur.close()
connection.close()
async def test_query_chunk_metadata(sqlite_vec_index, sample_chunks_with_metadata, sample_embeddings_with_metadata):
await sqlite_vec_index.add_chunks(sample_chunks_with_metadata, sample_embeddings_with_metadata)
response = await sqlite_vec_index.query_vector(sample_embeddings_with_metadata[-1], k=2, score_threshold=0.0)
assert response.chunks[0].chunk_metadata == sample_chunks_with_metadata[-1].chunk_metadata
@pytest.mark.asyncio
async def test_query_chunks_vector(sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
response = await sqlite_vec_index.query_vector(query_embedding, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
@pytest.mark.xfail(reason="Chunk Metadata not yet supported for SQLite-vec", strict=True)
async def test_query_chunk_metadata(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
query_embedding = sample_embeddings[0]
response = await sqlite_vec_index.query_vector(query_embedding, k=2, score_threshold=0.0)
assert response.chunks[-1].chunk_metadata == sample_chunks[-1].chunk_metadata
@pytest.mark.asyncio
async def test_query_chunks_full_text_search(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
query_string = "Sentence 5"
response = await sqlite_vec_index.query_keyword(k=3, score_threshold=0.0, query_string=query_string)
@ -91,7 +65,6 @@ async def test_query_chunks_full_text_search(sqlite_vec_index, sample_chunks, sa
assert len(response_no_results.chunks) == 0, f"Expected 0 results, but got {len(response_no_results.chunks)}"
@pytest.mark.asyncio
async def test_query_chunks_hybrid(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -113,7 +86,6 @@ async def test_query_chunks_hybrid(sqlite_vec_index, sample_chunks, sample_embed
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
@pytest.mark.asyncio
async def test_query_chunks_full_text_search_k_greater_than_results(sqlite_vec_index, sample_chunks, sample_embeddings):
# Re-initialize with a clean index
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -126,7 +98,6 @@ async def test_query_chunks_full_text_search_k_greater_than_results(sqlite_vec_i
assert any("Sentence 1 from document 0" in chunk.content for chunk in response.chunks), "Expected chunk not found"
@pytest.mark.asyncio
async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dimension):
"""Test that chunk IDs do not conflict across batches when inserting chunks."""
# Reduce batch size to force multiple batches for same document
@ -148,7 +119,7 @@ async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dime
assert len(chunk_ids) == len(set(chunk_ids)), "Duplicate chunk IDs detected across batches!"
@pytest_asyncio.fixture(scope="session")
@pytest.fixture(scope="session")
async def sqlite_vec_adapter(sqlite_connection):
config = type("Config", (object,), {"db_path": ":memory:"}) # Mock config with in-memory database
adapter = SQLiteVecVectorIOAdapter(config=config, inference_api=None)
@ -157,7 +128,6 @@ async def sqlite_vec_adapter(sqlite_connection):
await adapter.shutdown()
@pytest.mark.asyncio
async def test_query_chunks_hybrid_no_keyword_matches(sqlite_vec_index, sample_chunks, sample_embeddings):
"""Test hybrid search when keyword search returns no matches - should still return vector results."""
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -186,7 +156,6 @@ async def test_query_chunks_hybrid_no_keyword_matches(sqlite_vec_index, sample_c
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
@pytest.mark.asyncio
async def test_query_chunks_hybrid_score_threshold(sqlite_vec_index, sample_chunks, sample_embeddings):
"""Test hybrid search with a high score threshold."""
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -208,7 +177,6 @@ async def test_query_chunks_hybrid_score_threshold(sqlite_vec_index, sample_chun
assert len(response.chunks) == 0
@pytest.mark.asyncio
async def test_query_chunks_hybrid_different_embedding(
sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension
):
@ -234,7 +202,6 @@ async def test_query_chunks_hybrid_different_embedding(
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
@pytest.mark.asyncio
async def test_query_chunks_hybrid_rrf_ranking(sqlite_vec_index, sample_chunks, sample_embeddings):
"""Test that RRF properly combines rankings when documents appear in both search methods."""
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -259,7 +226,6 @@ async def test_query_chunks_hybrid_rrf_ranking(sqlite_vec_index, sample_chunks,
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
@pytest.mark.asyncio
async def test_query_chunks_hybrid_score_selection(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -307,7 +273,6 @@ async def test_query_chunks_hybrid_score_selection(sqlite_vec_index, sample_chun
assert response.scores[0] == pytest.approx(2.0 / 61.0, rel=1e-6) # Should behave like RRF
@pytest.mark.asyncio
async def test_query_chunks_hybrid_mixed_results(sqlite_vec_index, sample_chunks, sample_embeddings):
"""Test hybrid search with documents that appear in only one search method."""
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -336,7 +301,6 @@ async def test_query_chunks_hybrid_mixed_results(sqlite_vec_index, sample_chunks
assert "document-2" in doc_ids # From keyword search
@pytest.mark.asyncio
async def test_query_chunks_hybrid_weighted_reranker_parametrization(
sqlite_vec_index, sample_chunks, sample_embeddings
):
@ -392,7 +356,6 @@ async def test_query_chunks_hybrid_weighted_reranker_parametrization(
)
@pytest.mark.asyncio
async def test_query_chunks_hybrid_rrf_impact_factor(sqlite_vec_index, sample_chunks, sample_embeddings):
"""Test RRFReRanker with different impact factors."""
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -424,7 +387,6 @@ async def test_query_chunks_hybrid_rrf_impact_factor(sqlite_vec_index, sample_ch
assert response.scores[0] == pytest.approx(2.0 / 101.0, rel=1e-6)
@pytest.mark.asyncio
async def test_query_chunks_hybrid_edge_cases(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -468,7 +430,6 @@ async def test_query_chunks_hybrid_edge_cases(sqlite_vec_index, sample_chunks, s
assert len(response.chunks) <= 100
@pytest.mark.asyncio
async def test_query_chunks_hybrid_tie_breaking(
sqlite_vec_index, sample_embeddings, embedding_dimension, tmp_path_factory
):

View file

@ -4,253 +4,130 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
import time
from unittest.mock import AsyncMock
import numpy as np
import pytest
import pytest_asyncio
from pymilvus import Collection, MilvusClient, connections
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig
from llama_stack.providers.remote.vector_io.milvus.milvus import VECTOR_DBS_PREFIX, MilvusIndex, MilvusVectorIOAdapter
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.remote.vector_io.milvus.milvus import VECTOR_DBS_PREFIX
# TODO: Refactor these to be for inline vector-io providers
MILVUS_ALIAS = "test_milvus"
COLLECTION_PREFIX = "test_collection"
# This test is a unit test for the inline VectoerIO providers. This should only contain
# tests which are specific to this class. More general (API-level) tests should be placed in
# tests/integration/vector_io/
#
# How to run this test:
#
# pytest tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py \
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
@pytest.fixture(scope="session")
def loop():
return asyncio.new_event_loop()
async def test_initialize_index(vector_index):
await vector_index.initialize()
@pytest.fixture(scope="session")
def mock_inference_api(embedding_dimension):
class MockInferenceAPI:
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
return [np.random.rand(embedding_dimension).astype(np.float32).tolist() for _ in texts]
return MockInferenceAPI()
@pytest_asyncio.fixture
async def unique_kvstore_config(tmp_path_factory):
# Generate a unique filename for this test
unique_id = f"test_kv_{np.random.randint(1e6)}"
temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / f"{unique_id}.db")
return SqliteKVStoreConfig(db_path=db_path)
@pytest_asyncio.fixture(scope="session", autouse=True)
async def milvus_vec_index(embedding_dimension, tmp_path_factory):
temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / "test_milvus.db")
client = MilvusClient(db_path)
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
connections.connect(alias=MILVUS_ALIAS, uri=db_path)
index = MilvusIndex(client, name, consistency_level="Strong")
index.db_path = db_path
yield index
@pytest_asyncio.fixture(scope="session")
async def milvus_vec_adapter(milvus_vec_index, mock_inference_api):
config = MilvusVectorIOConfig(
db_path=milvus_vec_index.db_path,
kvstore=SqliteKVStoreConfig(),
)
adapter = MilvusVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
identifier=adapter.metadata_collection_name,
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=128,
)
)
yield adapter
await adapter.shutdown()
@pytest.mark.asyncio
async def test_cache_contains_initial_collection(milvus_vec_adapter):
coll_name = milvus_vec_adapter.metadata_collection_name
assert coll_name in milvus_vec_adapter.cache
@pytest.mark.asyncio
async def test_add_chunks(milvus_vec_index, sample_chunks, sample_embeddings):
await milvus_vec_index.add_chunks(sample_chunks, sample_embeddings)
resp = await milvus_vec_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1)
async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embeddings):
vector_index.delete()
vector_index.initialize()
await vector_index.add_chunks(sample_chunks, sample_embeddings)
resp = await vector_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1)
assert resp.chunks[0].content == sample_chunks[0].content
vector_index.delete()
@pytest.mark.asyncio
async def test_query_chunks_vector(milvus_vec_index, sample_chunks, sample_embeddings, embedding_dimension):
await milvus_vec_index.add_chunks(sample_chunks, sample_embeddings)
query_emb = np.random.rand(embedding_dimension).astype(np.float32)
resp = await milvus_vec_index.query_vector(query_emb, k=2, score_threshold=0.0)
assert isinstance(resp, QueryChunksResponse)
assert len(resp.chunks) == 2
@pytest.mark.asyncio
async def test_chunk_id_conflict(milvus_vec_index, sample_chunks, embedding_dimension):
async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimension):
embeddings = np.random.rand(len(sample_chunks), embedding_dimension).astype(np.float32)
await milvus_vec_index.add_chunks(sample_chunks, embeddings)
coll = Collection(milvus_vec_index.collection_name, using=MILVUS_ALIAS)
ids = coll.query(expr="id >= 0", output_fields=["id"], timeout=30)
flat_ids = [i["id"] for i in ids]
assert len(flat_ids) == len(set(flat_ids))
@pytest.mark.asyncio
async def test_initialize_with_milvus_client(milvus_vec_index, unique_kvstore_config):
kvstore = await kvstore_impl(unique_kvstore_config)
vector_db = VectorDB(
identifier="test_db",
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=128,
metadata={"test_key": "test_value"},
)
test_vector_db_data = vector_db.model_dump_json()
await kvstore.set(f"{VECTOR_DBS_PREFIX}test_db", test_vector_db_data)
tmp_milvus_vec_adapter = MilvusVectorIOAdapter(
config=MilvusVectorIOConfig(
db_path=milvus_vec_index.db_path,
kvstore=unique_kvstore_config,
),
inference_api=None,
files_api=None,
)
await tmp_milvus_vec_adapter.initialize()
vector_db = VectorDB(
identifier="test_db",
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=128,
)
test_vector_db_data = vector_db.model_dump_json()
await tmp_milvus_vec_adapter.kvstore.set(f"{VECTOR_DBS_PREFIX}/test_db", test_vector_db_data)
assert milvus_vec_index.client is not None
assert isinstance(milvus_vec_index.client, MilvusClient)
assert tmp_milvus_vec_adapter.cache is not None
# registering a vector won't update the cache or openai_vector_store collection name
assert (
tmp_milvus_vec_adapter.metadata_collection_name not in tmp_milvus_vec_adapter.cache
or tmp_milvus_vec_adapter.openai_vector_stores
await vector_index.add_chunks(sample_chunks, embeddings)
resp = await vector_index.query_vector(
np.random.rand(embedding_dimension).astype(np.float32),
k=len(sample_chunks),
score_threshold=-1,
)
contents = [chunk.content for chunk in resp.chunks]
assert len(contents) == len(set(contents))
@pytest.mark.asyncio
async def test_persistence_across_adapter_restarts(
tmp_path, milvus_vec_index, mock_inference_api, unique_kvstore_config
):
adapter1 = MilvusVectorIOAdapter(
config=MilvusVectorIOConfig(db_path=milvus_vec_index.db_path, kvstore=unique_kvstore_config),
inference_api=mock_inference_api,
files_api=None,
)
await adapter1.initialize()
async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter):
key = f"{VECTOR_DBS_PREFIX}db1"
dummy = VectorDB(
identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
)
await adapter1.register_vector_db(dummy)
await adapter1.shutdown()
await vector_io_adapter.kvstore.set(key=key, value=json.dumps(dummy.model_dump()))
await adapter1.initialize()
assert "foo_db" in adapter1.cache
await adapter1.shutdown()
await vector_io_adapter.initialize()
@pytest.mark.asyncio
async def test_register_and_unregister_vector_db(milvus_vec_adapter):
try:
connections.disconnect(MILVUS_ALIAS)
except Exception as _:
pass
async def test_persistence_across_adapter_restarts(vector_io_adapter):
await vector_io_adapter.initialize()
dummy = VectorDB(
identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
)
await vector_io_adapter.register_vector_db(dummy)
await vector_io_adapter.shutdown()
connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_adapter.config.db_path)
await vector_io_adapter.initialize()
assert "foo_db" in vector_io_adapter.cache
await vector_io_adapter.shutdown()
async def test_register_and_unregister_vector_db(vector_io_adapter):
unique_id = f"foo_db_{np.random.randint(1e6)}"
dummy = VectorDB(
identifier=unique_id, provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
)
await milvus_vec_adapter.register_vector_db(dummy)
assert dummy.identifier in milvus_vec_adapter.cache
if dummy.identifier in milvus_vec_adapter.cache:
index = milvus_vec_adapter.cache[dummy.identifier].index
if hasattr(index, "client") and hasattr(index.client, "_using"):
index.client._using = MILVUS_ALIAS
await milvus_vec_adapter.unregister_vector_db(dummy.identifier)
assert dummy.identifier not in milvus_vec_adapter.cache
await vector_io_adapter.register_vector_db(dummy)
assert dummy.identifier in vector_io_adapter.cache
await vector_io_adapter.unregister_vector_db(dummy.identifier)
assert dummy.identifier not in vector_io_adapter.cache
@pytest.mark.asyncio
async def test_query_unregistered_raises(milvus_vec_adapter):
async def test_query_unregistered_raises(vector_io_adapter):
fake_emb = np.zeros(8, dtype=np.float32)
with pytest.raises(AttributeError):
await milvus_vec_adapter.query_chunks("no_such_db", fake_emb)
with pytest.raises(ValueError):
await vector_io_adapter.query_chunks("no_such_db", fake_emb)
@pytest.mark.asyncio
async def test_insert_chunks_calls_underlying_index(milvus_vec_adapter):
async def test_insert_chunks_calls_underlying_index(vector_io_adapter):
fake_index = AsyncMock()
milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index)
vector_io_adapter.cache["db1"] = fake_index
chunks = ["chunk1", "chunk2"]
await milvus_vec_adapter.insert_chunks("db1", chunks)
await vector_io_adapter.insert_chunks("db1", chunks)
fake_index.insert_chunks.assert_awaited_once_with(chunks)
@pytest.mark.asyncio
async def test_insert_chunks_missing_db_raises(milvus_vec_adapter):
milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
async def test_insert_chunks_missing_db_raises(vector_io_adapter):
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
with pytest.raises(ValueError):
await milvus_vec_adapter.insert_chunks("db_not_exist", [])
await vector_io_adapter.insert_chunks("db_not_exist", [])
@pytest.mark.asyncio
async def test_query_chunks_calls_underlying_index_and_returns(milvus_vec_adapter):
async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter):
expected = QueryChunksResponse(chunks=[Chunk(content="c1")], scores=[0.1])
fake_index = AsyncMock(query_chunks=AsyncMock(return_value=expected))
milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index)
vector_io_adapter.cache["db1"] = fake_index
response = await milvus_vec_adapter.query_chunks("db1", "my_query", {"param": 1})
response = await vector_io_adapter.query_chunks("db1", "my_query", {"param": 1})
fake_index.query_chunks.assert_awaited_once_with("my_query", {"param": 1})
assert response is expected
@pytest.mark.asyncio
async def test_query_chunks_missing_db_raises(milvus_vec_adapter):
milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
async def test_query_chunks_missing_db_raises(vector_io_adapter):
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
with pytest.raises(ValueError):
await milvus_vec_adapter.query_chunks("db_missing", "q", None)
await vector_io_adapter.query_chunks("db_missing", "q", None)
@pytest.mark.asyncio
async def test_save_openai_vector_store(milvus_vec_adapter):
async def test_save_openai_vector_store(vector_io_adapter):
store_id = "vs_1234"
openai_vector_store = {
"id": store_id,
@ -260,14 +137,13 @@ async def test_save_openai_vector_store(milvus_vec_adapter):
"embedding_model": "test_model",
}
await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store)
await vector_io_adapter._save_openai_vector_store(store_id, openai_vector_store)
assert openai_vector_store["id"] in milvus_vec_adapter.openai_vector_stores
assert milvus_vec_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store
assert openai_vector_store["id"] in vector_io_adapter.openai_vector_stores
assert vector_io_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store
@pytest.mark.asyncio
async def test_update_openai_vector_store(milvus_vec_adapter):
async def test_update_openai_vector_store(vector_io_adapter):
store_id = "vs_1234"
openai_vector_store = {
"id": store_id,
@ -277,14 +153,13 @@ async def test_update_openai_vector_store(milvus_vec_adapter):
"embedding_model": "test_model",
}
await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store)
await vector_io_adapter._save_openai_vector_store(store_id, openai_vector_store)
openai_vector_store["description"] = "Updated description"
await milvus_vec_adapter._update_openai_vector_store(store_id, openai_vector_store)
assert milvus_vec_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store
await vector_io_adapter._update_openai_vector_store(store_id, openai_vector_store)
assert vector_io_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store
@pytest.mark.asyncio
async def test_delete_openai_vector_store(milvus_vec_adapter):
async def test_delete_openai_vector_store(vector_io_adapter):
store_id = "vs_1234"
openai_vector_store = {
"id": store_id,
@ -294,13 +169,12 @@ async def test_delete_openai_vector_store(milvus_vec_adapter):
"embedding_model": "test_model",
}
await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store)
await milvus_vec_adapter._delete_openai_vector_store_from_storage(store_id)
assert openai_vector_store["id"] not in milvus_vec_adapter.openai_vector_stores
await vector_io_adapter._save_openai_vector_store(store_id, openai_vector_store)
await vector_io_adapter._delete_openai_vector_store_from_storage(store_id)
assert openai_vector_store["id"] not in vector_io_adapter.openai_vector_stores
@pytest.mark.asyncio
async def test_load_openai_vector_stores(milvus_vec_adapter):
async def test_load_openai_vector_stores(vector_io_adapter):
store_id = "vs_1234"
openai_vector_store = {
"id": store_id,
@ -310,13 +184,12 @@ async def test_load_openai_vector_stores(milvus_vec_adapter):
"embedding_model": "test_model",
}
await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store)
loaded_stores = await milvus_vec_adapter._load_openai_vector_stores()
await vector_io_adapter._save_openai_vector_store(store_id, openai_vector_store)
loaded_stores = await vector_io_adapter._load_openai_vector_stores()
assert loaded_stores[store_id] == openai_vector_store
@pytest.mark.asyncio
async def test_save_openai_vector_store_file(milvus_vec_adapter, tmp_path_factory):
async def test_save_openai_vector_store_file(vector_io_adapter, tmp_path_factory):
store_id = "vs_1234"
file_id = "file_1234"
@ -334,11 +207,10 @@ async def test_save_openai_vector_store_file(milvus_vec_adapter, tmp_path_factor
]
# validating we don't raise an exception
await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
@pytest.mark.asyncio
async def test_update_openai_vector_store_file(milvus_vec_adapter, tmp_path_factory):
async def test_update_openai_vector_store_file(vector_io_adapter, tmp_path_factory):
store_id = "vs_1234"
file_id = "file_1234"
@ -355,24 +227,23 @@ async def test_update_openai_vector_store_file(milvus_vec_adapter, tmp_path_fact
{"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}}
]
await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
updated_file_info = file_info.copy()
updated_file_info["filename"] = "updated_test_file.txt"
await milvus_vec_adapter._update_openai_vector_store_file(
await vector_io_adapter._update_openai_vector_store_file(
store_id,
file_id,
updated_file_info,
)
loaded_contents = await milvus_vec_adapter._load_openai_vector_store_file(store_id, file_id)
loaded_contents = await vector_io_adapter._load_openai_vector_store_file(store_id, file_id)
assert loaded_contents == updated_file_info
assert loaded_contents != file_info
@pytest.mark.asyncio
async def test_load_openai_vector_store_file_contents(milvus_vec_adapter, tmp_path_factory):
async def test_load_openai_vector_store_file_contents(vector_io_adapter, tmp_path_factory):
store_id = "vs_1234"
file_id = "file_1234"
@ -389,14 +260,13 @@ async def test_load_openai_vector_store_file_contents(milvus_vec_adapter, tmp_pa
{"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}}
]
await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
loaded_contents = await milvus_vec_adapter._load_openai_vector_store_file_contents(store_id, file_id)
loaded_contents = await vector_io_adapter._load_openai_vector_store_file_contents(store_id, file_id)
assert loaded_contents == file_contents
@pytest.mark.asyncio
async def test_delete_openai_vector_store_file_from_storage(milvus_vec_adapter, tmp_path_factory):
async def test_delete_openai_vector_store_file_from_storage(vector_io_adapter, tmp_path_factory):
store_id = "vs_1234"
file_id = "file_1234"
@ -413,8 +283,10 @@ async def test_delete_openai_vector_store_file_from_storage(milvus_vec_adapter,
{"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}}
]
await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
await milvus_vec_adapter._delete_openai_vector_store_file_from_storage(store_id, file_id)
await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
await vector_io_adapter._delete_openai_vector_store_file_from_storage(store_id, file_id)
loaded_contents = await milvus_vec_adapter._load_openai_vector_store_file_contents(store_id, file_id)
loaded_file_info = await vector_io_adapter._load_openai_vector_store_file(store_id, file_id)
assert loaded_file_info == {}
loaded_contents = await vector_io_adapter._load_openai_vector_store_file_contents(store_id, file_id)
assert loaded_contents == []

View file

@ -8,6 +8,7 @@ from unittest.mock import AsyncMock, MagicMock
import pytest
from llama_stack.apis.tools.rag_tool import RAGQueryConfig
from llama_stack.apis.vector_io import (
Chunk,
ChunkMetadata,
@ -17,13 +18,11 @@ from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRunti
class TestRagQuery:
@pytest.mark.asyncio
async def test_query_raises_on_empty_vector_db_ids(self):
rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock())
with pytest.raises(ValueError):
await rag_tool.query(content=MagicMock(), vector_db_ids=[])
@pytest.mark.asyncio
async def test_query_chunk_metadata_handling(self):
rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock())
content = "test query content"
@ -60,3 +59,14 @@ class TestRagQuery:
)
assert expected_metadata_string in result.content[1].text
assert result.content is not None
async def test_query_raises_incorrect_mode(self):
with pytest.raises(ValueError):
RAGQueryConfig(mode="invalid_mode")
@pytest.mark.asyncio
async def test_query_accepts_valid_modes(self):
RAGQueryConfig() # Test default (vector)
RAGQueryConfig(mode="vector") # Test vector
RAGQueryConfig(mode="keyword") # Test keyword
RAGQueryConfig(mode="hybrid") # Test hybrid

View file

@ -112,7 +112,6 @@ class TestValidateEmbedding:
class TestVectorStore:
@pytest.mark.asyncio
async def test_returns_content_from_pdf_data_uri(self):
data_uri = data_url_from_file(DUMMY_PDF_PATH)
doc = RAGDocument(
@ -124,7 +123,7 @@ class TestVectorStore:
content = await content_from_doc(doc)
assert content in DUMMY_PDF_TEXT_CHOICES
@pytest.mark.asyncio
@pytest.mark.allow_network
async def test_downloads_pdf_and_returns_content(self):
# Using GitHub to host the PDF file
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
@ -137,7 +136,7 @@ class TestVectorStore:
content = await content_from_doc(doc)
assert content in DUMMY_PDF_TEXT_CHOICES
@pytest.mark.asyncio
@pytest.mark.allow_network
async def test_downloads_pdf_and_returns_content_with_url_object(self):
# Using GitHub to host the PDF file
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
@ -204,7 +203,6 @@ class TestVectorStore:
class TestVectorDBWithIndex:
@pytest.mark.asyncio
async def test_insert_chunks_without_embeddings(self):
mock_vector_db = MagicMock()
mock_vector_db.embedding_model = "test-model without embeddings"
@ -230,7 +228,6 @@ class TestVectorDBWithIndex:
assert args[0] == chunks
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
@pytest.mark.asyncio
async def test_insert_chunks_with_valid_embeddings(self):
mock_vector_db = MagicMock()
mock_vector_db.embedding_model = "test-model with embeddings"
@ -255,7 +252,6 @@ class TestVectorDBWithIndex:
assert args[0] == chunks
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
@pytest.mark.asyncio
async def test_insert_chunks_with_invalid_embeddings(self):
mock_vector_db = MagicMock()
mock_vector_db.embedding_dimension = 3
@ -295,7 +291,6 @@ class TestVectorDBWithIndex:
mock_inference_api.embeddings.assert_not_called()
mock_index.add_chunks.assert_not_called()
@pytest.mark.asyncio
async def test_insert_chunks_with_partially_precomputed_embeddings(self):
mock_vector_db = MagicMock()
mock_vector_db.embedding_model = "test-model with partial embeddings"

View file

@ -38,14 +38,12 @@ def sample_model():
)
@pytest.mark.asyncio
async def test_registry_initialization(disk_dist_registry):
# Test empty registry
result = await disk_dist_registry.get("nonexistent", "nonexistent")
assert result is None
@pytest.mark.asyncio
async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_model):
print(f"Registering {sample_vector_db}")
await disk_dist_registry.register(sample_vector_db)
@ -64,7 +62,6 @@ async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_m
assert result_model.provider_id == sample_model.provider_id
@pytest.mark.asyncio
async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db, sample_model):
# First populate the disk registry
disk_registry = DiskDistributionRegistry(sqlite_kvstore)
@ -85,7 +82,6 @@ async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db,
assert result_vector_db.provider_id == sample_vector_db.provider_id
@pytest.mark.asyncio
async def test_cached_registry_updates(cached_disk_dist_registry):
new_vector_db = VectorDB(
identifier="test_vector_db_2",
@ -112,7 +108,6 @@ async def test_cached_registry_updates(cached_disk_dist_registry):
assert result_vector_db.provider_id == new_vector_db.provider_id
@pytest.mark.asyncio
async def test_duplicate_provider_registration(cached_disk_dist_registry):
original_vector_db = VectorDB(
identifier="test_vector_db_2",
@ -137,7 +132,6 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry):
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
@pytest.mark.asyncio
async def test_get_all_objects(cached_disk_dist_registry):
# Create multiple test banks
# Create multiple test banks
@ -170,7 +164,6 @@ async def test_get_all_objects(cached_disk_dist_registry):
assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension
@pytest.mark.asyncio
async def test_parse_registry_values_error_handling(sqlite_kvstore):
valid_db = VectorDB(
identifier="valid_vector_db",
@ -209,7 +202,6 @@ async def test_parse_registry_values_error_handling(sqlite_kvstore):
assert invalid_obj is None
@pytest.mark.asyncio
async def test_cached_registry_error_handling(sqlite_kvstore):
valid_db = VectorDB(
identifier="valid_cached_db",

View file

@ -5,14 +5,11 @@
# the root directory of this source tree.
import pytest
from llama_stack.apis.models import ModelType
from llama_stack.distribution.datatypes import ModelWithOwner, User
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
@pytest.mark.asyncio
async def test_registry_cache_with_acl(cached_disk_dist_registry):
model = ModelWithOwner(
identifier="model-acl",
@ -48,7 +45,6 @@ async def test_registry_cache_with_acl(cached_disk_dist_registry):
assert new_model.owner.attributes["teams"] == ["ai-team"]
@pytest.mark.asyncio
async def test_registry_empty_acl(cached_disk_dist_registry):
model = ModelWithOwner(
identifier="model-empty-acl",
@ -85,7 +81,6 @@ async def test_registry_empty_acl(cached_disk_dist_registry):
assert len(all_models) == 2
@pytest.mark.asyncio
async def test_registry_serialization(cached_disk_dist_registry):
attributes = {
"roles": ["admin", "researcher"],

View file

@ -7,7 +7,6 @@
from unittest.mock import MagicMock, Mock, patch
import pytest
import pytest_asyncio
import yaml
from pydantic import TypeAdapter, ValidationError
@ -27,7 +26,7 @@ def _return_model(model):
return model
@pytest_asyncio.fixture
@pytest.fixture
async def test_setup(cached_disk_dist_registry):
mock_inference = Mock()
mock_inference.__provider_spec__ = MagicMock()
@ -41,7 +40,6 @@ async def test_setup(cached_disk_dist_registry):
yield cached_disk_dist_registry, routing_table
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_control_with_cache(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup
@ -106,7 +104,6 @@ async def test_access_control_with_cache(mock_get_authenticated_user, test_setup
await routing_table.get_model("model-admin")
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_control_and_updates(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup
@ -145,7 +142,6 @@ async def test_access_control_and_updates(mock_get_authenticated_user, test_setu
assert model.identifier == "model-updates"
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_control_empty_attributes(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup
@ -170,7 +166,6 @@ async def test_access_control_empty_attributes(mock_get_authenticated_user, test
assert "model-empty-attrs" in model_ids
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_no_user_attributes(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup
@ -201,7 +196,6 @@ async def test_no_user_attributes(mock_get_authenticated_user, test_setup):
assert all_models.data[0].identifier == "model-public-2"
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_automatic_access_attributes(mock_get_authenticated_user, test_setup):
"""Test that newly created resources inherit access attributes from their creator."""
@ -246,7 +240,7 @@ async def test_automatic_access_attributes(mock_get_authenticated_user, test_set
assert model.identifier == "auto-access-model"
@pytest_asyncio.fixture
@pytest.fixture
async def test_setup_with_access_policy(cached_disk_dist_registry):
mock_inference = Mock()
mock_inference.__provider_spec__ = MagicMock()
@ -281,7 +275,6 @@ async def test_setup_with_access_policy(cached_disk_dist_registry):
yield routing_table
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_policy(mock_get_authenticated_user, test_setup_with_access_policy):
routing_table = test_setup_with_access_policy

View file

@ -202,7 +202,6 @@ def test_http_auth_request_payload(http_client, valid_api_key, mock_auth_endpoin
assert "param2" in payload["request"]["params"]
@pytest.mark.asyncio
async def test_http_middleware_with_access_attributes(mock_http_middleware, mock_scope):
"""Test HTTP middleware behavior with access attributes"""
middleware, mock_app = mock_http_middleware

View file

@ -9,7 +9,6 @@ import sys
from typing import Any, Protocol
from unittest.mock import AsyncMock, MagicMock
import pytest
from pydantic import BaseModel, Field
from llama_stack.apis.inference import Inference
@ -66,7 +65,6 @@ class SampleImpl:
pass
@pytest.mark.asyncio
async def test_resolve_impls_basic():
# Create a real provider spec
provider_spec = InlineProviderSpec(

View file

@ -7,13 +7,10 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock
import pytest
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.distribution.server.server import create_dynamic_typed_route, create_sse_event, sse_generator
@pytest.mark.asyncio
async def test_sse_generator_basic():
# An AsyncIterator wrapped in an Awaitable, just like our web methods
async def async_event_gen():
@ -35,7 +32,6 @@ async def test_sse_generator_basic():
assert seen_events[1] == create_sse_event("Test event 2")
@pytest.mark.asyncio
async def test_sse_generator_client_disconnected():
# An AsyncIterator wrapped in an Awaitable, just like our web methods
async def async_event_gen():
@ -58,7 +54,6 @@ async def test_sse_generator_client_disconnected():
assert seen_events[0] == create_sse_event("Test event 1")
@pytest.mark.asyncio
async def test_sse_generator_client_disconnected_before_response_starts():
# Disconnect before the response starts
async def async_event_gen():
@ -75,7 +70,6 @@ async def test_sse_generator_client_disconnected_before_response_starts():
assert len(seen_events) == 0
@pytest.mark.asyncio
async def test_sse_generator_error_before_response_starts():
# Raise an error before the response starts
async def async_event_gen():
@ -93,7 +87,6 @@ async def test_sse_generator_error_before_response_starts():
assert 'data: {"error":' in seen_events[0]
@pytest.mark.asyncio
async def test_paginated_response_url_setting():
"""Test that PaginatedResponse gets url set to route path."""

View file

@ -42,7 +42,6 @@ def create_test_chat_completion(
)
@pytest.mark.asyncio
async def test_inference_store_pagination_basic():
"""Test basic pagination functionality."""
with TemporaryDirectory() as tmp_dir:
@ -88,7 +87,6 @@ async def test_inference_store_pagination_basic():
assert result3.has_more is False
@pytest.mark.asyncio
async def test_inference_store_pagination_ascending():
"""Test pagination with ascending order."""
with TemporaryDirectory() as tmp_dir:
@ -123,7 +121,6 @@ async def test_inference_store_pagination_ascending():
assert result2.has_more is True
@pytest.mark.asyncio
async def test_inference_store_pagination_with_model_filter():
"""Test pagination combined with model filtering."""
with TemporaryDirectory() as tmp_dir:
@ -161,7 +158,6 @@ async def test_inference_store_pagination_with_model_filter():
assert result2.has_more is False
@pytest.mark.asyncio
async def test_inference_store_pagination_invalid_after():
"""Test error handling for invalid 'after' parameter."""
with TemporaryDirectory() as tmp_dir:
@ -174,7 +170,6 @@ async def test_inference_store_pagination_invalid_after():
await store.list_chat_completions(after="non-existent", limit=2)
@pytest.mark.asyncio
async def test_inference_store_pagination_no_limit():
"""Test pagination behavior when no limit is specified."""
with TemporaryDirectory() as tmp_dir:

View file

@ -44,7 +44,6 @@ def create_test_response_input(content: str, input_id: str) -> OpenAIResponseInp
)
@pytest.mark.asyncio
async def test_responses_store_pagination_basic():
"""Test basic pagination functionality for responses store."""
with TemporaryDirectory() as tmp_dir:
@ -90,7 +89,6 @@ async def test_responses_store_pagination_basic():
assert result3.has_more is False
@pytest.mark.asyncio
async def test_responses_store_pagination_ascending():
"""Test pagination with ascending order."""
with TemporaryDirectory() as tmp_dir:
@ -125,7 +123,6 @@ async def test_responses_store_pagination_ascending():
assert result2.has_more is True
@pytest.mark.asyncio
async def test_responses_store_pagination_with_model_filter():
"""Test pagination combined with model filtering."""
with TemporaryDirectory() as tmp_dir:
@ -163,7 +160,6 @@ async def test_responses_store_pagination_with_model_filter():
assert result2.has_more is False
@pytest.mark.asyncio
async def test_responses_store_pagination_invalid_after():
"""Test error handling for invalid 'after' parameter."""
with TemporaryDirectory() as tmp_dir:
@ -176,7 +172,6 @@ async def test_responses_store_pagination_invalid_after():
await store.list_responses(after="non-existent", limit=2)
@pytest.mark.asyncio
async def test_responses_store_pagination_no_limit():
"""Test pagination behavior when no limit is specified."""
with TemporaryDirectory() as tmp_dir:
@ -205,7 +200,6 @@ async def test_responses_store_pagination_no_limit():
assert result.has_more is False
@pytest.mark.asyncio
async def test_responses_store_get_response_object():
"""Test retrieving a single response object."""
with TemporaryDirectory() as tmp_dir:
@ -230,7 +224,6 @@ async def test_responses_store_get_response_object():
await store.get_response_object("non-existent")
@pytest.mark.asyncio
async def test_responses_store_input_items_pagination():
"""Test pagination functionality for input items."""
with TemporaryDirectory() as tmp_dir:
@ -308,7 +301,6 @@ async def test_responses_store_input_items_pagination():
await store.list_response_input_items("test-resp", before="some-id", after="other-id")
@pytest.mark.asyncio
async def test_responses_store_input_items_before_pagination():
"""Test before pagination functionality for input items."""
with TemporaryDirectory() as tmp_dir:

View file

@ -14,7 +14,6 @@ from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemyS
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
@pytest.mark.asyncio
async def test_sqlite_sqlstore():
with TemporaryDirectory() as tmp_dir:
db_name = "test.db"
@ -66,7 +65,6 @@ async def test_sqlite_sqlstore():
assert result.has_more is False
@pytest.mark.asyncio
async def test_sqlstore_pagination_basic():
"""Test basic pagination functionality at the SQL store level."""
with TemporaryDirectory() as tmp_dir:
@ -131,7 +129,6 @@ async def test_sqlstore_pagination_basic():
assert result3.has_more is False
@pytest.mark.asyncio
async def test_sqlstore_pagination_with_filter():
"""Test pagination with WHERE conditions."""
with TemporaryDirectory() as tmp_dir:
@ -184,7 +181,6 @@ async def test_sqlstore_pagination_with_filter():
assert result2.has_more is False
@pytest.mark.asyncio
async def test_sqlstore_pagination_ascending_order():
"""Test pagination with ascending order."""
with TemporaryDirectory() as tmp_dir:
@ -233,7 +229,6 @@ async def test_sqlstore_pagination_ascending_order():
assert result2.has_more is True
@pytest.mark.asyncio
async def test_sqlstore_pagination_multi_column_ordering_error():
"""Test that multi-column ordering raises an error when using cursor pagination."""
with TemporaryDirectory() as tmp_dir:
@ -271,7 +266,6 @@ async def test_sqlstore_pagination_multi_column_ordering_error():
assert result.data[0]["id"] == "task1"
@pytest.mark.asyncio
async def test_sqlstore_pagination_cursor_requires_order_by():
"""Test that cursor pagination requires order_by parameter."""
with TemporaryDirectory() as tmp_dir:
@ -289,7 +283,6 @@ async def test_sqlstore_pagination_cursor_requires_order_by():
)
@pytest.mark.asyncio
async def test_sqlstore_pagination_error_handling():
"""Test error handling for invalid columns and cursor IDs."""
with TemporaryDirectory() as tmp_dir:
@ -339,7 +332,6 @@ async def test_sqlstore_pagination_error_handling():
)
@pytest.mark.asyncio
async def test_sqlstore_pagination_custom_key_column():
"""Test pagination with custom primary key column (not 'id')."""
with TemporaryDirectory() as tmp_dir:

View file

@ -7,8 +7,6 @@
from tempfile import TemporaryDirectory
from unittest.mock import patch
import pytest
from llama_stack.distribution.access_control.access_control import default_policy, is_action_allowed
from llama_stack.distribution.access_control.datatypes import Action
from llama_stack.distribution.datatypes import User
@ -18,7 +16,6 @@ from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemyS
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
@pytest.mark.asyncio
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
async def test_authorized_fetch_with_where_sql_access_control(mock_get_authenticated_user):
"""Test that fetch_all works correctly with where_sql for access control"""
@ -81,7 +78,6 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
assert row["title"] == "User Document"
@pytest.mark.asyncio
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
async def test_sql_policy_consistency(mock_get_authenticated_user):
"""Test that SQL WHERE clause logic exactly matches is_action_allowed policy logic"""
@ -153,7 +149,9 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
policy_ids = set()
for scenario in test_scenarios:
sql_record = SqlRecord(
record_id=scenario["id"], table_name="resources", access_attributes=scenario["access_attributes"]
record_id=scenario["id"],
table_name="resources",
owner=User(principal="test-user", attributes=scenario["access_attributes"]),
)
if is_action_allowed(policy, Action.READ, sql_record, user):
@ -166,7 +164,6 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
)
@pytest.mark.asyncio
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
async def test_authorized_store_user_attribute_capture(mock_get_authenticated_user):
"""Test that user attributes are properly captured during insert"""