From 01c222e12f8b5e6c1cf8c2661bfe69e5680415c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Thu, 10 Jul 2025 15:16:08 +0200 Subject: [PATCH 01/11] ci: run all APIs integration tests (#2646) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? We are now automatically building the list of integration test to run. In that process, eval and files and being tested now. This is pending https://github.com/meta-llama/llama-stack/pull/2628 Signed-off-by: Sébastien Han --- .github/actions/setup-ollama/action.yml | 2 ++ .github/workflows/integration-tests.yml | 39 +++++++++++++++---------- tests/integration/fixtures/common.py | 1 + 3 files changed, 26 insertions(+), 16 deletions(-) diff --git a/.github/actions/setup-ollama/action.yml b/.github/actions/setup-ollama/action.yml index da24839c2..37a369a9a 100644 --- a/.github/actions/setup-ollama/action.yml +++ b/.github/actions/setup-ollama/action.yml @@ -8,4 +8,6 @@ runs: run: | docker run -d --name ollama -p 11434:11434 docker.io/leseb/ollama-with-models # TODO: rebuild an ollama image with llama-guard3:1b + echo "Verifying Ollama status..." + timeout 30 bash -c 'while ! curl -s -L http://127.0.0.1:11434; do sleep 1 && echo "."; done' docker exec ollama ollama pull llama-guard3:1b diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index b102191f2..c46100c38 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -18,16 +18,33 @@ concurrency: cancel-in-progress: true jobs: - test-matrix: + discover-tests: runs-on: ubuntu-latest + outputs: + test-type: ${{ steps.generate-matrix.outputs.test-type }} + steps: + - name: Checkout repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Generate test matrix + id: generate-matrix + run: | + # Get test directories dynamically, excluding non-test directories + TEST_TYPES=$(find tests/integration -maxdepth 1 -mindepth 1 -type d -printf "%f\n" | + grep -Ev "^(__pycache__|fixtures|test_cases)$" | + sort | jq -R -s -c 'split("\n")[:-1]') + echo "test-type=$TEST_TYPES" >> $GITHUB_OUTPUT + + test-matrix: + needs: discover-tests + runs-on: ubuntu-latest + strategy: + fail-fast: false matrix: - # Listing tests manually since some of them currently fail - # TODO: generate matrix list from tests/integration when fixed - test-type: [agents, inference, datasets, inspect, safety, scoring, post_training, providers, tool_runtime, vector_io] + test-type: ${{ fromJson(needs.discover-tests.outputs.test-type) }} client-type: [library, server] python-version: ["3.12", "3.13"] - fail-fast: false # we want to run all tests regardless of failure steps: - name: Checkout repository @@ -51,23 +68,13 @@ jobs: free -h df -h - - name: Verify Ollama status is OK - if: matrix.client-type == 'http' - run: | - echo "Verifying Ollama status..." - ollama_status=$(curl -s -L http://127.0.0.1:8321/v1/providers/ollama|jq --raw-output .health.status) - echo "Ollama status: $ollama_status" - if [ "$ollama_status" != "OK" ]; then - echo "Ollama health check failed" - exit 1 - fi - - name: Run Integration Tests env: OLLAMA_INFERENCE_MODEL: "llama3.2:3b-instruct-fp16" # for server tests ENABLE_OLLAMA: "ollama" # for server tests OLLAMA_URL: "http://0.0.0.0:11434" SAFETY_MODEL: "llama-guard3:1b" + LLAMA_STACK_CLIENT_TIMEOUT: "300" # Increased timeout for eval operations # Use 'shell' to get pipefail behavior # https://docs.github.com/en/actions/reference/workflow-syntax-for-github-actions#exit-codes-and-error-action-preference # TODO: write a precommit hook to detect if a test contains a pipe but does not use 'shell: bash' diff --git a/tests/integration/fixtures/common.py b/tests/integration/fixtures/common.py index 28a047ea5..749793b64 100644 --- a/tests/integration/fixtures/common.py +++ b/tests/integration/fixtures/common.py @@ -218,6 +218,7 @@ def llama_stack_client(request, provider_data): return LlamaStackClient( base_url=base_url, provider_data=provider_data, + timeout=int(os.environ.get("LLAMA_STACK_CLIENT_TIMEOUT", "30")), ) # check if this looks like a URL using proper URL parsing From 81ebaf6e9a1744c36941baba5295337a8a7eb2af Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Thu, 10 Jul 2025 10:19:12 -0400 Subject: [PATCH 02/11] fix: properly represent paths in server logs (#2698) # What does this PR do? currently when logging the run yaml, if there are path objects in the object they are represented as: ``` external_providers_dir: !!python/object/apply:pathlib.PosixPath - '~' - .llama - providers.d ``` now, with a config.model_dump(mode="json"), it works properly ``` external_providers_dir: ~/.llama/providers.d ``` Signed-off-by: Charlie Doern --- llama_stack/distribution/server/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index a7e860a36..f2e29a6f9 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -445,7 +445,7 @@ def main(args: argparse.Namespace | None = None): logger.info(log_line) logger.info("Run configuration:") - safe_config = redact_sensitive_fields(config.model_dump()) + safe_config = redact_sensitive_fields(config.model_dump(mode="json")) logger.info(yaml.dump(safe_config, indent=2)) app = FastAPI( From bbe0199bb70d56c3ec86d7066b3db0712afdd3f7 Mon Sep 17 00:00:00 2001 From: Nathan Weinberg <31703736+nathan-weinberg@users.noreply.github.com> Date: Thu, 10 Jul 2025 10:47:59 -0400 Subject: [PATCH 03/11] chore: update pre-commit hook versions (#2708) While investigating the `uv.lock` changes made in https://github.com/meta-llama/llama-stack/pull/2695 I noticed several of the pre-commit hook versions were out of date This PR updates them and fixes some new `ruff` errors --------- Signed-off-by: Nathan Weinberg --- .pre-commit-config.yaml | 10 +++++----- llama_stack/distribution/utils/context.py | 5 +---- .../providers/utils/telemetry/trace_protocol.py | 6 ++---- 3 files changed, 8 insertions(+), 13 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ebbadefa6..f3a5a718b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,7 +29,7 @@ repos: - id: check-toml - repo: https://github.com/Lucas-C/pre-commit-hooks - rev: v1.5.4 + rev: v1.5.5 hooks: - id: insert-license files: \.py$|\.sh$ @@ -38,7 +38,7 @@ repos: - docs/license_header.txt - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.9.4 + rev: v0.12.2 hooks: - id: ruff args: [ --fix ] @@ -46,14 +46,14 @@ repos: - id: ruff-format - repo: https://github.com/adamchainz/blacken-docs - rev: 1.19.0 + rev: 1.19.1 hooks: - id: blacken-docs additional_dependencies: - black==24.3.0 - repo: https://github.com/astral-sh/uv-pre-commit - rev: 0.7.8 + rev: 0.7.20 hooks: - id: uv-lock - id: uv-export @@ -66,7 +66,7 @@ repos: ] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.15.0 + rev: v1.16.1 hooks: - id: mypy additional_dependencies: diff --git a/llama_stack/distribution/utils/context.py b/llama_stack/distribution/utils/context.py index 3fcd3315f..24b249890 100644 --- a/llama_stack/distribution/utils/context.py +++ b/llama_stack/distribution/utils/context.py @@ -6,12 +6,9 @@ from collections.abc import AsyncGenerator from contextvars import ContextVar -from typing import TypeVar - -T = TypeVar("T") -def preserve_contexts_async_generator( +def preserve_contexts_async_generator[T]( gen: AsyncGenerator[T, None], context_vars: list[ContextVar] ) -> AsyncGenerator[T, None]: """ diff --git a/llama_stack/providers/utils/telemetry/trace_protocol.py b/llama_stack/providers/utils/telemetry/trace_protocol.py index eb6d8b331..916f7622a 100644 --- a/llama_stack/providers/utils/telemetry/trace_protocol.py +++ b/llama_stack/providers/utils/telemetry/trace_protocol.py @@ -9,14 +9,12 @@ import inspect import json from collections.abc import AsyncGenerator, Callable from functools import wraps -from typing import Any, TypeVar +from typing import Any from pydantic import BaseModel from llama_stack.models.llama.datatypes import Primitive -T = TypeVar("T") - def serialize_value(value: Any) -> Primitive: return str(_prepare_for_json(value)) @@ -44,7 +42,7 @@ def _prepare_for_json(value: Any) -> str: return str(value) -def trace_protocol(cls: type[T]) -> type[T]: +def trace_protocol[T](cls: type[T]) -> type[T]: """ A class decorator that automatically traces all methods in a protocol/base class and its inheriting classes. From 83c6b200674b94d3e32a033398a79ba06380805e Mon Sep 17 00:00:00 2001 From: Mustafa Elbehery Date: Thu, 10 Jul 2025 16:53:38 +0200 Subject: [PATCH 04/11] chore(api): add `mypy` coverage to `cli/stack` (#2650) # What does this PR do? This PR adds static type coverage to `llama-stack` Part of https://github.com/meta-llama/llama-stack/issues/2647 ## Test Plan Signed-off-by: Mustafa Elbehery --- llama_stack/cli/stack/_build.py | 24 +++++++++++++++++++----- pyproject.toml | 1 - 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/llama_stack/cli/stack/_build.py b/llama_stack/cli/stack/_build.py index 5d88b1d82..b573b2edc 100644 --- a/llama_stack/cli/stack/_build.py +++ b/llama_stack/cli/stack/_build.py @@ -93,7 +93,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None: ) sys.exit(1) elif args.providers: - providers = dict() + providers_list: dict[str, str | list[str]] = dict() for api_provider in args.providers.split(","): if "=" not in api_provider: cprint( @@ -112,7 +112,15 @@ def run_stack_build_command(args: argparse.Namespace) -> None: ) sys.exit(1) if provider in providers_for_api: - providers.setdefault(api, []).append(provider) + if api not in providers_list: + providers_list[api] = [] + # Use type guarding to ensure we have a list + provider_value = providers_list[api] + if isinstance(provider_value, list): + provider_value.append(provider) + else: + # Convert string to list and append + providers_list[api] = [provider_value, provider] else: cprint( f"{provider} is not a valid provider for the {api} API.", @@ -121,7 +129,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None: ) sys.exit(1) distribution_spec = DistributionSpec( - providers=providers, + providers=providers_list, description=",".join(args.providers), ) if not args.image_type: @@ -182,7 +190,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None: cprint("Tip: use to see options for the providers.\n", color="green", file=sys.stderr) - providers = dict() + providers: dict[str, str | list[str]] = dict() for api, providers_for_api in get_provider_registry().items(): available_providers = [x for x in providers_for_api.keys() if x not in ("remote", "remote::sample")] if not available_providers: @@ -371,10 +379,16 @@ def _run_stack_build_command_from_build_config( if not image_name: raise ValueError("Please specify an image name when building a venv image") + # At this point, image_name should be guaranteed to be a string + if image_name is None: + raise ValueError("image_name should not be None after validation") + if template_name: build_dir = DISTRIBS_BASE_DIR / template_name build_file_path = build_dir / f"{template_name}-build.yaml" else: + if image_name is None: + raise ValueError("image_name cannot be None") build_dir = DISTRIBS_BASE_DIR / image_name build_file_path = build_dir / f"{image_name}-build.yaml" @@ -395,7 +409,7 @@ def _run_stack_build_command_from_build_config( build_file_path, image_name, template_or_config=template_name or config_path or str(build_file_path), - run_config=run_config_file, + run_config=run_config_file.as_posix() if run_config_file else None, ) if return_code != 0: raise RuntimeError(f"Failed to build image {image_name}") diff --git a/pyproject.toml b/pyproject.toml index 30598e5e3..d84a823a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -226,7 +226,6 @@ follow_imports = "silent" exclude = [ # As we fix more and more of these, we should remove them from the list "^llama_stack/cli/download\\.py$", - "^llama_stack/cli/stack/_build\\.py$", "^llama_stack/distribution/build\\.py$", "^llama_stack/distribution/client\\.py$", "^llama_stack/distribution/request_headers\\.py$", From b18f4d1ccfeb3321285a2cefb03baf515e978d7f Mon Sep 17 00:00:00 2001 From: Nathan Weinberg <31703736+nathan-weinberg@users.noreply.github.com> Date: Thu, 10 Jul 2025 11:24:10 -0400 Subject: [PATCH 05/11] ci: add config for pre-commit.ci (#2712) # What does this PR do? the project already had some config setup for https://pre-commit.ci/ this commit adds additional explicit fields Closes #2711 **IMPORTANT:** A project maintainer must add `pre-commit.ci` to this repo for this to work - this can be done via https://pre-commit.ci/ Signed-off-by: Nathan Weinberg --- .pre-commit-config.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f3a5a718b..3c744c6bc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -133,3 +133,8 @@ repos: ci: autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks autoupdate_commit_msg: ⬆ [pre-commit.ci] pre-commit autoupdate + autofix_prs: true + autoupdate_branch: '' + autoupdate_schedule: weekly + skip: [] + submodules: false From 6a6b66ae4f965de4cd3cd71a4320e868fa777b95 Mon Sep 17 00:00:00 2001 From: Francisco Arceo Date: Thu, 10 Jul 2025 14:22:13 -0400 Subject: [PATCH 06/11] chore: Adding unit tests for OpenAI vector stores and migrating SQLite-vec registry to kvstore (#2665) # What does this PR do? This PR refactors and the VectorIO backend logic for `sqlite-vec` and adds unit tests and fixtures to make it easy to test both `sqlite-vec` and `milvus`. Key changes: - `sqlite-vec` migrated to `kvstore` registry - added in-memory cache for sqlite-vec to be consistent with `milvus` - default fixtures moved to `conftest.py` - removed redundant tests from sqlite`-vec` - made `test_vector_io_openai_vector_stores.py` more easily extensible ## Test Plan Unit tests added testing inline providers. --------- Signed-off-by: Francisco Javier Arceo --- .../providers/vector_io/inline_milvus.md | 2 +- .../providers/vector_io/inline_sqlite-vec.md | 6 +- .../providers/vector_io/inline_sqlite_vec.md | 6 +- .../inline/vector_io/milvus/config.py | 2 +- .../inline/vector_io/sqlite_vec/config.py | 14 +- .../inline/vector_io/sqlite_vec/sqlite_vec.py | 313 +++++++----------- .../remote/vector_io/milvus/milvus.py | 8 + llama_stack/templates/open-benchmark/run.yaml | 3 + llama_stack/templates/starter/run.yaml | 3 + tests/unit/providers/vector_io/conftest.py | 157 +++++++++ .../providers/vector_io/test_sqlite_vec.py | 35 +- .../test_vector_io_openai_vector_stores.py | 297 ++++++----------- 12 files changed, 422 insertions(+), 424 deletions(-) diff --git a/docs/source/providers/vector_io/inline_milvus.md b/docs/source/providers/vector_io/inline_milvus.md index be7340c9d..3b3aad3fc 100644 --- a/docs/source/providers/vector_io/inline_milvus.md +++ b/docs/source/providers/vector_io/inline_milvus.md @@ -11,7 +11,7 @@ Please refer to the remote provider documentation. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `db_path` | `` | No | PydanticUndefined | | -| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | | +| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend (SQLite only for now) | | `consistency_level` | `` | No | Strong | The consistency level of the Milvus server | ## Sample Configuration diff --git a/docs/source/providers/vector_io/inline_sqlite-vec.md b/docs/source/providers/vector_io/inline_sqlite-vec.md index fd3ec1dc4..ae7c45b21 100644 --- a/docs/source/providers/vector_io/inline_sqlite-vec.md +++ b/docs/source/providers/vector_io/inline_sqlite-vec.md @@ -205,12 +205,16 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| -| `db_path` | `` | No | PydanticUndefined | | +| `db_path` | `` | No | PydanticUndefined | Path to the SQLite database file | +| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend (SQLite only for now) | ## Sample Configuration ```yaml db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec.db +kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec_registry.db ``` diff --git a/docs/source/providers/vector_io/inline_sqlite_vec.md b/docs/source/providers/vector_io/inline_sqlite_vec.md index e4b69c9ab..7e14bb8bd 100644 --- a/docs/source/providers/vector_io/inline_sqlite_vec.md +++ b/docs/source/providers/vector_io/inline_sqlite_vec.md @@ -10,12 +10,16 @@ Please refer to the sqlite-vec provider documentation. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| -| `db_path` | `` | No | PydanticUndefined | | +| `db_path` | `` | No | PydanticUndefined | Path to the SQLite database file | +| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend (SQLite only for now) | ## Sample Configuration ```yaml db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec.db +kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec_registry.db ``` diff --git a/llama_stack/providers/inline/vector_io/milvus/config.py b/llama_stack/providers/inline/vector_io/milvus/config.py index a05ca1670..8cbd056be 100644 --- a/llama_stack/providers/inline/vector_io/milvus/config.py +++ b/llama_stack/providers/inline/vector_io/milvus/config.py @@ -18,7 +18,7 @@ from llama_stack.schema_utils import json_schema_type @json_schema_type class MilvusVectorIOConfig(BaseModel): db_path: str - kvstore: KVStoreConfig + kvstore: KVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)") consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong") @classmethod diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/config.py b/llama_stack/providers/inline/vector_io/sqlite_vec/config.py index 4c57f4aba..525ed4b1f 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/config.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/config.py @@ -6,14 +6,24 @@ from typing import Any -from pydantic import BaseModel +from pydantic import BaseModel, Field + +from llama_stack.providers.utils.kvstore.config import ( + KVStoreConfig, + SqliteKVStoreConfig, +) class SQLiteVectorIOConfig(BaseModel): - db_path: str + db_path: str = Field(description="Path to the SQLite database file") + kvstore: KVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)") @classmethod def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]: return { "db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + "sqlite_vec.db", + "kvstore": SqliteKVStoreConfig.sample_run_config( + __distro_dir__=__distro_dir__, + db_name="sqlite_vec_registry.db", + ), } diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index 7e977635a..6acd85c56 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -24,6 +24,8 @@ from llama_stack.apis.vector_io import ( VectorIO, ) from llama_stack.providers.datatypes import VectorDBsProtocolPrivate +from llama_stack.providers.utils.kvstore import kvstore_impl +from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ( RERANKER_TYPE_RRF, @@ -40,6 +42,13 @@ KEYWORD_SEARCH = "keyword" HYBRID_SEARCH = "hybrid" SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH, HYBRID_SEARCH} +VERSION = "v3" +VECTOR_DBS_PREFIX = f"vector_dbs:sqlite_vec:{VERSION}::" +VECTOR_INDEX_PREFIX = f"vector_index:sqlite_vec:{VERSION}::" +OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:sqlite_vec:{VERSION}::" +OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:sqlite_vec:{VERSION}::" +OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:sqlite_vec:{VERSION}::" + def serialize_vector(vector: list[float]) -> bytes: """Serialize a list of floats into a compact binary representation.""" @@ -117,13 +126,14 @@ class SQLiteVecIndex(EmbeddingIndex): - An FTS5 table (fts_chunks_{bank_id}) for full-text keyword search. """ - def __init__(self, dimension: int, db_path: str, bank_id: str): + def __init__(self, dimension: int, db_path: str, bank_id: str, kvstore: KVStore | None = None): self.dimension = dimension self.db_path = db_path self.bank_id = bank_id self.metadata_table = f"chunks_{bank_id}".replace("-", "_") self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_") self.fts_table = f"fts_chunks_{bank_id}".replace("-", "_") + self.kvstore = kvstore @classmethod async def create(cls, dimension: int, db_path: str, bank_id: str): @@ -425,27 +435,116 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc self.files_api = files_api self.cache: dict[str, VectorDBWithIndex] = {} self.openai_vector_stores: dict[str, dict[str, Any]] = {} + self.kvstore: KVStore | None = None async def initialize(self) -> None: - def _setup_connection(): - # Open a connection to the SQLite database (the file is specified in the config). + self.kvstore = await kvstore_impl(self.config.kvstore) + + start_key = VECTOR_DBS_PREFIX + end_key = f"{VECTOR_DBS_PREFIX}\xff" + stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key) + for db_json in stored_vector_dbs: + vector_db = VectorDB.model_validate_json(db_json) + index = await SQLiteVecIndex.create( + vector_db.embedding_dimension, + self.config.db_path, + vector_db.identifier, + ) + self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) + + # load any existing OpenAI vector stores + self.openai_vector_stores = await self._load_openai_vector_stores() + + async def shutdown(self) -> None: + # nothing to do since we don't maintain a persistent connection + pass + + async def list_vector_dbs(self) -> list[VectorDB]: + return [v.vector_db for v in self.cache.values()] + + async def register_vector_db(self, vector_db: VectorDB) -> None: + index = await SQLiteVecIndex.create( + vector_db.embedding_dimension, + self.config.db_path, + vector_db.identifier, + ) + self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) + + async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None: + if vector_db_id in self.cache: + return self.cache[vector_db_id] + + if self.vector_db_store is None: + raise ValueError(f"Vector DB {vector_db_id} not found") + + vector_db = self.vector_db_store.get_vector_db(vector_db_id) + if not vector_db: + raise ValueError(f"Vector DB {vector_db_id} not found") + + index = VectorDBWithIndex( + vector_db=vector_db, + index=SQLiteVecIndex( + dimension=vector_db.embedding_dimension, + db_path=self.config.db_path, + bank_id=vector_db.identifier, + kvstore=self.kvstore, + ), + inference_api=self.inference_api, + ) + self.cache[vector_db_id] = index + return index + + async def unregister_vector_db(self, vector_db_id: str) -> None: + if vector_db_id not in self.cache: + logger.warning(f"Vector DB {vector_db_id} not found") + return + await self.cache[vector_db_id].index.delete() + del self.cache[vector_db_id] + + # OpenAI Vector Store Mixin abstract method implementations + async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: + """Save vector store metadata to SQLite database.""" + assert self.kvstore is not None + key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}" + await self.kvstore.set(key=key, value=json.dumps(store_info)) + self.openai_vector_stores[store_id] = store_info + + async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]: + """Load all vector store metadata from SQLite database.""" + assert self.kvstore is not None + start_key = OPENAI_VECTOR_STORES_PREFIX + end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff" + stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key) + stores = {} + for store_data in stored_openai_stores: + store_info = json.loads(store_data) + stores[store_info["id"]] = store_info + return stores + + async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: + """Update vector store metadata in SQLite database.""" + assert self.kvstore is not None + key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}" + await self.kvstore.set(key=key, value=json.dumps(store_info)) + self.openai_vector_stores[store_id] = store_info + + async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None: + """Delete vector store metadata from SQLite database.""" + assert self.kvstore is not None + key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}" + await self.kvstore.delete(key) + if store_id in self.openai_vector_stores: + del self.openai_vector_stores[store_id] + + async def _save_openai_vector_store_file( + self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]] + ) -> None: + """Save vector store file metadata to SQLite database.""" + + def _create_or_store(): connection = _create_sqlite_connection(self.config.db_path) cur = connection.cursor() try: - # Create a table to persist vector DB registrations. - cur.execute(""" - CREATE TABLE IF NOT EXISTS vector_dbs ( - id TEXT PRIMARY KEY, - metadata TEXT - ); - """) - # Create a table to persist OpenAI vector stores. - cur.execute(""" - CREATE TABLE IF NOT EXISTS openai_vector_stores ( - id TEXT PRIMARY KEY, - metadata TEXT - ); - """) # Create a table to persist OpenAI vector store files. cur.execute(""" CREATE TABLE IF NOT EXISTS openai_vector_store_files ( @@ -464,168 +563,6 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc ); """) connection.commit() - # Load any existing vector DB registrations. - cur.execute("SELECT metadata FROM vector_dbs") - vector_db_rows = cur.fetchall() - return vector_db_rows - finally: - cur.close() - connection.close() - - vector_db_rows = await asyncio.to_thread(_setup_connection) - - # Load existing vector DBs - for row in vector_db_rows: - vector_db_data = row[0] - vector_db = VectorDB.model_validate_json(vector_db_data) - index = await SQLiteVecIndex.create( - vector_db.embedding_dimension, - self.config.db_path, - vector_db.identifier, - ) - self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) - - # Load existing OpenAI vector stores using the mixin method - self.openai_vector_stores = await self._load_openai_vector_stores() - - async def shutdown(self) -> None: - # nothing to do since we don't maintain a persistent connection - pass - - async def register_vector_db(self, vector_db: VectorDB) -> None: - def _register_db(): - connection = _create_sqlite_connection(self.config.db_path) - cur = connection.cursor() - try: - cur.execute( - "INSERT OR REPLACE INTO vector_dbs (id, metadata) VALUES (?, ?)", - (vector_db.identifier, vector_db.model_dump_json()), - ) - connection.commit() - finally: - cur.close() - connection.close() - - await asyncio.to_thread(_register_db) - index = await SQLiteVecIndex.create( - vector_db.embedding_dimension, - self.config.db_path, - vector_db.identifier, - ) - self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) - - async def list_vector_dbs(self) -> list[VectorDB]: - return [v.vector_db for v in self.cache.values()] - - async def unregister_vector_db(self, vector_db_id: str) -> None: - if vector_db_id not in self.cache: - logger.warning(f"Vector DB {vector_db_id} not found") - return - await self.cache[vector_db_id].index.delete() - del self.cache[vector_db_id] - - def _delete_vector_db_from_registry(): - connection = _create_sqlite_connection(self.config.db_path) - cur = connection.cursor() - try: - cur.execute("DELETE FROM vector_dbs WHERE id = ?", (vector_db_id,)) - connection.commit() - finally: - cur.close() - connection.close() - - await asyncio.to_thread(_delete_vector_db_from_registry) - - # OpenAI Vector Store Mixin abstract method implementations - async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: - """Save vector store metadata to SQLite database.""" - - def _store(): - connection = _create_sqlite_connection(self.config.db_path) - cur = connection.cursor() - try: - cur.execute( - "INSERT OR REPLACE INTO openai_vector_stores (id, metadata) VALUES (?, ?)", - (store_id, json.dumps(store_info)), - ) - connection.commit() - except Exception as e: - logger.error(f"Error saving openai vector store {store_id}: {e}") - raise - finally: - cur.close() - connection.close() - - try: - await asyncio.to_thread(_store) - except Exception as e: - logger.error(f"Error saving openai vector store {store_id}: {e}") - raise - - async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]: - """Load all vector store metadata from SQLite database.""" - - def _load(): - connection = _create_sqlite_connection(self.config.db_path) - cur = connection.cursor() - try: - cur.execute("SELECT metadata FROM openai_vector_stores") - rows = cur.fetchall() - return rows - finally: - cur.close() - connection.close() - - rows = await asyncio.to_thread(_load) - stores = {} - for row in rows: - store_data = row[0] - store_info = json.loads(store_data) - stores[store_info["id"]] = store_info - return stores - - async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: - """Update vector store metadata in SQLite database.""" - - def _update(): - connection = _create_sqlite_connection(self.config.db_path) - cur = connection.cursor() - try: - cur.execute( - "UPDATE openai_vector_stores SET metadata = ? WHERE id = ?", - (json.dumps(store_info), store_id), - ) - connection.commit() - finally: - cur.close() - connection.close() - - await asyncio.to_thread(_update) - - async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None: - """Delete vector store metadata from SQLite database.""" - - def _delete(): - connection = _create_sqlite_connection(self.config.db_path) - cur = connection.cursor() - try: - cur.execute("DELETE FROM openai_vector_stores WHERE id = ?", (store_id,)) - connection.commit() - finally: - cur.close() - connection.close() - - await asyncio.to_thread(_delete) - - async def _save_openai_vector_store_file( - self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]] - ) -> None: - """Save vector store file metadata to SQLite database.""" - - def _store(): - connection = _create_sqlite_connection(self.config.db_path) - cur = connection.cursor() - try: cur.execute( "INSERT OR REPLACE INTO openai_vector_store_files (store_id, file_id, metadata) VALUES (?, ?, ?)", (store_id, file_id, json.dumps(file_info)), @@ -643,7 +580,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc connection.close() try: - await asyncio.to_thread(_store) + await asyncio.to_thread(_create_or_store) except Exception as e: logger.error(f"Error saving openai vector store file {store_id} {file_id}: {e}") raise @@ -722,6 +659,10 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc cur.execute( "DELETE FROM openai_vector_store_files WHERE store_id = ? AND file_id = ?", (store_id, file_id) ) + cur.execute( + "DELETE FROM openai_vector_store_files_contents WHERE store_id = ? AND file_id = ?", + (store_id, file_id), + ) connection.commit() finally: cur.close() @@ -730,15 +671,17 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc await asyncio.to_thread(_delete) async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None: - if vector_db_id not in self.cache: - raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}") + index = await self._get_and_cache_vector_db_index(vector_db_id) + if not index: + raise ValueError(f"Vector DB {vector_db_id} not found") # The VectorDBWithIndex helper is expected to compute embeddings via the inference_api # and then call our index's add_chunks. - await self.cache[vector_db_id].insert_chunks(chunks) + await index.insert_chunks(chunks) async def query_chunks( self, vector_db_id: str, query: Any, params: dict[str, Any] | None = None ) -> QueryChunksResponse: - if vector_db_id not in self.cache: + index = await self._get_and_cache_vector_db_index(vector_db_id) + if not index: raise ValueError(f"Vector DB {vector_db_id} not found") - return await self.cache[vector_db_id].query_chunks(query, params) + return await index.query_chunks(query, params) diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index 1f65e580e..a06130fd0 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -61,6 +61,11 @@ class MilvusIndex(EmbeddingIndex): self.consistency_level = consistency_level self.kvstore = kvstore + async def initialize(self): + # MilvusIndex does not require explicit initialization + # TODO: could move collection creation into initialization but it is not really necessary + pass + async def delete(self): if await asyncio.to_thread(self.client.has_collection, self.collection_name): await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name) @@ -199,6 +204,9 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP if vector_db_id in self.cache: return self.cache[vector_db_id] + if self.vector_db_store is None: + raise ValueError(f"Vector DB {vector_db_id} not found") + vector_db = await self.vector_db_store.get_vector_db(vector_db_id) if not vector_db: raise ValueError(f"Vector DB {vector_db_id} not found") diff --git a/llama_stack/templates/open-benchmark/run.yaml b/llama_stack/templates/open-benchmark/run.yaml index 76c029864..0b368ebc9 100644 --- a/llama_stack/templates/open-benchmark/run.yaml +++ b/llama_stack/templates/open-benchmark/run.yaml @@ -39,6 +39,9 @@ providers: provider_type: inline::sqlite-vec config: db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/sqlite_vec.db + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/sqlite_vec_registry.db - provider_id: ${env.ENABLE_CHROMADB:+chromadb} provider_type: remote::chromadb config: diff --git a/llama_stack/templates/starter/run.yaml b/llama_stack/templates/starter/run.yaml index b3dfe32d5..888a2c3bf 100644 --- a/llama_stack/templates/starter/run.yaml +++ b/llama_stack/templates/starter/run.yaml @@ -144,6 +144,9 @@ providers: provider_type: inline::sqlite-vec config: db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db - provider_id: ${env.ENABLE_MILVUS:=__disabled__} provider_type: inline::milvus config: diff --git a/tests/unit/providers/vector_io/conftest.py b/tests/unit/providers/vector_io/conftest.py index 5eaca8a25..4a9639326 100644 --- a/tests/unit/providers/vector_io/conftest.py +++ b/tests/unit/providers/vector_io/conftest.py @@ -8,10 +8,18 @@ import random import numpy as np import pytest +from pymilvus import MilvusClient, connections +from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, ChunkMetadata +from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig +from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig +from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter +from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter EMBEDDING_DIMENSION = 384 +COLLECTION_PREFIX = "test_collection" +MILVUS_ALIAS = "test_milvus" @pytest.fixture @@ -50,7 +58,156 @@ def sample_chunks(): return sample +@pytest.fixture(scope="session") +def sample_chunks_with_metadata(): + """Generates chunks that force multiple batches for a single document to expose ID conflicts.""" + n, k = 10, 3 + sample = [ + Chunk( + content=f"Sentence {i} from document {j}", + metadata={"document_id": f"document-{j}"}, + chunk_metadata=ChunkMetadata( + document_id=f"document-{j}", + chunk_id=f"document-{j}-chunk-{i}", + source=f"example source-{j}-{i}", + ), + ) + for j in range(k) + for i in range(n) + ] + return sample + + @pytest.fixture(scope="session") def sample_embeddings(sample_chunks): np.random.seed(42) return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks]) + + +@pytest.fixture(scope="session") +def sample_embeddings_with_metadata(sample_chunks_with_metadata): + np.random.seed(42) + return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks_with_metadata]) + + +@pytest.fixture(params=["milvus", "sqlite_vec"]) +def vector_provider(request): + return request.param + + +@pytest.fixture(scope="session") +def mock_inference_api(embedding_dimension): + class MockInferenceAPI: + async def embed_batch(self, texts: list[str]) -> list[list[float]]: + return [np.random.rand(embedding_dimension).astype(np.float32).tolist() for _ in texts] + + return MockInferenceAPI() + + +@pytest.fixture +async def unique_kvstore_config(tmp_path_factory): + # Generate a unique filename for this test + unique_id = f"test_kv_{np.random.randint(1e6)}" + temp_dir = tmp_path_factory.getbasetemp() + db_path = str(temp_dir / f"{unique_id}.db") + + return SqliteKVStoreConfig(db_path=db_path) + + +@pytest.fixture(scope="session") +def sqlite_vec_db_path(tmp_path_factory): + db_path = str(tmp_path_factory.getbasetemp() / "test.db") + return db_path + + +@pytest.fixture +async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory): + temp_dir = tmp_path_factory.getbasetemp() + db_path = str(temp_dir / f"test_sqlite_vec_{np.random.randint(1e6)}.db") + bank_id = f"sqlite_vec_bank_{np.random.randint(1e6)}" + index = SQLiteVecIndex(embedding_dimension, db_path, bank_id) + await index.initialize() + index.db_path = db_path + yield index + index.delete() + + +@pytest.fixture +async def sqlite_vec_adapter(sqlite_vec_db_path, mock_inference_api, embedding_dimension): + config = SQLiteVectorIOConfig( + db_path=sqlite_vec_db_path, + kvstore=SqliteKVStoreConfig(), + ) + adapter = SQLiteVecVectorIOAdapter( + config=config, + inference_api=mock_inference_api, + files_api=None, + ) + collection_id = f"sqlite_test_collection_{np.random.randint(1e6)}" + await adapter.initialize() + await adapter.register_vector_db( + VectorDB( + identifier=collection_id, + provider_id="test_provider", + embedding_model="test_model", + embedding_dimension=embedding_dimension, + ) + ) + adapter.test_collection_id = collection_id + yield adapter + await adapter.shutdown() + + +@pytest.fixture(scope="session") +def milvus_vec_db_path(tmp_path_factory): + db_path = str(tmp_path_factory.getbasetemp() / "test_milvus.db") + return db_path + + +@pytest.fixture +async def milvus_vec_index(milvus_vec_db_path, embedding_dimension): + client = MilvusClient(milvus_vec_db_path) + name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}" + connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_db_path) + index = MilvusIndex(client, name, consistency_level="Strong") + index.db_path = milvus_vec_db_path + yield index + + +@pytest.fixture +async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api): + config = MilvusVectorIOConfig( + db_path=milvus_vec_db_path, + kvstore=SqliteKVStoreConfig(), + ) + adapter = MilvusVectorIOAdapter( + config=config, + inference_api=mock_inference_api, + files_api=None, + ) + await adapter.initialize() + await adapter.register_vector_db( + VectorDB( + identifier=adapter.metadata_collection_name, + provider_id="test_provider", + embedding_model="test_model", + embedding_dimension=128, + ) + ) + yield adapter + await adapter.shutdown() + + +@pytest.fixture +def vector_io_adapter(vector_provider, request): + """Returns the appropriate vector IO adapter based on the provider parameter.""" + if vector_provider == "milvus": + return request.getfixturevalue("milvus_vec_adapter") + else: + return request.getfixturevalue("sqlite_vec_adapter") + + +@pytest.fixture +def vector_index(vector_provider, request): + """Returns appropriate vector index based on provider parameter""" + return request.getfixturevalue(f"{vector_provider}_vec_index") diff --git a/tests/unit/providers/vector_io/test_sqlite_vec.py b/tests/unit/providers/vector_io/test_sqlite_vec.py index 5d9d92cf3..8579c31bb 100644 --- a/tests/unit/providers/vector_io/test_sqlite_vec.py +++ b/tests/unit/providers/vector_io/test_sqlite_vec.py @@ -34,7 +34,7 @@ def loop(): return asyncio.new_event_loop() -@pytest_asyncio.fixture(scope="session", autouse=True) +@pytest_asyncio.fixture async def sqlite_vec_index(embedding_dimension, tmp_path_factory): temp_dir = tmp_path_factory.getbasetemp() db_path = str(temp_dir / "test_sqlite.db") @@ -44,38 +44,15 @@ async def sqlite_vec_index(embedding_dimension, tmp_path_factory): @pytest.mark.asyncio -async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings): - await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=2) - connection = _create_sqlite_connection(sqlite_vec_index.db_path) - cur = connection.cursor() - cur.execute(f"SELECT COUNT(*) FROM {sqlite_vec_index.metadata_table}") - count = cur.fetchone()[0] - assert count == len(sample_chunks) - cur.close() - connection.close() - - -@pytest.mark.asyncio -async def test_query_chunks_vector(sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension): - await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) - query_embedding = np.random.rand(embedding_dimension).astype(np.float32) - response = await sqlite_vec_index.query_vector(query_embedding, k=2, score_threshold=0.0) - assert isinstance(response, QueryChunksResponse) - assert len(response.chunks) == 2 - - -@pytest.mark.xfail(reason="Chunk Metadata not yet supported for SQLite-vec", strict=True) -async def test_query_chunk_metadata(sqlite_vec_index, sample_chunks, sample_embeddings): - await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) - query_embedding = sample_embeddings[0] - response = await sqlite_vec_index.query_vector(query_embedding, k=2, score_threshold=0.0) - assert response.chunks[-1].chunk_metadata == sample_chunks[-1].chunk_metadata +async def test_query_chunk_metadata(sqlite_vec_index, sample_chunks_with_metadata, sample_embeddings_with_metadata): + await sqlite_vec_index.add_chunks(sample_chunks_with_metadata, sample_embeddings_with_metadata) + response = await sqlite_vec_index.query_vector(sample_embeddings_with_metadata[-1], k=2, score_threshold=0.0) + assert response.chunks[0].chunk_metadata == sample_chunks_with_metadata[-1].chunk_metadata @pytest.mark.asyncio async def test_query_chunks_full_text_search(sqlite_vec_index, sample_chunks, sample_embeddings): await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) - query_string = "Sentence 5" response = await sqlite_vec_index.query_keyword(k=3, score_threshold=0.0, query_string=query_string) @@ -148,7 +125,7 @@ async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dime assert len(chunk_ids) == len(set(chunk_ids)), "Duplicate chunk IDs detected across batches!" -@pytest_asyncio.fixture(scope="session") +@pytest.fixture(scope="session") async def sqlite_vec_adapter(sqlite_connection): config = type("Config", (object,), {"db_path": ":memory:"}) # Mock config with in-memory database adapter = SQLiteVecVectorIOAdapter(config=config, inference_api=None) diff --git a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py index 0a109e833..5f7926ce6 100644 --- a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py +++ b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py @@ -4,253 +4,142 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import asyncio +import json import time from unittest.mock import AsyncMock import numpy as np import pytest -import pytest_asyncio -from pymilvus import Collection, MilvusClient, connections from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, QueryChunksResponse -from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig -from llama_stack.providers.remote.vector_io.milvus.milvus import VECTOR_DBS_PREFIX, MilvusIndex, MilvusVectorIOAdapter -from llama_stack.providers.utils.kvstore import kvstore_impl +from llama_stack.providers.remote.vector_io.milvus.milvus import VECTOR_DBS_PREFIX -# TODO: Refactor these to be for inline vector-io providers -MILVUS_ALIAS = "test_milvus" -COLLECTION_PREFIX = "test_collection" - - -@pytest.fixture(scope="session") -def loop(): - return asyncio.new_event_loop() - - -@pytest.fixture(scope="session") -def mock_inference_api(embedding_dimension): - class MockInferenceAPI: - async def embed_batch(self, texts: list[str]) -> list[list[float]]: - return [np.random.rand(embedding_dimension).astype(np.float32).tolist() for _ in texts] - - return MockInferenceAPI() - - -@pytest_asyncio.fixture -async def unique_kvstore_config(tmp_path_factory): - # Generate a unique filename for this test - unique_id = f"test_kv_{np.random.randint(1e6)}" - temp_dir = tmp_path_factory.getbasetemp() - db_path = str(temp_dir / f"{unique_id}.db") - - return SqliteKVStoreConfig(db_path=db_path) - - -@pytest_asyncio.fixture(scope="session", autouse=True) -async def milvus_vec_index(embedding_dimension, tmp_path_factory): - temp_dir = tmp_path_factory.getbasetemp() - db_path = str(temp_dir / "test_milvus.db") - client = MilvusClient(db_path) - name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}" - connections.connect(alias=MILVUS_ALIAS, uri=db_path) - index = MilvusIndex(client, name, consistency_level="Strong") - index.db_path = db_path - yield index - - -@pytest_asyncio.fixture(scope="session") -async def milvus_vec_adapter(milvus_vec_index, mock_inference_api): - config = MilvusVectorIOConfig( - db_path=milvus_vec_index.db_path, - kvstore=SqliteKVStoreConfig(), - ) - adapter = MilvusVectorIOAdapter( - config=config, - inference_api=mock_inference_api, - files_api=None, - ) - await adapter.initialize() - await adapter.register_vector_db( - VectorDB( - identifier=adapter.metadata_collection_name, - provider_id="test_provider", - embedding_model="test_model", - embedding_dimension=128, - ) - ) - yield adapter - await adapter.shutdown() +# This test is a unit test for the inline VectoerIO providers. This should only contain +# tests which are specific to this class. More general (API-level) tests should be placed in +# tests/integration/vector_io/ +# +# How to run this test: +# +# pytest tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py \ +# -v -s --tb=short --disable-warnings --asyncio-mode=auto @pytest.mark.asyncio -async def test_cache_contains_initial_collection(milvus_vec_adapter): - coll_name = milvus_vec_adapter.metadata_collection_name - assert coll_name in milvus_vec_adapter.cache +async def test_initialize_index(vector_index): + await vector_index.initialize() @pytest.mark.asyncio -async def test_add_chunks(milvus_vec_index, sample_chunks, sample_embeddings): - await milvus_vec_index.add_chunks(sample_chunks, sample_embeddings) - resp = await milvus_vec_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1) +async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embeddings): + vector_index.delete() + vector_index.initialize() + await vector_index.add_chunks(sample_chunks, sample_embeddings) + resp = await vector_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1) assert resp.chunks[0].content == sample_chunks[0].content + vector_index.delete() @pytest.mark.asyncio -async def test_query_chunks_vector(milvus_vec_index, sample_chunks, sample_embeddings, embedding_dimension): - await milvus_vec_index.add_chunks(sample_chunks, sample_embeddings) - query_emb = np.random.rand(embedding_dimension).astype(np.float32) - resp = await milvus_vec_index.query_vector(query_emb, k=2, score_threshold=0.0) - assert isinstance(resp, QueryChunksResponse) - assert len(resp.chunks) == 2 - - -@pytest.mark.asyncio -async def test_chunk_id_conflict(milvus_vec_index, sample_chunks, embedding_dimension): +async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimension): embeddings = np.random.rand(len(sample_chunks), embedding_dimension).astype(np.float32) - await milvus_vec_index.add_chunks(sample_chunks, embeddings) - coll = Collection(milvus_vec_index.collection_name, using=MILVUS_ALIAS) - ids = coll.query(expr="id >= 0", output_fields=["id"], timeout=30) - flat_ids = [i["id"] for i in ids] - assert len(flat_ids) == len(set(flat_ids)) + await vector_index.add_chunks(sample_chunks, embeddings) + resp = await vector_index.query_vector( + np.random.rand(embedding_dimension).astype(np.float32), + k=len(sample_chunks), + score_threshold=-1, + ) + + contents = [chunk.content for chunk in resp.chunks] + assert len(contents) == len(set(contents)) @pytest.mark.asyncio -async def test_initialize_with_milvus_client(milvus_vec_index, unique_kvstore_config): - kvstore = await kvstore_impl(unique_kvstore_config) - vector_db = VectorDB( - identifier="test_db", - provider_id="test_provider", - embedding_model="test_model", - embedding_dimension=128, - metadata={"test_key": "test_value"}, - ) - test_vector_db_data = vector_db.model_dump_json() - await kvstore.set(f"{VECTOR_DBS_PREFIX}test_db", test_vector_db_data) - tmp_milvus_vec_adapter = MilvusVectorIOAdapter( - config=MilvusVectorIOConfig( - db_path=milvus_vec_index.db_path, - kvstore=unique_kvstore_config, - ), - inference_api=None, - files_api=None, - ) - await tmp_milvus_vec_adapter.initialize() - - vector_db = VectorDB( - identifier="test_db", - provider_id="test_provider", - embedding_model="test_model", - embedding_dimension=128, - ) - test_vector_db_data = vector_db.model_dump_json() - await tmp_milvus_vec_adapter.kvstore.set(f"{VECTOR_DBS_PREFIX}/test_db", test_vector_db_data) - - assert milvus_vec_index.client is not None - assert isinstance(milvus_vec_index.client, MilvusClient) - assert tmp_milvus_vec_adapter.cache is not None - # registering a vector won't update the cache or openai_vector_store collection name - assert ( - tmp_milvus_vec_adapter.metadata_collection_name not in tmp_milvus_vec_adapter.cache - or tmp_milvus_vec_adapter.openai_vector_stores - ) - - -@pytest.mark.asyncio -async def test_persistence_across_adapter_restarts( - tmp_path, milvus_vec_index, mock_inference_api, unique_kvstore_config -): - adapter1 = MilvusVectorIOAdapter( - config=MilvusVectorIOConfig(db_path=milvus_vec_index.db_path, kvstore=unique_kvstore_config), - inference_api=mock_inference_api, - files_api=None, - ) - await adapter1.initialize() +async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter): + key = f"{VECTOR_DBS_PREFIX}db1" dummy = VectorDB( identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128 ) - await adapter1.register_vector_db(dummy) - await adapter1.shutdown() + await vector_io_adapter.kvstore.set(key=key, value=json.dumps(dummy.model_dump())) - await adapter1.initialize() - assert "foo_db" in adapter1.cache - await adapter1.shutdown() + await vector_io_adapter.initialize() @pytest.mark.asyncio -async def test_register_and_unregister_vector_db(milvus_vec_adapter): - try: - connections.disconnect(MILVUS_ALIAS) - except Exception as _: - pass +async def test_persistence_across_adapter_restarts(vector_io_adapter): + await vector_io_adapter.initialize() + dummy = VectorDB( + identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128 + ) + await vector_io_adapter.register_vector_db(dummy) + await vector_io_adapter.shutdown() - connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_adapter.config.db_path) + await vector_io_adapter.initialize() + assert "foo_db" in vector_io_adapter.cache + await vector_io_adapter.shutdown() + + +@pytest.mark.asyncio +async def test_register_and_unregister_vector_db(vector_io_adapter): unique_id = f"foo_db_{np.random.randint(1e6)}" dummy = VectorDB( identifier=unique_id, provider_id="test_provider", embedding_model="test_model", embedding_dimension=128 ) - await milvus_vec_adapter.register_vector_db(dummy) - assert dummy.identifier in milvus_vec_adapter.cache - - if dummy.identifier in milvus_vec_adapter.cache: - index = milvus_vec_adapter.cache[dummy.identifier].index - if hasattr(index, "client") and hasattr(index.client, "_using"): - index.client._using = MILVUS_ALIAS - - await milvus_vec_adapter.unregister_vector_db(dummy.identifier) - assert dummy.identifier not in milvus_vec_adapter.cache + await vector_io_adapter.register_vector_db(dummy) + assert dummy.identifier in vector_io_adapter.cache + await vector_io_adapter.unregister_vector_db(dummy.identifier) + assert dummy.identifier not in vector_io_adapter.cache @pytest.mark.asyncio -async def test_query_unregistered_raises(milvus_vec_adapter): +async def test_query_unregistered_raises(vector_io_adapter): fake_emb = np.zeros(8, dtype=np.float32) - with pytest.raises(AttributeError): - await milvus_vec_adapter.query_chunks("no_such_db", fake_emb) + with pytest.raises(ValueError): + await vector_io_adapter.query_chunks("no_such_db", fake_emb) @pytest.mark.asyncio -async def test_insert_chunks_calls_underlying_index(milvus_vec_adapter): +async def test_insert_chunks_calls_underlying_index(vector_io_adapter): fake_index = AsyncMock() - milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index) + vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index) chunks = ["chunk1", "chunk2"] - await milvus_vec_adapter.insert_chunks("db1", chunks) + await vector_io_adapter.insert_chunks("db1", chunks) fake_index.insert_chunks.assert_awaited_once_with(chunks) @pytest.mark.asyncio -async def test_insert_chunks_missing_db_raises(milvus_vec_adapter): - milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None) +async def test_insert_chunks_missing_db_raises(vector_io_adapter): + vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None) with pytest.raises(ValueError): - await milvus_vec_adapter.insert_chunks("db_not_exist", []) + await vector_io_adapter.insert_chunks("db_not_exist", []) @pytest.mark.asyncio -async def test_query_chunks_calls_underlying_index_and_returns(milvus_vec_adapter): +async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter): expected = QueryChunksResponse(chunks=[Chunk(content="c1")], scores=[0.1]) fake_index = AsyncMock(query_chunks=AsyncMock(return_value=expected)) - milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index) + vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index) - response = await milvus_vec_adapter.query_chunks("db1", "my_query", {"param": 1}) + response = await vector_io_adapter.query_chunks("db1", "my_query", {"param": 1}) fake_index.query_chunks.assert_awaited_once_with("my_query", {"param": 1}) assert response is expected @pytest.mark.asyncio -async def test_query_chunks_missing_db_raises(milvus_vec_adapter): - milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None) +async def test_query_chunks_missing_db_raises(vector_io_adapter): + vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None) with pytest.raises(ValueError): - await milvus_vec_adapter.query_chunks("db_missing", "q", None) + await vector_io_adapter.query_chunks("db_missing", "q", None) @pytest.mark.asyncio -async def test_save_openai_vector_store(milvus_vec_adapter): +async def test_save_openai_vector_store(vector_io_adapter): store_id = "vs_1234" openai_vector_store = { "id": store_id, @@ -260,14 +149,14 @@ async def test_save_openai_vector_store(milvus_vec_adapter): "embedding_model": "test_model", } - await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store) + await vector_io_adapter._save_openai_vector_store(store_id, openai_vector_store) - assert openai_vector_store["id"] in milvus_vec_adapter.openai_vector_stores - assert milvus_vec_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store + assert openai_vector_store["id"] in vector_io_adapter.openai_vector_stores + assert vector_io_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store @pytest.mark.asyncio -async def test_update_openai_vector_store(milvus_vec_adapter): +async def test_update_openai_vector_store(vector_io_adapter): store_id = "vs_1234" openai_vector_store = { "id": store_id, @@ -277,14 +166,14 @@ async def test_update_openai_vector_store(milvus_vec_adapter): "embedding_model": "test_model", } - await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store) + await vector_io_adapter._save_openai_vector_store(store_id, openai_vector_store) openai_vector_store["description"] = "Updated description" - await milvus_vec_adapter._update_openai_vector_store(store_id, openai_vector_store) - assert milvus_vec_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store + await vector_io_adapter._update_openai_vector_store(store_id, openai_vector_store) + assert vector_io_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store @pytest.mark.asyncio -async def test_delete_openai_vector_store(milvus_vec_adapter): +async def test_delete_openai_vector_store(vector_io_adapter): store_id = "vs_1234" openai_vector_store = { "id": store_id, @@ -294,13 +183,13 @@ async def test_delete_openai_vector_store(milvus_vec_adapter): "embedding_model": "test_model", } - await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store) - await milvus_vec_adapter._delete_openai_vector_store_from_storage(store_id) - assert openai_vector_store["id"] not in milvus_vec_adapter.openai_vector_stores + await vector_io_adapter._save_openai_vector_store(store_id, openai_vector_store) + await vector_io_adapter._delete_openai_vector_store_from_storage(store_id) + assert openai_vector_store["id"] not in vector_io_adapter.openai_vector_stores @pytest.mark.asyncio -async def test_load_openai_vector_stores(milvus_vec_adapter): +async def test_load_openai_vector_stores(vector_io_adapter): store_id = "vs_1234" openai_vector_store = { "id": store_id, @@ -310,13 +199,13 @@ async def test_load_openai_vector_stores(milvus_vec_adapter): "embedding_model": "test_model", } - await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store) - loaded_stores = await milvus_vec_adapter._load_openai_vector_stores() + await vector_io_adapter._save_openai_vector_store(store_id, openai_vector_store) + loaded_stores = await vector_io_adapter._load_openai_vector_stores() assert loaded_stores[store_id] == openai_vector_store @pytest.mark.asyncio -async def test_save_openai_vector_store_file(milvus_vec_adapter, tmp_path_factory): +async def test_save_openai_vector_store_file(vector_io_adapter, tmp_path_factory): store_id = "vs_1234" file_id = "file_1234" @@ -334,11 +223,11 @@ async def test_save_openai_vector_store_file(milvus_vec_adapter, tmp_path_factor ] # validating we don't raise an exception - await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents) + await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents) @pytest.mark.asyncio -async def test_update_openai_vector_store_file(milvus_vec_adapter, tmp_path_factory): +async def test_update_openai_vector_store_file(vector_io_adapter, tmp_path_factory): store_id = "vs_1234" file_id = "file_1234" @@ -355,24 +244,24 @@ async def test_update_openai_vector_store_file(milvus_vec_adapter, tmp_path_fact {"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}} ] - await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents) + await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents) updated_file_info = file_info.copy() updated_file_info["filename"] = "updated_test_file.txt" - await milvus_vec_adapter._update_openai_vector_store_file( + await vector_io_adapter._update_openai_vector_store_file( store_id, file_id, updated_file_info, ) - loaded_contents = await milvus_vec_adapter._load_openai_vector_store_file(store_id, file_id) + loaded_contents = await vector_io_adapter._load_openai_vector_store_file(store_id, file_id) assert loaded_contents == updated_file_info assert loaded_contents != file_info @pytest.mark.asyncio -async def test_load_openai_vector_store_file_contents(milvus_vec_adapter, tmp_path_factory): +async def test_load_openai_vector_store_file_contents(vector_io_adapter, tmp_path_factory): store_id = "vs_1234" file_id = "file_1234" @@ -389,14 +278,14 @@ async def test_load_openai_vector_store_file_contents(milvus_vec_adapter, tmp_pa {"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}} ] - await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents) + await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents) - loaded_contents = await milvus_vec_adapter._load_openai_vector_store_file_contents(store_id, file_id) + loaded_contents = await vector_io_adapter._load_openai_vector_store_file_contents(store_id, file_id) assert loaded_contents == file_contents @pytest.mark.asyncio -async def test_delete_openai_vector_store_file_from_storage(milvus_vec_adapter, tmp_path_factory): +async def test_delete_openai_vector_store_file_from_storage(vector_io_adapter, tmp_path_factory): store_id = "vs_1234" file_id = "file_1234" @@ -413,8 +302,8 @@ async def test_delete_openai_vector_store_file_from_storage(milvus_vec_adapter, {"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}} ] - await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents) - await milvus_vec_adapter._delete_openai_vector_store_file_from_storage(store_id, file_id) + await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents) + await vector_io_adapter._delete_openai_vector_store_file_from_storage(store_id, file_id) - loaded_contents = await milvus_vec_adapter._load_openai_vector_store_file_contents(store_id, file_id) + loaded_contents = await vector_io_adapter._load_openai_vector_store_file_contents(store_id, file_id) assert loaded_contents == [] From 0bbff91c7ef2614c98e1af3578fb10980c0cfd36 Mon Sep 17 00:00:00 2001 From: Nathan Weinberg <31703736+nathan-weinberg@users.noreply.github.com> Date: Thu, 10 Jul 2025 14:47:54 -0400 Subject: [PATCH 07/11] docs: fix a few broken things in the CONTRIBUTING.md (#2714) # What does this PR do? "dev" dependencies were moved in pyproject.toml typo with guidance around automatic doc generation Signed-off-by: Nathan Weinberg --- CONTRIBUTING.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b9b25cedf..304c4dd26 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -66,7 +66,7 @@ You can install the dependencies by running: ```bash cd llama-stack -uv sync --extra dev +uv sync --group dev uv pip install -e . source .venv/bin/activate ``` @@ -168,7 +168,7 @@ manually as they are auto-generated. ### Updating the provider documentation -If you have made changes to a provider's configuration, you should run `./scripts/distro_codegen.py` +If you have made changes to a provider's configuration, you should run `./scripts/provider_codegen.py` to re-generate the documentation. You should not change `docs/source/.../providers/` files manually as they are auto-generated. Note that the provider "description" field will be used to generate the provider documentation. From 9f04bc6d1af4bb70512fc1f09c911c8c0a060401 Mon Sep 17 00:00:00 2001 From: Nathan Weinberg <31703736+nathan-weinberg@users.noreply.github.com> Date: Thu, 10 Jul 2025 16:14:10 -0400 Subject: [PATCH 08/11] chore: move "install.sh" script into "scripts" dir (#2719) # What does this PR do? "install.sh" is something that a general user might not use e.g. it is specific to using the "ollama" inference provider cleanup the top-level structure of the repo by moving it into the "scripts" dir and updating the relevant references accordingly Signed-off-by: Nathan Weinberg --- .github/workflows/install-script-ci.yml | 8 ++++---- README.md | 2 +- install.sh => scripts/install.sh | 0 3 files changed, 5 insertions(+), 5 deletions(-) rename install.sh => scripts/install.sh (100%) diff --git a/.github/workflows/install-script-ci.yml b/.github/workflows/install-script-ci.yml index 2eb234c77..d711444e8 100644 --- a/.github/workflows/install-script-ci.yml +++ b/.github/workflows/install-script-ci.yml @@ -3,10 +3,10 @@ name: Installer CI on: pull_request: paths: - - 'install.sh' + - 'scripts/install.sh' push: paths: - - 'install.sh' + - 'scripts/install.sh' schedule: - cron: '0 2 * * *' # every day at 02:00 UTC @@ -16,11 +16,11 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2 - name: Run ShellCheck on install.sh - run: shellcheck install.sh + run: shellcheck scripts/install.sh smoke-test: needs: lint runs-on: ubuntu-latest steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2 - name: Run installer end-to-end - run: ./install.sh + run: ./scripts/install.sh diff --git a/README.md b/README.md index 1bebf6b19..9148ce05d 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,7 @@ As more providers start supporting Llama 4, you can use them in Llama Stack as w To try Llama Stack locally, run: ```bash -curl -LsSf https://github.com/meta-llama/llama-stack/raw/main/install.sh | bash +curl -LsSf https://github.com/meta-llama/llama-stack/raw/main/scripts/install.sh | bash ``` ### Overview diff --git a/install.sh b/scripts/install.sh similarity index 100% rename from install.sh rename to scripts/install.sh From 5fe3027cbfd823f420b493bf07537be5d8d65436 Mon Sep 17 00:00:00 2001 From: Nathan Weinberg <31703736+nathan-weinberg@users.noreply.github.com> Date: Thu, 10 Jul 2025 17:06:10 -0400 Subject: [PATCH 09/11] chore: remove "rfc" directory and move original rfc to "docs" (#2718) # What does this PR do? the "rfc" directory has only a single document in it, and its the original RFC for creating Llama Stack simply the project directory structure by moving this into the "docs" directory and renaming it to "original_rfc" to preserve the context of the doc ## Why did you do this? A simplified top-level directory structure helps keep the project simpler and prevents misleading new contributors into thinking we use it (we really don't) --------- Signed-off-by: Nathan Weinberg Co-authored-by: raghotham --- rfcs/RFC-0001-llama-stack.md => docs/original_rfc.md | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) rename rfcs/RFC-0001-llama-stack.md => docs/original_rfc.md (96%) diff --git a/rfcs/RFC-0001-llama-stack.md b/docs/original_rfc.md similarity index 96% rename from rfcs/RFC-0001-llama-stack.md rename to docs/original_rfc.md index 222487bd6..dc95a04cb 100644 --- a/rfcs/RFC-0001-llama-stack.md +++ b/docs/original_rfc.md @@ -1,5 +1,7 @@ # The Llama Stack API +*Originally authored Jul 23, 2024* + **Authors:** * Meta: @raghotham, @ashwinb, @hjshah, @jspisak @@ -24,7 +26,7 @@ Meta releases weights of both the pretrained and instruction fine-tuned Llama mo ### Model Lifecycle -![Figure 1: Model Life Cycle](../docs/resources/model-lifecycle.png) +![Figure 1: Model Life Cycle](resources/model-lifecycle.png) For each of the operations that need to be performed (e.g. fine tuning, inference, evals etc) during the model life cycle, we identified the capabilities as toolchain APIs that are needed. Some of these capabilities are primitive operations like inference while other capabilities like synthetic data generation are composed of other capabilities. The list of APIs we have identified to support the lifecycle of Llama models is below: @@ -37,7 +39,7 @@ For each of the operations that need to be performed (e.g. fine tuning, inferenc ### Agentic System -![Figure 2: Agentic System](../docs/resources/agentic-system.png) +![Figure 2: Agentic System](resources/agentic-system.png) In addition to the model lifecycle, we considered the different components involved in an agentic system. Specifically around tool calling and shields. Since the model may decide to call tools, a single model inference call is not enough. What’s needed is an agentic loop consisting of tool calls and inference. The model provides separate tokens representing end-of-message and end-of-turn. A message represents a possible stopping point for execution where the model can inform the execution environment that a tool call needs to be made. The execution environment, upon execution, adds back the result to the context window and makes another inference call. This process can get repeated until an end-of-turn token is generated. Note that as of today, in the OSS world, such a “loop” is often coded explicitly via elaborate prompt engineering using a ReAct pattern (typically) or preconstructed execution graph. Llama 3.1 (and future Llamas) attempts to absorb this multi-step reasoning loop inside the main model itself. @@ -63,9 +65,9 @@ The sequence diagram that details the steps is [here](https://github.com/meta-ll We define the Llama Stack as a layer cake shown below. -![Figure 3: Llama Stack](../docs/resources/llama-stack.png) +![Figure 3: Llama Stack](resources/llama-stack.png) -The API is defined in the [YAML](../docs/_static/llama-stack-spec.yaml) and [HTML](../docs/_static/llama-stack-spec.html) files. +The API is defined in the [YAML](_static/llama-stack-spec.yaml) and [HTML](_static/llama-stack-spec.html) files. ## Sample implementations From 4cf1952c32b74be607fd1aefb026a65e70d8b7ef Mon Sep 17 00:00:00 2001 From: ehhuang Date: Thu, 10 Jul 2025 14:40:17 -0700 Subject: [PATCH 10/11] chore: update vllm k8s command to support tool calling (#2717) # What does this PR do? ## Test Plan --- docs/source/distributions/k8s/apply.sh | 10 +++++----- docs/source/distributions/k8s/vllm-k8s.yaml.template | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/source/distributions/k8s/apply.sh b/docs/source/distributions/k8s/apply.sh index 06b1ea10c..7b403d34e 100755 --- a/docs/source/distributions/k8s/apply.sh +++ b/docs/source/distributions/k8s/apply.sh @@ -6,12 +6,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -export POSTGRES_USER=${POSTGRES_USER:-llamastack} -export POSTGRES_DB=${POSTGRES_DB:-llamastack} -export POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-llamastack} +export POSTGRES_USER=llamastack +export POSTGRES_DB=llamastack +export POSTGRES_PASSWORD=llamastack -export INFERENCE_MODEL=${INFERENCE_MODEL:-meta-llama/Llama-3.2-3B-Instruct} -export SAFETY_MODEL=${SAFETY_MODEL:-meta-llama/Llama-Guard-3-1B} +export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct +export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B # HF_TOKEN should be set by the user; base64 encode it for the secret if [ -n "${HF_TOKEN:-}" ]; then diff --git a/docs/source/distributions/k8s/vllm-k8s.yaml.template b/docs/source/distributions/k8s/vllm-k8s.yaml.template index 03f3759c3..22bee4bbc 100644 --- a/docs/source/distributions/k8s/vllm-k8s.yaml.template +++ b/docs/source/distributions/k8s/vllm-k8s.yaml.template @@ -32,7 +32,7 @@ spec: image: vllm/vllm-openai:latest command: ["/bin/sh", "-c"] args: - - "vllm serve ${INFERENCE_MODEL} --dtype float16 --enforce-eager --max-model-len 4096 --gpu-memory-utilization 0.6" + - "vllm serve ${INFERENCE_MODEL} --dtype float16 --enforce-eager --max-model-len 4096 --gpu-memory-utilization 0.6 --enable-auto-tool-choice --tool-call-parser llama4_pythonic" env: - name: INFERENCE_MODEL value: "${INFERENCE_MODEL}" From d880c2df0ed0d1405a5458a25309ad3b66907219 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Thu, 10 Jul 2025 14:40:32 -0700 Subject: [PATCH 11/11] fix: auth sql store: user is owner policy (#2674) # What does this PR do? The current authorized sql store implementation does not respect user.principal (only checks attributes). This PR addresses that. ## Test Plan Added test cases to integration tests. --- .../access_control/access_control.py | 2 +- .../utils/sqlstore/authorized_sqlstore.py | 54 ++-- .../utils/sqlstore/sqlalchemy_sqlstore.py | 58 ++-- .../sqlstore/test_authorized_sqlstore.py | 304 +++++++++++------- tests/unit/utils/test_authorized_sqlstore.py | 4 +- 5 files changed, 247 insertions(+), 175 deletions(-) diff --git a/llama_stack/distribution/access_control/access_control.py b/llama_stack/distribution/access_control/access_control.py index 075152ce4..64c0122c1 100644 --- a/llama_stack/distribution/access_control/access_control.py +++ b/llama_stack/distribution/access_control/access_control.py @@ -81,7 +81,7 @@ def is_action_allowed( if not len(policy): policy = default_policy() - qualified_resource_id = resource.type + "::" + resource.identifier + qualified_resource_id = f"{resource.type}::{resource.identifier}" for rule in policy: if rule.forbid and matches_scope(rule.forbid, action, qualified_resource_id, user.principal): if rule.when: diff --git a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py index 5dff7f122..864a7dbb6 100644 --- a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py @@ -39,22 +39,10 @@ SQL_OPTIMIZED_POLICY = [ class SqlRecord(ProtectedResource): - """Simple ProtectedResource implementation for SQL records.""" - - def __init__(self, record_id: str, table_name: str, access_attributes: dict[str, list[str]] | None = None): + def __init__(self, record_id: str, table_name: str, owner: User): self.type = f"sql_record::{table_name}" self.identifier = record_id - - if access_attributes: - self.owner = User( - principal="system", - attributes=access_attributes, - ) - else: - self.owner = User( - principal="system_public", - attributes=None, - ) + self.owner = owner class AuthorizedSqlStore: @@ -101,22 +89,27 @@ class AuthorizedSqlStore: async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None: """Create a table with built-in access control support.""" - await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON) enhanced_schema = dict(schema) if "access_attributes" not in enhanced_schema: enhanced_schema["access_attributes"] = ColumnType.JSON + if "owner_principal" not in enhanced_schema: + enhanced_schema["owner_principal"] = ColumnType.STRING await self.sql_store.create_table(table, enhanced_schema) + await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON) + await self.sql_store.add_column_if_not_exists(table, "owner_principal", ColumnType.STRING) async def insert(self, table: str, data: Mapping[str, Any]) -> None: """Insert a row with automatic access control attribute capture.""" enhanced_data = dict(data) current_user = get_authenticated_user() - if current_user and current_user.attributes: + if current_user: + enhanced_data["owner_principal"] = current_user.principal enhanced_data["access_attributes"] = current_user.attributes else: + enhanced_data["owner_principal"] = None enhanced_data["access_attributes"] = None await self.sql_store.insert(table, enhanced_data) @@ -146,9 +139,12 @@ class AuthorizedSqlStore: for row in rows.data: stored_access_attrs = row.get("access_attributes") + stored_owner_principal = row.get("owner_principal") or "" record_id = row.get("id", "unknown") - sql_record = SqlRecord(str(record_id), table, stored_access_attrs) + sql_record = SqlRecord( + str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs) + ) if is_action_allowed(policy, Action.READ, sql_record, current_user): filtered_rows.append(row) @@ -186,8 +182,10 @@ class AuthorizedSqlStore: Only applies SQL filtering for the default policy to ensure correctness. For custom policies, uses conservative filtering to avoid blocking legitimate access. """ + current_user = get_authenticated_user() + if not policy or policy == SQL_OPTIMIZED_POLICY: - return self._build_default_policy_where_clause() + return self._build_default_policy_where_clause(current_user) else: return self._build_conservative_where_clause() @@ -227,29 +225,27 @@ class AuthorizedSqlStore: def _get_public_access_conditions(self) -> list[str]: """Get the SQL conditions for public access.""" + # Public records are records that have no owner_principal or access_attributes + conditions = ["owner_principal = ''"] if self.database_type == SqlStoreType.postgres: # Postgres stores JSON null as 'null' - return ["access_attributes::text = 'null'"] + conditions.append("access_attributes::text = 'null'") elif self.database_type == SqlStoreType.sqlite: - return ["access_attributes = 'null'"] + conditions.append("access_attributes = 'null'") else: raise ValueError(f"Unsupported database type: {self.database_type}") + return conditions - def _build_default_policy_where_clause(self) -> str: + def _build_default_policy_where_clause(self, current_user: User | None) -> str: """Build SQL WHERE clause for the default policy. Default policy: permit all actions when user in owners [roles, teams, projects, namespaces] This means user must match ALL attribute categories that exist in the resource. """ - current_user = get_authenticated_user() - base_conditions = self._get_public_access_conditions() - if not current_user or not current_user.attributes: - # Only allow public records - return f"({' OR '.join(base_conditions)})" - else: - user_attr_conditions = [] + user_attr_conditions = [] + if current_user and current_user.attributes: for attr_key, user_values in current_user.attributes.items(): if user_values: value_conditions = [] @@ -269,7 +265,7 @@ class AuthorizedSqlStore: all_requirements_met = f"({' AND '.join(user_attr_conditions)})" base_conditions.append(all_requirements_met) - return f"({' OR '.join(base_conditions)})" + return f"({' OR '.join(base_conditions)})" def _build_conservative_where_clause(self) -> str: """Conservative SQL filtering for custom policies. diff --git a/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py b/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py index 3aecb0d59..6414929db 100644 --- a/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py @@ -244,35 +244,41 @@ class SqlAlchemySqlStoreImpl(SqlStore): engine = create_async_engine(self.config.engine_str) try: - inspector = inspect(engine) - - table_names = inspector.get_table_names() - if table not in table_names: - return - - existing_columns = inspector.get_columns(table) - column_names = [col["name"] for col in existing_columns] - - if column_name in column_names: - return - - sqlalchemy_type = TYPE_MAPPING.get(column_type) - if not sqlalchemy_type: - raise ValueError(f"Unsupported column type '{column_type}' for column '{column_name}'.") - - # Create the ALTER TABLE statement - # Note: We need to get the dialect-specific type name - dialect = engine.dialect - type_impl = sqlalchemy_type() - compiled_type = type_impl.compile(dialect=dialect) - - nullable_clause = "" if nullable else " NOT NULL" - add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}") - async with engine.begin() as conn: + + def check_column_exists(sync_conn): + inspector = inspect(sync_conn) + + table_names = inspector.get_table_names() + if table not in table_names: + return False, False # table doesn't exist, column doesn't exist + + existing_columns = inspector.get_columns(table) + column_names = [col["name"] for col in existing_columns] + + return True, column_name in column_names # table exists, column exists or not + + table_exists, column_exists = await conn.run_sync(check_column_exists) + if not table_exists or column_exists: + return + + sqlalchemy_type = TYPE_MAPPING.get(column_type) + if not sqlalchemy_type: + raise ValueError(f"Unsupported column type '{column_type}' for column '{column_name}'.") + + # Create the ALTER TABLE statement + # Note: We need to get the dialect-specific type name + dialect = engine.dialect + type_impl = sqlalchemy_type() + compiled_type = type_impl.compile(dialect=dialect) + + nullable_clause = "" if nullable else " NOT NULL" + add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}") + await conn.execute(add_column_sql) - except Exception: + except Exception as e: # If any error occurs during migration, log it but don't fail # The table creation will handle adding the column + logger.error(f"Error adding column {column_name} to table {table}: {e}") pass diff --git a/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py b/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py index 93b4d8905..bf6077532 100644 --- a/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py +++ b/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py @@ -14,8 +14,7 @@ from llama_stack.distribution.access_control.access_control import default_polic from llama_stack.distribution.datatypes import User from llama_stack.providers.utils.sqlstore.api import ColumnType from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore -from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl -from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig, SqliteSqlStoreConfig +from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig, SqliteSqlStoreConfig, sqlstore_impl def get_postgres_config(): @@ -30,144 +29,213 @@ def get_postgres_config(): def get_sqlite_config(): - """Get SQLite configuration with temporary database.""" - tmp_file = tempfile.NamedTemporaryFile(suffix=".db", delete=False) - tmp_file.close() - return SqliteSqlStoreConfig(db_path=tmp_file.name), tmp_file.name + """Get SQLite configuration with temporary file database.""" + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db") + temp_file.close() + return SqliteSqlStoreConfig(db_path=temp_file.name) + + +# Backend configurations for parametrized tests +BACKEND_CONFIGS = [ + pytest.param( + get_postgres_config, + marks=pytest.mark.skipif( + not os.environ.get("ENABLE_POSTGRES_TESTS"), + reason="PostgreSQL tests require ENABLE_POSTGRES_TESTS environment variable", + ), + id="postgres", + ), + pytest.param(get_sqlite_config, id="sqlite"), +] + + +@pytest.fixture +def authorized_store(backend_config): + """Set up authorized store with proper cleanup.""" + config_func = backend_config + + config = config_func() + + base_sqlstore = sqlstore_impl(config) + authorized_store = AuthorizedSqlStore(base_sqlstore) + + yield authorized_store + + if hasattr(config, "db_path"): + try: + os.unlink(config.db_path) + except (OSError, FileNotFoundError): + pass + + +async def create_test_table(authorized_store, table_name): + """Create a test table with standard schema.""" + await authorized_store.create_table( + table=table_name, + schema={ + "id": ColumnType.STRING, + "data": ColumnType.STRING, + }, + ) + + +async def cleanup_records(sql_store, table_name, record_ids): + """Clean up test records.""" + for record_id in record_ids: + try: + await sql_store.delete(table_name, {"id": record_id}) + except Exception: + pass @pytest.mark.asyncio -@pytest.mark.parametrize( - "backend_config", - [ - pytest.param( - ("postgres", get_postgres_config), - marks=pytest.mark.skipif( - not os.environ.get("ENABLE_POSTGRES_TESTS"), - reason="PostgreSQL tests require ENABLE_POSTGRES_TESTS environment variable", - ), - id="postgres", - ), - pytest.param(("sqlite", get_sqlite_config), id="sqlite"), - ], -) +@pytest.mark.parametrize("backend_config", BACKEND_CONFIGS) @patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") -async def test_json_comparison(mock_get_authenticated_user, backend_config): +async def test_authorized_store_attributes(mock_get_authenticated_user, authorized_store, request): """Test that JSON column comparisons work correctly for both PostgreSQL and SQLite""" - backend_name, config_func = backend_config + backend_name = request.node.callspec.id - # Handle different config types - if backend_name == "postgres": - config = config_func() - cleanup_path = None - else: # sqlite - config, cleanup_path = config_func() + # Create test table + table_name = f"test_json_comparison_{backend_name}" + await create_test_table(authorized_store, table_name) try: - base_sqlstore = SqlAlchemySqlStoreImpl(config) - authorized_store = AuthorizedSqlStore(base_sqlstore) + # Test with no authenticated user (should handle JSON null comparison) + mock_get_authenticated_user.return_value = None - # Create test table - table_name = f"test_json_comparison_{backend_name}" - await authorized_store.create_table( - table=table_name, - schema={ - "id": ColumnType.STRING, - "data": ColumnType.STRING, - }, + # Insert some test data + await authorized_store.insert(table_name, {"id": "1", "data": "public_data"}) + + # Test fetching with no user - should not error on JSON comparison + result = await authorized_store.fetch_all(table_name, policy=default_policy()) + assert len(result.data) == 1 + assert result.data[0]["id"] == "1" + assert result.data[0]["access_attributes"] is None + + # Test with authenticated user + test_user = User("test-user", {"roles": ["admin"]}) + mock_get_authenticated_user.return_value = test_user + + # Insert data with user attributes + await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"}) + + # Fetch all - admin should see both + result = await authorized_store.fetch_all(table_name, policy=default_policy()) + assert len(result.data) == 2 + + # Test with non-admin user + regular_user = User("regular-user", {"roles": ["user"]}) + mock_get_authenticated_user.return_value = regular_user + + # Should only see public record + result = await authorized_store.fetch_all(table_name, policy=default_policy()) + assert len(result.data) == 1 + assert result.data[0]["id"] == "1" + + # Test the category missing branch: user with multiple attributes + multi_user = User("multi-user", {"roles": ["admin"], "teams": ["dev"]}) + mock_get_authenticated_user.return_value = multi_user + + # Insert record with multi-user (has both roles and teams) + await authorized_store.insert(table_name, {"id": "3", "data": "multi_user_data"}) + + # Test different user types to create records with different attribute patterns + # Record with only roles (teams category will be missing) + roles_only_user = User("roles-user", {"roles": ["admin"]}) + mock_get_authenticated_user.return_value = roles_only_user + await authorized_store.insert(table_name, {"id": "4", "data": "roles_only_data"}) + + # Record with only teams (roles category will be missing) + teams_only_user = User("teams-user", {"teams": ["dev"]}) + mock_get_authenticated_user.return_value = teams_only_user + await authorized_store.insert(table_name, {"id": "5", "data": "teams_only_data"}) + + # Record with different roles/teams (shouldn't match our test user) + different_user = User("different-user", {"roles": ["user"], "teams": ["qa"]}) + mock_get_authenticated_user.return_value = different_user + await authorized_store.insert(table_name, {"id": "6", "data": "different_user_data"}) + + # Now test with the multi-user who has both roles=admin and teams=dev + mock_get_authenticated_user.return_value = multi_user + result = await authorized_store.fetch_all(table_name, policy=default_policy()) + + # Should see: + # - public record (1) - no access_attributes + # - admin record (2) - user matches roles=admin, teams missing (allowed) + # - multi_user record (3) - user matches both roles=admin and teams=dev + # - roles_only record (4) - user matches roles=admin, teams missing (allowed) + # - teams_only record (5) - user matches teams=dev, roles missing (allowed) + # Should NOT see: + # - different_user record (6) - user doesn't match roles=user or teams=qa + expected_ids = {"1", "2", "3", "4", "5"} + actual_ids = {record["id"] for record in result.data} + assert actual_ids == expected_ids, f"Expected to see records {expected_ids} but got {actual_ids}" + + # Verify the category missing logic specifically + # Records 4 and 5 test the "category missing" branch where one attribute category is missing + category_test_ids = {record["id"] for record in result.data if record["id"] in ["4", "5"]} + assert category_test_ids == {"4", "5"}, ( + f"Category missing logic failed: expected 4,5 but got {category_test_ids}" ) - try: - # Test with no authenticated user (should handle JSON null comparison) - mock_get_authenticated_user.return_value = None + finally: + # Clean up records + await cleanup_records(authorized_store.sql_store, table_name, ["1", "2", "3", "4", "5", "6"]) - # Insert some test data - await authorized_store.insert(table_name, {"id": "1", "data": "public_data"}) - # Test fetching with no user - should not error on JSON comparison - result = await authorized_store.fetch_all(table_name, policy=default_policy()) - assert len(result.data) == 1 - assert result.data[0]["id"] == "1" - assert result.data[0]["access_attributes"] is None +@pytest.mark.asyncio +@pytest.mark.parametrize("backend_config", BACKEND_CONFIGS) +@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") +async def test_user_ownership_policy(mock_get_authenticated_user, authorized_store, request): + """Test that 'user is owner' policies work correctly with record ownership""" + from llama_stack.distribution.access_control.datatypes import AccessRule, Action, Scope - # Test with authenticated user - test_user = User("test-user", {"roles": ["admin"]}) - mock_get_authenticated_user.return_value = test_user + backend_name = request.node.callspec.id - # Insert data with user attributes - await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"}) + # Create test table + table_name = f"test_ownership_{backend_name}" + await create_test_table(authorized_store, table_name) - # Fetch all - admin should see both - result = await authorized_store.fetch_all(table_name, policy=default_policy()) - assert len(result.data) == 2 + try: + # Test with first user who creates records + user1 = User("user1", {"roles": ["admin"]}) + mock_get_authenticated_user.return_value = user1 - # Test with non-admin user - regular_user = User("regular-user", {"roles": ["user"]}) - mock_get_authenticated_user.return_value = regular_user + # Insert a record owned by user1 + await authorized_store.insert(table_name, {"id": "1", "data": "user1_data"}) - # Should only see public record - result = await authorized_store.fetch_all(table_name, policy=default_policy()) - assert len(result.data) == 1 - assert result.data[0]["id"] == "1" + # Test with second user + user2 = User("user2", {"roles": ["user"]}) + mock_get_authenticated_user.return_value = user2 - # Test the category missing branch: user with multiple attributes - multi_user = User("multi-user", {"roles": ["admin"], "teams": ["dev"]}) - mock_get_authenticated_user.return_value = multi_user + # Insert a record owned by user2 + await authorized_store.insert(table_name, {"id": "2", "data": "user2_data"}) - # Insert record with multi-user (has both roles and teams) - await authorized_store.insert(table_name, {"id": "3", "data": "multi_user_data"}) + # Create a policy that only allows access when user is the owner + owner_only_policy = [ + AccessRule( + permit=Scope(actions=[Action.READ]), + when=["user is owner"], + ), + ] - # Test different user types to create records with different attribute patterns - # Record with only roles (teams category will be missing) - roles_only_user = User("roles-user", {"roles": ["admin"]}) - mock_get_authenticated_user.return_value = roles_only_user - await authorized_store.insert(table_name, {"id": "4", "data": "roles_only_data"}) + # Test user1 access - should only see their own record + mock_get_authenticated_user.return_value = user1 + result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) + assert len(result.data) == 1, f"Expected user1 to see 1 record, got {len(result.data)}" + assert result.data[0]["id"] == "1", f"Expected user1's record, got {result.data[0]['id']}" - # Record with only teams (roles category will be missing) - teams_only_user = User("teams-user", {"teams": ["dev"]}) - mock_get_authenticated_user.return_value = teams_only_user - await authorized_store.insert(table_name, {"id": "5", "data": "teams_only_data"}) + # Test user2 access - should only see their own record + mock_get_authenticated_user.return_value = user2 + result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) + assert len(result.data) == 1, f"Expected user2 to see 1 record, got {len(result.data)}" + assert result.data[0]["id"] == "2", f"Expected user2's record, got {result.data[0]['id']}" - # Record with different roles/teams (shouldn't match our test user) - different_user = User("different-user", {"roles": ["user"], "teams": ["qa"]}) - mock_get_authenticated_user.return_value = different_user - await authorized_store.insert(table_name, {"id": "6", "data": "different_user_data"}) - - # Now test with the multi-user who has both roles=admin and teams=dev - mock_get_authenticated_user.return_value = multi_user - result = await authorized_store.fetch_all(table_name, policy=default_policy()) - - # Should see: - # - public record (1) - no access_attributes - # - admin record (2) - user matches roles=admin, teams missing (allowed) - # - multi_user record (3) - user matches both roles=admin and teams=dev - # - roles_only record (4) - user matches roles=admin, teams missing (allowed) - # - teams_only record (5) - user matches teams=dev, roles missing (allowed) - # Should NOT see: - # - different_user record (6) - user doesn't match roles=user or teams=qa - expected_ids = {"1", "2", "3", "4", "5"} - actual_ids = {record["id"] for record in result.data} - assert actual_ids == expected_ids, f"Expected to see records {expected_ids} but got {actual_ids}" - - # Verify the category missing logic specifically - # Records 4 and 5 test the "category missing" branch where one attribute category is missing - category_test_ids = {record["id"] for record in result.data if record["id"] in ["4", "5"]} - assert category_test_ids == {"4", "5"}, ( - f"Category missing logic failed: expected 4,5 but got {category_test_ids}" - ) - - finally: - # Clean up records - for record_id in ["1", "2", "3", "4", "5", "6"]: - try: - await base_sqlstore.delete(table_name, {"id": record_id}) - except Exception: - pass + # Test with anonymous user - should see no records + mock_get_authenticated_user.return_value = None + result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) + assert len(result.data) == 0, f"Expected anonymous user to see 0 records, got {len(result.data)}" finally: - # Clean up temporary SQLite database file if needed - if cleanup_path: - try: - os.unlink(cleanup_path) - except OSError: - pass + # Clean up records + await cleanup_records(authorized_store.sql_store, table_name, ["1", "2"]) diff --git a/tests/unit/utils/test_authorized_sqlstore.py b/tests/unit/utils/test_authorized_sqlstore.py index 1624c0ba7..61763719a 100644 --- a/tests/unit/utils/test_authorized_sqlstore.py +++ b/tests/unit/utils/test_authorized_sqlstore.py @@ -153,7 +153,9 @@ async def test_sql_policy_consistency(mock_get_authenticated_user): policy_ids = set() for scenario in test_scenarios: sql_record = SqlRecord( - record_id=scenario["id"], table_name="resources", access_attributes=scenario["access_attributes"] + record_id=scenario["id"], + table_name="resources", + owner=User(principal="test-user", attributes=scenario["access_attributes"]), ) if is_action_allowed(policy, Action.READ, sql_record, user):