Merge branch 'main' into vectordb_name

This commit is contained in:
Francisco Arceo 2025-07-11 13:46:16 -04:00 committed by GitHub
commit 67f1131040
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
36 changed files with 803 additions and 696 deletions

View file

@ -8,4 +8,6 @@ runs:
run: |
docker run -d --name ollama -p 11434:11434 docker.io/leseb/ollama-with-models
# TODO: rebuild an ollama image with llama-guard3:1b
echo "Verifying Ollama status..."
timeout 30 bash -c 'while ! curl -s -L http://127.0.0.1:11434; do sleep 1 && echo "."; done'
docker exec ollama ollama pull llama-guard3:1b

View file

@ -3,10 +3,10 @@ name: Installer CI
on:
pull_request:
paths:
- 'install.sh'
- 'scripts/install.sh'
push:
paths:
- 'install.sh'
- 'scripts/install.sh'
schedule:
- cron: '0 2 * * *' # every day at 02:00 UTC
@ -16,11 +16,11 @@ jobs:
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2
- name: Run ShellCheck on install.sh
run: shellcheck install.sh
run: shellcheck scripts/install.sh
smoke-test:
needs: lint
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2
- name: Run installer end-to-end
run: ./install.sh
run: ./scripts/install.sh

View file

@ -18,16 +18,33 @@ concurrency:
cancel-in-progress: true
jobs:
test-matrix:
discover-tests:
runs-on: ubuntu-latest
outputs:
test-type: ${{ steps.generate-matrix.outputs.test-type }}
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Generate test matrix
id: generate-matrix
run: |
# Get test directories dynamically, excluding non-test directories
TEST_TYPES=$(find tests/integration -maxdepth 1 -mindepth 1 -type d -printf "%f\n" |
grep -Ev "^(__pycache__|fixtures|test_cases)$" |
sort | jq -R -s -c 'split("\n")[:-1]')
echo "test-type=$TEST_TYPES" >> $GITHUB_OUTPUT
test-matrix:
needs: discover-tests
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
# Listing tests manually since some of them currently fail
# TODO: generate matrix list from tests/integration when fixed
test-type: [agents, inference, datasets, inspect, safety, scoring, post_training, providers, tool_runtime, vector_io]
test-type: ${{ fromJson(needs.discover-tests.outputs.test-type) }}
client-type: [library, server]
python-version: ["3.12", "3.13"]
fail-fast: false # we want to run all tests regardless of failure
steps:
- name: Checkout repository
@ -51,23 +68,13 @@ jobs:
free -h
df -h
- name: Verify Ollama status is OK
if: matrix.client-type == 'http'
run: |
echo "Verifying Ollama status..."
ollama_status=$(curl -s -L http://127.0.0.1:8321/v1/providers/ollama|jq --raw-output .health.status)
echo "Ollama status: $ollama_status"
if [ "$ollama_status" != "OK" ]; then
echo "Ollama health check failed"
exit 1
fi
- name: Run Integration Tests
env:
OLLAMA_INFERENCE_MODEL: "llama3.2:3b-instruct-fp16" # for server tests
ENABLE_OLLAMA: "ollama" # for server tests
OLLAMA_URL: "http://0.0.0.0:11434"
SAFETY_MODEL: "llama-guard3:1b"
LLAMA_STACK_CLIENT_TIMEOUT: "300" # Increased timeout for eval operations
# Use 'shell' to get pipefail behavior
# https://docs.github.com/en/actions/reference/workflow-syntax-for-github-actions#exit-codes-and-error-action-preference
# TODO: write a precommit hook to detect if a test contains a pipe but does not use 'shell: bash'

View file

@ -29,7 +29,7 @@ repos:
- id: check-toml
- repo: https://github.com/Lucas-C/pre-commit-hooks
rev: v1.5.4
rev: v1.5.5
hooks:
- id: insert-license
files: \.py$|\.sh$
@ -38,7 +38,7 @@ repos:
- docs/license_header.txt
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.4
rev: v0.12.2
hooks:
- id: ruff
args: [ --fix ]
@ -46,14 +46,14 @@ repos:
- id: ruff-format
- repo: https://github.com/adamchainz/blacken-docs
rev: 1.19.0
rev: 1.19.1
hooks:
- id: blacken-docs
additional_dependencies:
- black==24.3.0
- repo: https://github.com/astral-sh/uv-pre-commit
rev: 0.7.8
rev: 0.7.20
hooks:
- id: uv-lock
- id: uv-export
@ -66,7 +66,7 @@ repos:
]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.15.0
rev: v1.16.1
hooks:
- id: mypy
additional_dependencies:
@ -133,3 +133,8 @@ repos:
ci:
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
autoupdate_commit_msg: ⬆ [pre-commit.ci] pre-commit autoupdate
autofix_prs: true
autoupdate_branch: ''
autoupdate_schedule: weekly
skip: []
submodules: false

View file

@ -66,7 +66,7 @@ You can install the dependencies by running:
```bash
cd llama-stack
uv sync --extra dev
uv sync --group dev
uv pip install -e .
source .venv/bin/activate
```
@ -168,7 +168,7 @@ manually as they are auto-generated.
### Updating the provider documentation
If you have made changes to a provider's configuration, you should run `./scripts/distro_codegen.py`
If you have made changes to a provider's configuration, you should run `./scripts/provider_codegen.py`
to re-generate the documentation. You should not change `docs/source/.../providers/` files manually
as they are auto-generated.
Note that the provider "description" field will be used to generate the provider documentation.

View file

@ -77,7 +77,7 @@ As more providers start supporting Llama 4, you can use them in Llama Stack as w
To try Llama Stack locally, run:
```bash
curl -LsSf https://github.com/meta-llama/llama-stack/raw/main/install.sh | bash
curl -LsSf https://github.com/meta-llama/llama-stack/raw/main/scripts/install.sh | bash
```
### Overview

View file

@ -1,5 +1,7 @@
# The Llama Stack API
*Originally authored Jul 23, 2024*
**Authors:**
* Meta: @raghotham, @ashwinb, @hjshah, @jspisak
@ -24,7 +26,7 @@ Meta releases weights of both the pretrained and instruction fine-tuned Llama mo
### Model Lifecycle
![Figure 1: Model Life Cycle](../docs/resources/model-lifecycle.png)
![Figure 1: Model Life Cycle](resources/model-lifecycle.png)
For each of the operations that need to be performed (e.g. fine tuning, inference, evals etc) during the model life cycle, we identified the capabilities as toolchain APIs that are needed. Some of these capabilities are primitive operations like inference while other capabilities like synthetic data generation are composed of other capabilities. The list of APIs we have identified to support the lifecycle of Llama models is below:
@ -37,7 +39,7 @@ For each of the operations that need to be performed (e.g. fine tuning, inferenc
### Agentic System
![Figure 2: Agentic System](../docs/resources/agentic-system.png)
![Figure 2: Agentic System](resources/agentic-system.png)
In addition to the model lifecycle, we considered the different components involved in an agentic system. Specifically around tool calling and shields. Since the model may decide to call tools, a single model inference call is not enough. Whats needed is an agentic loop consisting of tool calls and inference. The model provides separate tokens representing end-of-message and end-of-turn. A message represents a possible stopping point for execution where the model can inform the execution environment that a tool call needs to be made. The execution environment, upon execution, adds back the result to the context window and makes another inference call. This process can get repeated until an end-of-turn token is generated.
Note that as of today, in the OSS world, such a “loop” is often coded explicitly via elaborate prompt engineering using a ReAct pattern (typically) or preconstructed execution graph. Llama 3.1 (and future Llamas) attempts to absorb this multi-step reasoning loop inside the main model itself.
@ -63,9 +65,9 @@ The sequence diagram that details the steps is [here](https://github.com/meta-ll
We define the Llama Stack as a layer cake shown below.
![Figure 3: Llama Stack](../docs/resources/llama-stack.png)
![Figure 3: Llama Stack](resources/llama-stack.png)
The API is defined in the [YAML](../docs/_static/llama-stack-spec.yaml) and [HTML](../docs/_static/llama-stack-spec.html) files.
The API is defined in the [YAML](_static/llama-stack-spec.yaml) and [HTML](_static/llama-stack-spec.html) files.
## Sample implementations

View file

@ -6,12 +6,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
export POSTGRES_USER=${POSTGRES_USER:-llamastack}
export POSTGRES_DB=${POSTGRES_DB:-llamastack}
export POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-llamastack}
export POSTGRES_USER=llamastack
export POSTGRES_DB=llamastack
export POSTGRES_PASSWORD=llamastack
export INFERENCE_MODEL=${INFERENCE_MODEL:-meta-llama/Llama-3.2-3B-Instruct}
export SAFETY_MODEL=${SAFETY_MODEL:-meta-llama/Llama-Guard-3-1B}
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
# HF_TOKEN should be set by the user; base64 encode it for the secret
if [ -n "${HF_TOKEN:-}" ]; then

View file

@ -32,7 +32,7 @@ spec:
image: vllm/vllm-openai:latest
command: ["/bin/sh", "-c"]
args:
- "vllm serve ${INFERENCE_MODEL} --dtype float16 --enforce-eager --max-model-len 4096 --gpu-memory-utilization 0.6"
- "vllm serve ${INFERENCE_MODEL} --dtype float16 --enforce-eager --max-model-len 4096 --gpu-memory-utilization 0.6 --enable-auto-tool-choice --tool-call-parser llama4_pythonic"
env:
- name: INFERENCE_MODEL
value: "${INFERENCE_MODEL}"

View file

@ -11,7 +11,7 @@ Please refer to the remote provider documentation.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `db_path` | `<class 'str'>` | No | PydanticUndefined | |
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | |
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend (SQLite only for now) |
| `consistency_level` | `<class 'str'>` | No | Strong | The consistency level of the Milvus server |
## Sample Configuration

View file

@ -205,12 +205,16 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `db_path` | `<class 'str'>` | No | PydanticUndefined | |
| `db_path` | `<class 'str'>` | No | PydanticUndefined | Path to the SQLite database file |
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend (SQLite only for now) |
## Sample Configuration
```yaml
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec.db
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec_registry.db
```

View file

@ -10,12 +10,16 @@ Please refer to the sqlite-vec provider documentation.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `db_path` | `<class 'str'>` | No | PydanticUndefined | |
| `db_path` | `<class 'str'>` | No | PydanticUndefined | Path to the SQLite database file |
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend (SQLite only for now) |
## Sample Configuration
```yaml
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec.db
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec_registry.db
```

View file

@ -93,7 +93,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
)
sys.exit(1)
elif args.providers:
providers = dict()
providers_list: dict[str, str | list[str]] = dict()
for api_provider in args.providers.split(","):
if "=" not in api_provider:
cprint(
@ -112,7 +112,15 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
)
sys.exit(1)
if provider in providers_for_api:
providers.setdefault(api, []).append(provider)
if api not in providers_list:
providers_list[api] = []
# Use type guarding to ensure we have a list
provider_value = providers_list[api]
if isinstance(provider_value, list):
provider_value.append(provider)
else:
# Convert string to list and append
providers_list[api] = [provider_value, provider]
else:
cprint(
f"{provider} is not a valid provider for the {api} API.",
@ -121,7 +129,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
)
sys.exit(1)
distribution_spec = DistributionSpec(
providers=providers,
providers=providers_list,
description=",".join(args.providers),
)
if not args.image_type:
@ -182,7 +190,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
cprint("Tip: use <TAB> to see options for the providers.\n", color="green", file=sys.stderr)
providers = dict()
providers: dict[str, str | list[str]] = dict()
for api, providers_for_api in get_provider_registry().items():
available_providers = [x for x in providers_for_api.keys() if x not in ("remote", "remote::sample")]
if not available_providers:
@ -371,10 +379,16 @@ def _run_stack_build_command_from_build_config(
if not image_name:
raise ValueError("Please specify an image name when building a venv image")
# At this point, image_name should be guaranteed to be a string
if image_name is None:
raise ValueError("image_name should not be None after validation")
if template_name:
build_dir = DISTRIBS_BASE_DIR / template_name
build_file_path = build_dir / f"{template_name}-build.yaml"
else:
if image_name is None:
raise ValueError("image_name cannot be None")
build_dir = DISTRIBS_BASE_DIR / image_name
build_file_path = build_dir / f"{image_name}-build.yaml"
@ -395,7 +409,7 @@ def _run_stack_build_command_from_build_config(
build_file_path,
image_name,
template_or_config=template_name or config_path or str(build_file_path),
run_config=run_config_file,
run_config=run_config_file.as_posix() if run_config_file else None,
)
if return_code != 0:
raise RuntimeError(f"Failed to build image {image_name}")

View file

@ -83,46 +83,57 @@ class StackRun(Subcommand):
return ImageType.CONDA.value, args.image_name
return args.image_type, args.image_name
def _resolve_config_and_template(self, args: argparse.Namespace) -> tuple[Path | None, str | None]:
"""Resolve config file path and template name from args.config"""
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
if not args.config:
return None, None
config_file = Path(args.config)
has_yaml_suffix = args.config.endswith(".yaml")
template_name = None
if not config_file.exists() and not has_yaml_suffix:
# check if this is a template
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
if config_file.exists():
template_name = args.config
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to ~/.llama dir
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
if not config_file.exists():
self.parser.error(
f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
)
if not config_file.is_file():
self.parser.error(
f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}"
)
return config_file, template_name
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
import yaml
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.utils.exec import formulate_run_args, run_command
if args.enable_ui:
self._start_ui_development_server(args.port)
image_type, image_name = self._get_image_type_and_name(args)
# Resolve config file and template name first
config_file, template_name = self._resolve_config_and_template(args)
# Check if config is required based on image type
if (image_type in [ImageType.CONDA.value, ImageType.VENV.value]) and not args.config:
if (image_type in [ImageType.CONDA.value, ImageType.VENV.value]) and not config_file:
self.parser.error("Config file is required for venv and conda environments")
if args.config:
config_file = Path(args.config)
has_yaml_suffix = args.config.endswith(".yaml")
template_name = None
if not config_file.exists() and not has_yaml_suffix:
# check if this is a template
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
if config_file.exists():
template_name = args.config
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to ~/.llama dir
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
if not config_file.exists():
self.parser.error(
f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
)
if not config_file.is_file():
self.parser.error(
f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}"
)
if config_file:
logger.info(f"Using run configuration: {config_file}")
try:
@ -138,8 +149,6 @@ class StackRun(Subcommand):
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
else:
config = None
config_file = None
template_name = None
# If neither image type nor image name is provided, assume the server should be run directly
# using the current environment packages.
@ -172,10 +181,7 @@ class StackRun(Subcommand):
run_args.extend([str(args.port)])
if config_file:
if template_name:
run_args.extend(["--template", str(template_name)])
else:
run_args.extend(["--config", str(config_file)])
run_args.extend(["--config", str(config_file)])
if args.env:
for env_var in args.env:

View file

@ -81,7 +81,7 @@ def is_action_allowed(
if not len(policy):
policy = default_policy()
qualified_resource_id = resource.type + "::" + resource.identifier
qualified_resource_id = f"{resource.type}::{resource.identifier}"
for rule in policy:
if rule.forbid and matches_scope(rule.forbid, action, qualified_resource_id, user.principal):
if rule.when:

View file

@ -96,7 +96,7 @@ FROM $container_base
WORKDIR /app
# We install the Python 3.12 dev headers and build tools so that any
# Cextension wheels (e.g. polyleven, faisscpu) can compile successfully.
# C-extension wheels (e.g. polyleven, faiss-cpu) can compile successfully.
RUN dnf -y update && dnf install -y iputils git net-tools wget \
vim-minimal python3.12 python3.12-pip python3.12-wheel \
@ -169,7 +169,7 @@ if [ -n "$run_config" ]; then
echo "Copying external providers directory: $external_providers_dir"
cp -r "$external_providers_dir" "$BUILD_CONTEXT_DIR/providers.d"
add_to_container << EOF
COPY --chmod=g+w providers.d /.llama/providers.d
COPY providers.d /.llama/providers.d
EOF
fi

View file

@ -445,7 +445,7 @@ def main(args: argparse.Namespace | None = None):
logger.info(log_line)
logger.info("Run configuration:")
safe_config = redact_sensitive_fields(config.model_dump())
safe_config = redact_sensitive_fields(config.model_dump(mode="json"))
logger.info(yaml.dump(safe_config, indent=2))
app = FastAPI(

View file

@ -6,12 +6,9 @@
from collections.abc import AsyncGenerator
from contextvars import ContextVar
from typing import TypeVar
T = TypeVar("T")
def preserve_contexts_async_generator(
def preserve_contexts_async_generator[T](
gen: AsyncGenerator[T, None], context_vars: list[ContextVar]
) -> AsyncGenerator[T, None]:
"""

View file

@ -18,7 +18,7 @@ from llama_stack.schema_utils import json_schema_type
@json_schema_type
class MilvusVectorIOConfig(BaseModel):
db_path: str
kvstore: KVStoreConfig
kvstore: KVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)")
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
@classmethod

View file

@ -6,14 +6,24 @@
from typing import Any
from pydantic import BaseModel
from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
class SQLiteVectorIOConfig(BaseModel):
db_path: str
db_path: str = Field(description="Path to the SQLite database file")
kvstore: KVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)")
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
return {
"db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + "sqlite_vec.db",
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="sqlite_vec_registry.db",
),
}

View file

@ -24,6 +24,8 @@ from llama_stack.apis.vector_io import (
VectorIO,
)
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
RERANKER_TYPE_RRF,
@ -40,6 +42,13 @@ KEYWORD_SEARCH = "keyword"
HYBRID_SEARCH = "hybrid"
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH, HYBRID_SEARCH}
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:sqlite_vec:{VERSION}::"
VECTOR_INDEX_PREFIX = f"vector_index:sqlite_vec:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:sqlite_vec:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:sqlite_vec:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:sqlite_vec:{VERSION}::"
def serialize_vector(vector: list[float]) -> bytes:
"""Serialize a list of floats into a compact binary representation."""
@ -117,13 +126,14 @@ class SQLiteVecIndex(EmbeddingIndex):
- An FTS5 table (fts_chunks_{bank_id}) for full-text keyword search.
"""
def __init__(self, dimension: int, db_path: str, bank_id: str):
def __init__(self, dimension: int, db_path: str, bank_id: str, kvstore: KVStore | None = None):
self.dimension = dimension
self.db_path = db_path
self.bank_id = bank_id
self.metadata_table = f"chunks_{bank_id}".replace("-", "_")
self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_")
self.fts_table = f"fts_chunks_{bank_id}".replace("-", "_")
self.kvstore = kvstore
@classmethod
async def create(cls, dimension: int, db_path: str, bank_id: str):
@ -425,27 +435,116 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
self.files_api = files_api
self.cache: dict[str, VectorDBWithIndex] = {}
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.kvstore: KVStore | None = None
async def initialize(self) -> None:
def _setup_connection():
# Open a connection to the SQLite database (the file is specified in the config).
self.kvstore = await kvstore_impl(self.config.kvstore)
start_key = VECTOR_DBS_PREFIX
end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
for db_json in stored_vector_dbs:
vector_db = VectorDB.model_validate_json(db_json)
index = await SQLiteVecIndex.create(
vector_db.embedding_dimension,
self.config.db_path,
vector_db.identifier,
)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
# load any existing OpenAI vector stores
self.openai_vector_stores = await self._load_openai_vector_stores()
async def shutdown(self) -> None:
# nothing to do since we don't maintain a persistent connection
pass
async def list_vector_dbs(self) -> list[VectorDB]:
return [v.vector_db for v in self.cache.values()]
async def register_vector_db(self, vector_db: VectorDB) -> None:
index = await SQLiteVecIndex.create(
vector_db.embedding_dimension,
self.config.db_path,
vector_db.identifier,
)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
if vector_db_id in self.cache:
return self.cache[vector_db_id]
if self.vector_db_store is None:
raise ValueError(f"Vector DB {vector_db_id} not found")
vector_db = self.vector_db_store.get_vector_db(vector_db_id)
if not vector_db:
raise ValueError(f"Vector DB {vector_db_id} not found")
index = VectorDBWithIndex(
vector_db=vector_db,
index=SQLiteVecIndex(
dimension=vector_db.embedding_dimension,
db_path=self.config.db_path,
bank_id=vector_db.identifier,
kvstore=self.kvstore,
),
inference_api=self.inference_api,
)
self.cache[vector_db_id] = index
return index
async def unregister_vector_db(self, vector_db_id: str) -> None:
if vector_db_id not in self.cache:
logger.warning(f"Vector DB {vector_db_id} not found")
return
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
# OpenAI Vector Store Mixin abstract method implementations
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to SQLite database."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
self.openai_vector_stores[store_id] = store_info
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
"""Load all vector store metadata from SQLite database."""
assert self.kvstore is not None
start_key = OPENAI_VECTOR_STORES_PREFIX
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key)
stores = {}
for store_data in stored_openai_stores:
store_info = json.loads(store_data)
stores[store_info["id"]] = store_info
return stores
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Update vector store metadata in SQLite database."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
self.openai_vector_stores[store_id] = store_info
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
"""Delete vector store metadata from SQLite database."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.delete(key)
if store_id in self.openai_vector_stores:
del self.openai_vector_stores[store_id]
async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
) -> None:
"""Save vector store file metadata to SQLite database."""
def _create_or_store():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
# Create a table to persist vector DB registrations.
cur.execute("""
CREATE TABLE IF NOT EXISTS vector_dbs (
id TEXT PRIMARY KEY,
metadata TEXT
);
""")
# Create a table to persist OpenAI vector stores.
cur.execute("""
CREATE TABLE IF NOT EXISTS openai_vector_stores (
id TEXT PRIMARY KEY,
metadata TEXT
);
""")
# Create a table to persist OpenAI vector store files.
cur.execute("""
CREATE TABLE IF NOT EXISTS openai_vector_store_files (
@ -464,168 +563,6 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
);
""")
connection.commit()
# Load any existing vector DB registrations.
cur.execute("SELECT metadata FROM vector_dbs")
vector_db_rows = cur.fetchall()
return vector_db_rows
finally:
cur.close()
connection.close()
vector_db_rows = await asyncio.to_thread(_setup_connection)
# Load existing vector DBs
for row in vector_db_rows:
vector_db_data = row[0]
vector_db = VectorDB.model_validate_json(vector_db_data)
index = await SQLiteVecIndex.create(
vector_db.embedding_dimension,
self.config.db_path,
vector_db.identifier,
)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
# Load existing OpenAI vector stores using the mixin method
self.openai_vector_stores = await self._load_openai_vector_stores()
async def shutdown(self) -> None:
# nothing to do since we don't maintain a persistent connection
pass
async def register_vector_db(self, vector_db: VectorDB) -> None:
def _register_db():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"INSERT OR REPLACE INTO vector_dbs (id, metadata) VALUES (?, ?)",
(vector_db.identifier, vector_db.model_dump_json()),
)
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_register_db)
index = await SQLiteVecIndex.create(
vector_db.embedding_dimension,
self.config.db_path,
vector_db.identifier,
)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
async def list_vector_dbs(self) -> list[VectorDB]:
return [v.vector_db for v in self.cache.values()]
async def unregister_vector_db(self, vector_db_id: str) -> None:
if vector_db_id not in self.cache:
logger.warning(f"Vector DB {vector_db_id} not found")
return
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
def _delete_vector_db_from_registry():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute("DELETE FROM vector_dbs WHERE id = ?", (vector_db_id,))
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_delete_vector_db_from_registry)
# OpenAI Vector Store Mixin abstract method implementations
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to SQLite database."""
def _store():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"INSERT OR REPLACE INTO openai_vector_stores (id, metadata) VALUES (?, ?)",
(store_id, json.dumps(store_info)),
)
connection.commit()
except Exception as e:
logger.error(f"Error saving openai vector store {store_id}: {e}")
raise
finally:
cur.close()
connection.close()
try:
await asyncio.to_thread(_store)
except Exception as e:
logger.error(f"Error saving openai vector store {store_id}: {e}")
raise
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
"""Load all vector store metadata from SQLite database."""
def _load():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute("SELECT metadata FROM openai_vector_stores")
rows = cur.fetchall()
return rows
finally:
cur.close()
connection.close()
rows = await asyncio.to_thread(_load)
stores = {}
for row in rows:
store_data = row[0]
store_info = json.loads(store_data)
stores[store_info["id"]] = store_info
return stores
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Update vector store metadata in SQLite database."""
def _update():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"UPDATE openai_vector_stores SET metadata = ? WHERE id = ?",
(json.dumps(store_info), store_id),
)
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_update)
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
"""Delete vector store metadata from SQLite database."""
def _delete():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute("DELETE FROM openai_vector_stores WHERE id = ?", (store_id,))
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_delete)
async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
) -> None:
"""Save vector store file metadata to SQLite database."""
def _store():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"INSERT OR REPLACE INTO openai_vector_store_files (store_id, file_id, metadata) VALUES (?, ?, ?)",
(store_id, file_id, json.dumps(file_info)),
@ -643,7 +580,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
connection.close()
try:
await asyncio.to_thread(_store)
await asyncio.to_thread(_create_or_store)
except Exception as e:
logger.error(f"Error saving openai vector store file {store_id} {file_id}: {e}")
raise
@ -722,6 +659,10 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
cur.execute(
"DELETE FROM openai_vector_store_files WHERE store_id = ? AND file_id = ?", (store_id, file_id)
)
cur.execute(
"DELETE FROM openai_vector_store_files_contents WHERE store_id = ? AND file_id = ?",
(store_id, file_id),
)
connection.commit()
finally:
cur.close()
@ -730,15 +671,17 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
await asyncio.to_thread(_delete)
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
if vector_db_id not in self.cache:
raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}")
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
# The VectorDBWithIndex helper is expected to compute embeddings via the inference_api
# and then call our index's add_chunks.
await self.cache[vector_db_id].insert_chunks(chunks)
await index.insert_chunks(chunks)
async def query_chunks(
self, vector_db_id: str, query: Any, params: dict[str, Any] | None = None
) -> QueryChunksResponse:
if vector_db_id not in self.cache:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
return await self.cache[vector_db_id].query_chunks(query, params)
return await index.query_chunks(query, params)

View file

@ -61,6 +61,11 @@ class MilvusIndex(EmbeddingIndex):
self.consistency_level = consistency_level
self.kvstore = kvstore
async def initialize(self):
# MilvusIndex does not require explicit initialization
# TODO: could move collection creation into initialization but it is not really necessary
pass
async def delete(self):
if await asyncio.to_thread(self.client.has_collection, self.collection_name):
await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name)
@ -199,6 +204,9 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
if vector_db_id in self.cache:
return self.cache[vector_db_id]
if self.vector_db_store is None:
raise ValueError(f"Vector DB {vector_db_id} not found")
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
if not vector_db:
raise ValueError(f"Vector DB {vector_db_id} not found")

View file

@ -39,22 +39,10 @@ SQL_OPTIMIZED_POLICY = [
class SqlRecord(ProtectedResource):
"""Simple ProtectedResource implementation for SQL records."""
def __init__(self, record_id: str, table_name: str, access_attributes: dict[str, list[str]] | None = None):
def __init__(self, record_id: str, table_name: str, owner: User):
self.type = f"sql_record::{table_name}"
self.identifier = record_id
if access_attributes:
self.owner = User(
principal="system",
attributes=access_attributes,
)
else:
self.owner = User(
principal="system_public",
attributes=None,
)
self.owner = owner
class AuthorizedSqlStore:
@ -101,22 +89,27 @@ class AuthorizedSqlStore:
async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None:
"""Create a table with built-in access control support."""
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
enhanced_schema = dict(schema)
if "access_attributes" not in enhanced_schema:
enhanced_schema["access_attributes"] = ColumnType.JSON
if "owner_principal" not in enhanced_schema:
enhanced_schema["owner_principal"] = ColumnType.STRING
await self.sql_store.create_table(table, enhanced_schema)
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
await self.sql_store.add_column_if_not_exists(table, "owner_principal", ColumnType.STRING)
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
"""Insert a row with automatic access control attribute capture."""
enhanced_data = dict(data)
current_user = get_authenticated_user()
if current_user and current_user.attributes:
if current_user:
enhanced_data["owner_principal"] = current_user.principal
enhanced_data["access_attributes"] = current_user.attributes
else:
enhanced_data["owner_principal"] = None
enhanced_data["access_attributes"] = None
await self.sql_store.insert(table, enhanced_data)
@ -146,9 +139,12 @@ class AuthorizedSqlStore:
for row in rows.data:
stored_access_attrs = row.get("access_attributes")
stored_owner_principal = row.get("owner_principal") or ""
record_id = row.get("id", "unknown")
sql_record = SqlRecord(str(record_id), table, stored_access_attrs)
sql_record = SqlRecord(
str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs)
)
if is_action_allowed(policy, Action.READ, sql_record, current_user):
filtered_rows.append(row)
@ -186,8 +182,10 @@ class AuthorizedSqlStore:
Only applies SQL filtering for the default policy to ensure correctness.
For custom policies, uses conservative filtering to avoid blocking legitimate access.
"""
current_user = get_authenticated_user()
if not policy or policy == SQL_OPTIMIZED_POLICY:
return self._build_default_policy_where_clause()
return self._build_default_policy_where_clause(current_user)
else:
return self._build_conservative_where_clause()
@ -227,29 +225,27 @@ class AuthorizedSqlStore:
def _get_public_access_conditions(self) -> list[str]:
"""Get the SQL conditions for public access."""
# Public records are records that have no owner_principal or access_attributes
conditions = ["owner_principal = ''"]
if self.database_type == SqlStoreType.postgres:
# Postgres stores JSON null as 'null'
return ["access_attributes::text = 'null'"]
conditions.append("access_attributes::text = 'null'")
elif self.database_type == SqlStoreType.sqlite:
return ["access_attributes = 'null'"]
conditions.append("access_attributes = 'null'")
else:
raise ValueError(f"Unsupported database type: {self.database_type}")
return conditions
def _build_default_policy_where_clause(self) -> str:
def _build_default_policy_where_clause(self, current_user: User | None) -> str:
"""Build SQL WHERE clause for the default policy.
Default policy: permit all actions when user in owners [roles, teams, projects, namespaces]
This means user must match ALL attribute categories that exist in the resource.
"""
current_user = get_authenticated_user()
base_conditions = self._get_public_access_conditions()
if not current_user or not current_user.attributes:
# Only allow public records
return f"({' OR '.join(base_conditions)})"
else:
user_attr_conditions = []
user_attr_conditions = []
if current_user and current_user.attributes:
for attr_key, user_values in current_user.attributes.items():
if user_values:
value_conditions = []
@ -269,7 +265,7 @@ class AuthorizedSqlStore:
all_requirements_met = f"({' AND '.join(user_attr_conditions)})"
base_conditions.append(all_requirements_met)
return f"({' OR '.join(base_conditions)})"
return f"({' OR '.join(base_conditions)})"
def _build_conservative_where_clause(self) -> str:
"""Conservative SQL filtering for custom policies.

View file

@ -244,35 +244,41 @@ class SqlAlchemySqlStoreImpl(SqlStore):
engine = create_async_engine(self.config.engine_str)
try:
inspector = inspect(engine)
table_names = inspector.get_table_names()
if table not in table_names:
return
existing_columns = inspector.get_columns(table)
column_names = [col["name"] for col in existing_columns]
if column_name in column_names:
return
sqlalchemy_type = TYPE_MAPPING.get(column_type)
if not sqlalchemy_type:
raise ValueError(f"Unsupported column type '{column_type}' for column '{column_name}'.")
# Create the ALTER TABLE statement
# Note: We need to get the dialect-specific type name
dialect = engine.dialect
type_impl = sqlalchemy_type()
compiled_type = type_impl.compile(dialect=dialect)
nullable_clause = "" if nullable else " NOT NULL"
add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}")
async with engine.begin() as conn:
def check_column_exists(sync_conn):
inspector = inspect(sync_conn)
table_names = inspector.get_table_names()
if table not in table_names:
return False, False # table doesn't exist, column doesn't exist
existing_columns = inspector.get_columns(table)
column_names = [col["name"] for col in existing_columns]
return True, column_name in column_names # table exists, column exists or not
table_exists, column_exists = await conn.run_sync(check_column_exists)
if not table_exists or column_exists:
return
sqlalchemy_type = TYPE_MAPPING.get(column_type)
if not sqlalchemy_type:
raise ValueError(f"Unsupported column type '{column_type}' for column '{column_name}'.")
# Create the ALTER TABLE statement
# Note: We need to get the dialect-specific type name
dialect = engine.dialect
type_impl = sqlalchemy_type()
compiled_type = type_impl.compile(dialect=dialect)
nullable_clause = "" if nullable else " NOT NULL"
add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}")
await conn.execute(add_column_sql)
except Exception:
except Exception as e:
# If any error occurs during migration, log it but don't fail
# The table creation will handle adding the column
logger.error(f"Error adding column {column_name} to table {table}: {e}")
pass

View file

@ -9,14 +9,12 @@ import inspect
import json
from collections.abc import AsyncGenerator, Callable
from functools import wraps
from typing import Any, TypeVar
from typing import Any
from pydantic import BaseModel
from llama_stack.models.llama.datatypes import Primitive
T = TypeVar("T")
def serialize_value(value: Any) -> Primitive:
return str(_prepare_for_json(value))
@ -44,7 +42,7 @@ def _prepare_for_json(value: Any) -> str:
return str(value)
def trace_protocol(cls: type[T]) -> type[T]:
def trace_protocol[T](cls: type[T]) -> type[T]:
"""
A class decorator that automatically traces all methods in a protocol/base class
and its inheriting classes.

View file

@ -39,6 +39,9 @@ providers:
provider_type: inline::sqlite-vec
config:
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/sqlite_vec.db
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/sqlite_vec_registry.db
- provider_id: ${env.ENABLE_CHROMADB:+chromadb}
provider_type: remote::chromadb
config:

View file

@ -144,6 +144,9 @@ providers:
provider_type: inline::sqlite-vec
config:
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db
- provider_id: ${env.ENABLE_MILVUS:=__disabled__}
provider_type: inline::milvus
config:

View file

@ -42,8 +42,8 @@ dependencies = [
"h11>=0.16.0",
"python-multipart>=0.0.20", # For fastapi Form
"uvicorn>=0.34.0", # server
"opentelemetry-sdk", # server
"opentelemetry-exporter-otlp-proto-http", # server
"opentelemetry-sdk>=1.30.0", # server
"opentelemetry-exporter-otlp-proto-http>=1.30.0", # server
"aiosqlite>=0.21.0", # server - for metadata store
"asyncpg", # for metadata store
]
@ -226,7 +226,6 @@ follow_imports = "silent"
exclude = [
# As we fix more and more of these, we should remove them from the list
"^llama_stack/cli/download\\.py$",
"^llama_stack/cli/stack/_build\\.py$",
"^llama_stack/distribution/build\\.py$",
"^llama_stack/distribution/client\\.py$",
"^llama_stack/distribution/request_headers\\.py$",

View file

@ -6,6 +6,7 @@
import inspect
import os
import signal
import socket
import subprocess
import tempfile
@ -45,6 +46,8 @@ def start_llama_stack_server(config_name: str) -> subprocess.Popen:
stderr=subprocess.PIPE, # keep stderr to see errors
text=True,
env={**os.environ, "LLAMA_STACK_LOG_FILE": "server.log"},
# Create new process group so we can kill all child processes
preexec_fn=os.setsid,
)
return process
@ -197,7 +200,7 @@ def llama_stack_client(request, provider_data):
server_process = start_llama_stack_server(config_name)
# Wait for server to be ready
if not wait_for_server_ready(base_url, timeout=30, process=server_process):
if not wait_for_server_ready(base_url, timeout=120, process=server_process):
print("Server failed to start within timeout")
server_process.terminate()
raise RuntimeError(
@ -215,6 +218,7 @@ def llama_stack_client(request, provider_data):
return LlamaStackClient(
base_url=base_url,
provider_data=provider_data,
timeout=int(os.environ.get("LLAMA_STACK_CLIENT_TIMEOUT", "30")),
)
# check if this looks like a URL using proper URL parsing
@ -267,14 +271,17 @@ def cleanup_server_process(request):
print(f"Server process already terminated with return code: {server_process.returncode}")
return
try:
server_process.terminate()
print(f"Terminating process {server_process.pid} and its group...")
# Kill the entire process group
os.killpg(os.getpgid(server_process.pid), signal.SIGTERM)
server_process.wait(timeout=10)
print("Server process terminated gracefully")
print("Server process and children terminated gracefully")
except subprocess.TimeoutExpired:
print("Server process did not terminate gracefully, killing it")
server_process.kill()
# Force kill the entire process group
os.killpg(os.getpgid(server_process.pid), signal.SIGKILL)
server_process.wait()
print("Server process killed")
print("Server process and children killed")
except Exception as e:
print(f"Error during server cleanup: {e}")
else:

View file

@ -14,8 +14,7 @@ from llama_stack.distribution.access_control.access_control import default_polic
from llama_stack.distribution.datatypes import User
from llama_stack.providers.utils.sqlstore.api import ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig, SqliteSqlStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig, SqliteSqlStoreConfig, sqlstore_impl
def get_postgres_config():
@ -30,144 +29,213 @@ def get_postgres_config():
def get_sqlite_config():
"""Get SQLite configuration with temporary database."""
tmp_file = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
tmp_file.close()
return SqliteSqlStoreConfig(db_path=tmp_file.name), tmp_file.name
"""Get SQLite configuration with temporary file database."""
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
temp_file.close()
return SqliteSqlStoreConfig(db_path=temp_file.name)
# Backend configurations for parametrized tests
BACKEND_CONFIGS = [
pytest.param(
get_postgres_config,
marks=pytest.mark.skipif(
not os.environ.get("ENABLE_POSTGRES_TESTS"),
reason="PostgreSQL tests require ENABLE_POSTGRES_TESTS environment variable",
),
id="postgres",
),
pytest.param(get_sqlite_config, id="sqlite"),
]
@pytest.fixture
def authorized_store(backend_config):
"""Set up authorized store with proper cleanup."""
config_func = backend_config
config = config_func()
base_sqlstore = sqlstore_impl(config)
authorized_store = AuthorizedSqlStore(base_sqlstore)
yield authorized_store
if hasattr(config, "db_path"):
try:
os.unlink(config.db_path)
except (OSError, FileNotFoundError):
pass
async def create_test_table(authorized_store, table_name):
"""Create a test table with standard schema."""
await authorized_store.create_table(
table=table_name,
schema={
"id": ColumnType.STRING,
"data": ColumnType.STRING,
},
)
async def cleanup_records(sql_store, table_name, record_ids):
"""Clean up test records."""
for record_id in record_ids:
try:
await sql_store.delete(table_name, {"id": record_id})
except Exception:
pass
@pytest.mark.asyncio
@pytest.mark.parametrize(
"backend_config",
[
pytest.param(
("postgres", get_postgres_config),
marks=pytest.mark.skipif(
not os.environ.get("ENABLE_POSTGRES_TESTS"),
reason="PostgreSQL tests require ENABLE_POSTGRES_TESTS environment variable",
),
id="postgres",
),
pytest.param(("sqlite", get_sqlite_config), id="sqlite"),
],
)
@pytest.mark.parametrize("backend_config", BACKEND_CONFIGS)
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
async def test_json_comparison(mock_get_authenticated_user, backend_config):
async def test_authorized_store_attributes(mock_get_authenticated_user, authorized_store, request):
"""Test that JSON column comparisons work correctly for both PostgreSQL and SQLite"""
backend_name, config_func = backend_config
backend_name = request.node.callspec.id
# Handle different config types
if backend_name == "postgres":
config = config_func()
cleanup_path = None
else: # sqlite
config, cleanup_path = config_func()
# Create test table
table_name = f"test_json_comparison_{backend_name}"
await create_test_table(authorized_store, table_name)
try:
base_sqlstore = SqlAlchemySqlStoreImpl(config)
authorized_store = AuthorizedSqlStore(base_sqlstore)
# Test with no authenticated user (should handle JSON null comparison)
mock_get_authenticated_user.return_value = None
# Create test table
table_name = f"test_json_comparison_{backend_name}"
await authorized_store.create_table(
table=table_name,
schema={
"id": ColumnType.STRING,
"data": ColumnType.STRING,
},
# Insert some test data
await authorized_store.insert(table_name, {"id": "1", "data": "public_data"})
# Test fetching with no user - should not error on JSON comparison
result = await authorized_store.fetch_all(table_name, policy=default_policy())
assert len(result.data) == 1
assert result.data[0]["id"] == "1"
assert result.data[0]["access_attributes"] is None
# Test with authenticated user
test_user = User("test-user", {"roles": ["admin"]})
mock_get_authenticated_user.return_value = test_user
# Insert data with user attributes
await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"})
# Fetch all - admin should see both
result = await authorized_store.fetch_all(table_name, policy=default_policy())
assert len(result.data) == 2
# Test with non-admin user
regular_user = User("regular-user", {"roles": ["user"]})
mock_get_authenticated_user.return_value = regular_user
# Should only see public record
result = await authorized_store.fetch_all(table_name, policy=default_policy())
assert len(result.data) == 1
assert result.data[0]["id"] == "1"
# Test the category missing branch: user with multiple attributes
multi_user = User("multi-user", {"roles": ["admin"], "teams": ["dev"]})
mock_get_authenticated_user.return_value = multi_user
# Insert record with multi-user (has both roles and teams)
await authorized_store.insert(table_name, {"id": "3", "data": "multi_user_data"})
# Test different user types to create records with different attribute patterns
# Record with only roles (teams category will be missing)
roles_only_user = User("roles-user", {"roles": ["admin"]})
mock_get_authenticated_user.return_value = roles_only_user
await authorized_store.insert(table_name, {"id": "4", "data": "roles_only_data"})
# Record with only teams (roles category will be missing)
teams_only_user = User("teams-user", {"teams": ["dev"]})
mock_get_authenticated_user.return_value = teams_only_user
await authorized_store.insert(table_name, {"id": "5", "data": "teams_only_data"})
# Record with different roles/teams (shouldn't match our test user)
different_user = User("different-user", {"roles": ["user"], "teams": ["qa"]})
mock_get_authenticated_user.return_value = different_user
await authorized_store.insert(table_name, {"id": "6", "data": "different_user_data"})
# Now test with the multi-user who has both roles=admin and teams=dev
mock_get_authenticated_user.return_value = multi_user
result = await authorized_store.fetch_all(table_name, policy=default_policy())
# Should see:
# - public record (1) - no access_attributes
# - admin record (2) - user matches roles=admin, teams missing (allowed)
# - multi_user record (3) - user matches both roles=admin and teams=dev
# - roles_only record (4) - user matches roles=admin, teams missing (allowed)
# - teams_only record (5) - user matches teams=dev, roles missing (allowed)
# Should NOT see:
# - different_user record (6) - user doesn't match roles=user or teams=qa
expected_ids = {"1", "2", "3", "4", "5"}
actual_ids = {record["id"] for record in result.data}
assert actual_ids == expected_ids, f"Expected to see records {expected_ids} but got {actual_ids}"
# Verify the category missing logic specifically
# Records 4 and 5 test the "category missing" branch where one attribute category is missing
category_test_ids = {record["id"] for record in result.data if record["id"] in ["4", "5"]}
assert category_test_ids == {"4", "5"}, (
f"Category missing logic failed: expected 4,5 but got {category_test_ids}"
)
try:
# Test with no authenticated user (should handle JSON null comparison)
mock_get_authenticated_user.return_value = None
finally:
# Clean up records
await cleanup_records(authorized_store.sql_store, table_name, ["1", "2", "3", "4", "5", "6"])
# Insert some test data
await authorized_store.insert(table_name, {"id": "1", "data": "public_data"})
# Test fetching with no user - should not error on JSON comparison
result = await authorized_store.fetch_all(table_name, policy=default_policy())
assert len(result.data) == 1
assert result.data[0]["id"] == "1"
assert result.data[0]["access_attributes"] is None
@pytest.mark.asyncio
@pytest.mark.parametrize("backend_config", BACKEND_CONFIGS)
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
async def test_user_ownership_policy(mock_get_authenticated_user, authorized_store, request):
"""Test that 'user is owner' policies work correctly with record ownership"""
from llama_stack.distribution.access_control.datatypes import AccessRule, Action, Scope
# Test with authenticated user
test_user = User("test-user", {"roles": ["admin"]})
mock_get_authenticated_user.return_value = test_user
backend_name = request.node.callspec.id
# Insert data with user attributes
await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"})
# Create test table
table_name = f"test_ownership_{backend_name}"
await create_test_table(authorized_store, table_name)
# Fetch all - admin should see both
result = await authorized_store.fetch_all(table_name, policy=default_policy())
assert len(result.data) == 2
try:
# Test with first user who creates records
user1 = User("user1", {"roles": ["admin"]})
mock_get_authenticated_user.return_value = user1
# Test with non-admin user
regular_user = User("regular-user", {"roles": ["user"]})
mock_get_authenticated_user.return_value = regular_user
# Insert a record owned by user1
await authorized_store.insert(table_name, {"id": "1", "data": "user1_data"})
# Should only see public record
result = await authorized_store.fetch_all(table_name, policy=default_policy())
assert len(result.data) == 1
assert result.data[0]["id"] == "1"
# Test with second user
user2 = User("user2", {"roles": ["user"]})
mock_get_authenticated_user.return_value = user2
# Test the category missing branch: user with multiple attributes
multi_user = User("multi-user", {"roles": ["admin"], "teams": ["dev"]})
mock_get_authenticated_user.return_value = multi_user
# Insert a record owned by user2
await authorized_store.insert(table_name, {"id": "2", "data": "user2_data"})
# Insert record with multi-user (has both roles and teams)
await authorized_store.insert(table_name, {"id": "3", "data": "multi_user_data"})
# Create a policy that only allows access when user is the owner
owner_only_policy = [
AccessRule(
permit=Scope(actions=[Action.READ]),
when=["user is owner"],
),
]
# Test different user types to create records with different attribute patterns
# Record with only roles (teams category will be missing)
roles_only_user = User("roles-user", {"roles": ["admin"]})
mock_get_authenticated_user.return_value = roles_only_user
await authorized_store.insert(table_name, {"id": "4", "data": "roles_only_data"})
# Test user1 access - should only see their own record
mock_get_authenticated_user.return_value = user1
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
assert len(result.data) == 1, f"Expected user1 to see 1 record, got {len(result.data)}"
assert result.data[0]["id"] == "1", f"Expected user1's record, got {result.data[0]['id']}"
# Record with only teams (roles category will be missing)
teams_only_user = User("teams-user", {"teams": ["dev"]})
mock_get_authenticated_user.return_value = teams_only_user
await authorized_store.insert(table_name, {"id": "5", "data": "teams_only_data"})
# Test user2 access - should only see their own record
mock_get_authenticated_user.return_value = user2
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
assert len(result.data) == 1, f"Expected user2 to see 1 record, got {len(result.data)}"
assert result.data[0]["id"] == "2", f"Expected user2's record, got {result.data[0]['id']}"
# Record with different roles/teams (shouldn't match our test user)
different_user = User("different-user", {"roles": ["user"], "teams": ["qa"]})
mock_get_authenticated_user.return_value = different_user
await authorized_store.insert(table_name, {"id": "6", "data": "different_user_data"})
# Now test with the multi-user who has both roles=admin and teams=dev
mock_get_authenticated_user.return_value = multi_user
result = await authorized_store.fetch_all(table_name, policy=default_policy())
# Should see:
# - public record (1) - no access_attributes
# - admin record (2) - user matches roles=admin, teams missing (allowed)
# - multi_user record (3) - user matches both roles=admin and teams=dev
# - roles_only record (4) - user matches roles=admin, teams missing (allowed)
# - teams_only record (5) - user matches teams=dev, roles missing (allowed)
# Should NOT see:
# - different_user record (6) - user doesn't match roles=user or teams=qa
expected_ids = {"1", "2", "3", "4", "5"}
actual_ids = {record["id"] for record in result.data}
assert actual_ids == expected_ids, f"Expected to see records {expected_ids} but got {actual_ids}"
# Verify the category missing logic specifically
# Records 4 and 5 test the "category missing" branch where one attribute category is missing
category_test_ids = {record["id"] for record in result.data if record["id"] in ["4", "5"]}
assert category_test_ids == {"4", "5"}, (
f"Category missing logic failed: expected 4,5 but got {category_test_ids}"
)
finally:
# Clean up records
for record_id in ["1", "2", "3", "4", "5", "6"]:
try:
await base_sqlstore.delete(table_name, {"id": record_id})
except Exception:
pass
# Test with anonymous user - should see no records
mock_get_authenticated_user.return_value = None
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
assert len(result.data) == 0, f"Expected anonymous user to see 0 records, got {len(result.data)}"
finally:
# Clean up temporary SQLite database file if needed
if cleanup_path:
try:
os.unlink(cleanup_path)
except OSError:
pass
# Clean up records
await cleanup_records(authorized_store.sql_store, table_name, ["1", "2"])

View file

@ -8,10 +8,18 @@ import random
import numpy as np
import pytest
from pymilvus import MilvusClient, connections
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, ChunkMetadata
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
EMBEDDING_DIMENSION = 384
COLLECTION_PREFIX = "test_collection"
MILVUS_ALIAS = "test_milvus"
@pytest.fixture
@ -50,7 +58,156 @@ def sample_chunks():
return sample
@pytest.fixture(scope="session")
def sample_chunks_with_metadata():
"""Generates chunks that force multiple batches for a single document to expose ID conflicts."""
n, k = 10, 3
sample = [
Chunk(
content=f"Sentence {i} from document {j}",
metadata={"document_id": f"document-{j}"},
chunk_metadata=ChunkMetadata(
document_id=f"document-{j}",
chunk_id=f"document-{j}-chunk-{i}",
source=f"example source-{j}-{i}",
),
)
for j in range(k)
for i in range(n)
]
return sample
@pytest.fixture(scope="session")
def sample_embeddings(sample_chunks):
np.random.seed(42)
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks])
@pytest.fixture(scope="session")
def sample_embeddings_with_metadata(sample_chunks_with_metadata):
np.random.seed(42)
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks_with_metadata])
@pytest.fixture(params=["milvus", "sqlite_vec"])
def vector_provider(request):
return request.param
@pytest.fixture(scope="session")
def mock_inference_api(embedding_dimension):
class MockInferenceAPI:
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
return [np.random.rand(embedding_dimension).astype(np.float32).tolist() for _ in texts]
return MockInferenceAPI()
@pytest.fixture
async def unique_kvstore_config(tmp_path_factory):
# Generate a unique filename for this test
unique_id = f"test_kv_{np.random.randint(1e6)}"
temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / f"{unique_id}.db")
return SqliteKVStoreConfig(db_path=db_path)
@pytest.fixture(scope="session")
def sqlite_vec_db_path(tmp_path_factory):
db_path = str(tmp_path_factory.getbasetemp() / "test.db")
return db_path
@pytest.fixture
async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory):
temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / f"test_sqlite_vec_{np.random.randint(1e6)}.db")
bank_id = f"sqlite_vec_bank_{np.random.randint(1e6)}"
index = SQLiteVecIndex(embedding_dimension, db_path, bank_id)
await index.initialize()
index.db_path = db_path
yield index
index.delete()
@pytest.fixture
async def sqlite_vec_adapter(sqlite_vec_db_path, mock_inference_api, embedding_dimension):
config = SQLiteVectorIOConfig(
db_path=sqlite_vec_db_path,
kvstore=SqliteKVStoreConfig(),
)
adapter = SQLiteVecVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
collection_id = f"sqlite_test_collection_{np.random.randint(1e6)}"
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
identifier=collection_id,
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=embedding_dimension,
)
)
adapter.test_collection_id = collection_id
yield adapter
await adapter.shutdown()
@pytest.fixture(scope="session")
def milvus_vec_db_path(tmp_path_factory):
db_path = str(tmp_path_factory.getbasetemp() / "test_milvus.db")
return db_path
@pytest.fixture
async def milvus_vec_index(milvus_vec_db_path, embedding_dimension):
client = MilvusClient(milvus_vec_db_path)
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_db_path)
index = MilvusIndex(client, name, consistency_level="Strong")
index.db_path = milvus_vec_db_path
yield index
@pytest.fixture
async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api):
config = MilvusVectorIOConfig(
db_path=milvus_vec_db_path,
kvstore=SqliteKVStoreConfig(),
)
adapter = MilvusVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
identifier=adapter.metadata_collection_name,
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=128,
)
)
yield adapter
await adapter.shutdown()
@pytest.fixture
def vector_io_adapter(vector_provider, request):
"""Returns the appropriate vector IO adapter based on the provider parameter."""
if vector_provider == "milvus":
return request.getfixturevalue("milvus_vec_adapter")
else:
return request.getfixturevalue("sqlite_vec_adapter")
@pytest.fixture
def vector_index(vector_provider, request):
"""Returns appropriate vector index based on provider parameter"""
return request.getfixturevalue(f"{vector_provider}_vec_index")

View file

@ -34,7 +34,7 @@ def loop():
return asyncio.new_event_loop()
@pytest_asyncio.fixture(scope="session", autouse=True)
@pytest_asyncio.fixture
async def sqlite_vec_index(embedding_dimension, tmp_path_factory):
temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / "test_sqlite.db")
@ -44,38 +44,15 @@ async def sqlite_vec_index(embedding_dimension, tmp_path_factory):
@pytest.mark.asyncio
async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=2)
connection = _create_sqlite_connection(sqlite_vec_index.db_path)
cur = connection.cursor()
cur.execute(f"SELECT COUNT(*) FROM {sqlite_vec_index.metadata_table}")
count = cur.fetchone()[0]
assert count == len(sample_chunks)
cur.close()
connection.close()
@pytest.mark.asyncio
async def test_query_chunks_vector(sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
response = await sqlite_vec_index.query_vector(query_embedding, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
@pytest.mark.xfail(reason="Chunk Metadata not yet supported for SQLite-vec", strict=True)
async def test_query_chunk_metadata(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
query_embedding = sample_embeddings[0]
response = await sqlite_vec_index.query_vector(query_embedding, k=2, score_threshold=0.0)
assert response.chunks[-1].chunk_metadata == sample_chunks[-1].chunk_metadata
async def test_query_chunk_metadata(sqlite_vec_index, sample_chunks_with_metadata, sample_embeddings_with_metadata):
await sqlite_vec_index.add_chunks(sample_chunks_with_metadata, sample_embeddings_with_metadata)
response = await sqlite_vec_index.query_vector(sample_embeddings_with_metadata[-1], k=2, score_threshold=0.0)
assert response.chunks[0].chunk_metadata == sample_chunks_with_metadata[-1].chunk_metadata
@pytest.mark.asyncio
async def test_query_chunks_full_text_search(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
query_string = "Sentence 5"
response = await sqlite_vec_index.query_keyword(k=3, score_threshold=0.0, query_string=query_string)
@ -148,7 +125,7 @@ async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dime
assert len(chunk_ids) == len(set(chunk_ids)), "Duplicate chunk IDs detected across batches!"
@pytest_asyncio.fixture(scope="session")
@pytest.fixture(scope="session")
async def sqlite_vec_adapter(sqlite_connection):
config = type("Config", (object,), {"db_path": ":memory:"}) # Mock config with in-memory database
adapter = SQLiteVecVectorIOAdapter(config=config, inference_api=None)

View file

@ -4,253 +4,142 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
import time
from unittest.mock import AsyncMock
import numpy as np
import pytest
import pytest_asyncio
from pymilvus import Collection, MilvusClient, connections
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig
from llama_stack.providers.remote.vector_io.milvus.milvus import VECTOR_DBS_PREFIX, MilvusIndex, MilvusVectorIOAdapter
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.remote.vector_io.milvus.milvus import VECTOR_DBS_PREFIX
# TODO: Refactor these to be for inline vector-io providers
MILVUS_ALIAS = "test_milvus"
COLLECTION_PREFIX = "test_collection"
@pytest.fixture(scope="session")
def loop():
return asyncio.new_event_loop()
@pytest.fixture(scope="session")
def mock_inference_api(embedding_dimension):
class MockInferenceAPI:
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
return [np.random.rand(embedding_dimension).astype(np.float32).tolist() for _ in texts]
return MockInferenceAPI()
@pytest_asyncio.fixture
async def unique_kvstore_config(tmp_path_factory):
# Generate a unique filename for this test
unique_id = f"test_kv_{np.random.randint(1e6)}"
temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / f"{unique_id}.db")
return SqliteKVStoreConfig(db_path=db_path)
@pytest_asyncio.fixture(scope="session", autouse=True)
async def milvus_vec_index(embedding_dimension, tmp_path_factory):
temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / "test_milvus.db")
client = MilvusClient(db_path)
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
connections.connect(alias=MILVUS_ALIAS, uri=db_path)
index = MilvusIndex(client, name, consistency_level="Strong")
index.db_path = db_path
yield index
@pytest_asyncio.fixture(scope="session")
async def milvus_vec_adapter(milvus_vec_index, mock_inference_api):
config = MilvusVectorIOConfig(
db_path=milvus_vec_index.db_path,
kvstore=SqliteKVStoreConfig(),
)
adapter = MilvusVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
identifier=adapter.metadata_collection_name,
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=128,
)
)
yield adapter
await adapter.shutdown()
# This test is a unit test for the inline VectoerIO providers. This should only contain
# tests which are specific to this class. More general (API-level) tests should be placed in
# tests/integration/vector_io/
#
# How to run this test:
#
# pytest tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py \
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
@pytest.mark.asyncio
async def test_cache_contains_initial_collection(milvus_vec_adapter):
coll_name = milvus_vec_adapter.metadata_collection_name
assert coll_name in milvus_vec_adapter.cache
async def test_initialize_index(vector_index):
await vector_index.initialize()
@pytest.mark.asyncio
async def test_add_chunks(milvus_vec_index, sample_chunks, sample_embeddings):
await milvus_vec_index.add_chunks(sample_chunks, sample_embeddings)
resp = await milvus_vec_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1)
async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embeddings):
vector_index.delete()
vector_index.initialize()
await vector_index.add_chunks(sample_chunks, sample_embeddings)
resp = await vector_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1)
assert resp.chunks[0].content == sample_chunks[0].content
vector_index.delete()
@pytest.mark.asyncio
async def test_query_chunks_vector(milvus_vec_index, sample_chunks, sample_embeddings, embedding_dimension):
await milvus_vec_index.add_chunks(sample_chunks, sample_embeddings)
query_emb = np.random.rand(embedding_dimension).astype(np.float32)
resp = await milvus_vec_index.query_vector(query_emb, k=2, score_threshold=0.0)
assert isinstance(resp, QueryChunksResponse)
assert len(resp.chunks) == 2
@pytest.mark.asyncio
async def test_chunk_id_conflict(milvus_vec_index, sample_chunks, embedding_dimension):
async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimension):
embeddings = np.random.rand(len(sample_chunks), embedding_dimension).astype(np.float32)
await milvus_vec_index.add_chunks(sample_chunks, embeddings)
coll = Collection(milvus_vec_index.collection_name, using=MILVUS_ALIAS)
ids = coll.query(expr="id >= 0", output_fields=["id"], timeout=30)
flat_ids = [i["id"] for i in ids]
assert len(flat_ids) == len(set(flat_ids))
await vector_index.add_chunks(sample_chunks, embeddings)
resp = await vector_index.query_vector(
np.random.rand(embedding_dimension).astype(np.float32),
k=len(sample_chunks),
score_threshold=-1,
)
contents = [chunk.content for chunk in resp.chunks]
assert len(contents) == len(set(contents))
@pytest.mark.asyncio
async def test_initialize_with_milvus_client(milvus_vec_index, unique_kvstore_config):
kvstore = await kvstore_impl(unique_kvstore_config)
vector_db = VectorDB(
identifier="test_db",
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=128,
metadata={"test_key": "test_value"},
)
test_vector_db_data = vector_db.model_dump_json()
await kvstore.set(f"{VECTOR_DBS_PREFIX}test_db", test_vector_db_data)
tmp_milvus_vec_adapter = MilvusVectorIOAdapter(
config=MilvusVectorIOConfig(
db_path=milvus_vec_index.db_path,
kvstore=unique_kvstore_config,
),
inference_api=None,
files_api=None,
)
await tmp_milvus_vec_adapter.initialize()
vector_db = VectorDB(
identifier="test_db",
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=128,
)
test_vector_db_data = vector_db.model_dump_json()
await tmp_milvus_vec_adapter.kvstore.set(f"{VECTOR_DBS_PREFIX}/test_db", test_vector_db_data)
assert milvus_vec_index.client is not None
assert isinstance(milvus_vec_index.client, MilvusClient)
assert tmp_milvus_vec_adapter.cache is not None
# registering a vector won't update the cache or openai_vector_store collection name
assert (
tmp_milvus_vec_adapter.metadata_collection_name not in tmp_milvus_vec_adapter.cache
or tmp_milvus_vec_adapter.openai_vector_stores
)
@pytest.mark.asyncio
async def test_persistence_across_adapter_restarts(
tmp_path, milvus_vec_index, mock_inference_api, unique_kvstore_config
):
adapter1 = MilvusVectorIOAdapter(
config=MilvusVectorIOConfig(db_path=milvus_vec_index.db_path, kvstore=unique_kvstore_config),
inference_api=mock_inference_api,
files_api=None,
)
await adapter1.initialize()
async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter):
key = f"{VECTOR_DBS_PREFIX}db1"
dummy = VectorDB(
identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
)
await adapter1.register_vector_db(dummy)
await adapter1.shutdown()
await vector_io_adapter.kvstore.set(key=key, value=json.dumps(dummy.model_dump()))
await adapter1.initialize()
assert "foo_db" in adapter1.cache
await adapter1.shutdown()
await vector_io_adapter.initialize()
@pytest.mark.asyncio
async def test_register_and_unregister_vector_db(milvus_vec_adapter):
try:
connections.disconnect(MILVUS_ALIAS)
except Exception as _:
pass
async def test_persistence_across_adapter_restarts(vector_io_adapter):
await vector_io_adapter.initialize()
dummy = VectorDB(
identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
)
await vector_io_adapter.register_vector_db(dummy)
await vector_io_adapter.shutdown()
connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_adapter.config.db_path)
await vector_io_adapter.initialize()
assert "foo_db" in vector_io_adapter.cache
await vector_io_adapter.shutdown()
@pytest.mark.asyncio
async def test_register_and_unregister_vector_db(vector_io_adapter):
unique_id = f"foo_db_{np.random.randint(1e6)}"
dummy = VectorDB(
identifier=unique_id, provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
)
await milvus_vec_adapter.register_vector_db(dummy)
assert dummy.identifier in milvus_vec_adapter.cache
if dummy.identifier in milvus_vec_adapter.cache:
index = milvus_vec_adapter.cache[dummy.identifier].index
if hasattr(index, "client") and hasattr(index.client, "_using"):
index.client._using = MILVUS_ALIAS
await milvus_vec_adapter.unregister_vector_db(dummy.identifier)
assert dummy.identifier not in milvus_vec_adapter.cache
await vector_io_adapter.register_vector_db(dummy)
assert dummy.identifier in vector_io_adapter.cache
await vector_io_adapter.unregister_vector_db(dummy.identifier)
assert dummy.identifier not in vector_io_adapter.cache
@pytest.mark.asyncio
async def test_query_unregistered_raises(milvus_vec_adapter):
async def test_query_unregistered_raises(vector_io_adapter):
fake_emb = np.zeros(8, dtype=np.float32)
with pytest.raises(AttributeError):
await milvus_vec_adapter.query_chunks("no_such_db", fake_emb)
with pytest.raises(ValueError):
await vector_io_adapter.query_chunks("no_such_db", fake_emb)
@pytest.mark.asyncio
async def test_insert_chunks_calls_underlying_index(milvus_vec_adapter):
async def test_insert_chunks_calls_underlying_index(vector_io_adapter):
fake_index = AsyncMock()
milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index)
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index)
chunks = ["chunk1", "chunk2"]
await milvus_vec_adapter.insert_chunks("db1", chunks)
await vector_io_adapter.insert_chunks("db1", chunks)
fake_index.insert_chunks.assert_awaited_once_with(chunks)
@pytest.mark.asyncio
async def test_insert_chunks_missing_db_raises(milvus_vec_adapter):
milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
async def test_insert_chunks_missing_db_raises(vector_io_adapter):
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
with pytest.raises(ValueError):
await milvus_vec_adapter.insert_chunks("db_not_exist", [])
await vector_io_adapter.insert_chunks("db_not_exist", [])
@pytest.mark.asyncio
async def test_query_chunks_calls_underlying_index_and_returns(milvus_vec_adapter):
async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter):
expected = QueryChunksResponse(chunks=[Chunk(content="c1")], scores=[0.1])
fake_index = AsyncMock(query_chunks=AsyncMock(return_value=expected))
milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index)
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index)
response = await milvus_vec_adapter.query_chunks("db1", "my_query", {"param": 1})
response = await vector_io_adapter.query_chunks("db1", "my_query", {"param": 1})
fake_index.query_chunks.assert_awaited_once_with("my_query", {"param": 1})
assert response is expected
@pytest.mark.asyncio
async def test_query_chunks_missing_db_raises(milvus_vec_adapter):
milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
async def test_query_chunks_missing_db_raises(vector_io_adapter):
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
with pytest.raises(ValueError):
await milvus_vec_adapter.query_chunks("db_missing", "q", None)
await vector_io_adapter.query_chunks("db_missing", "q", None)
@pytest.mark.asyncio
async def test_save_openai_vector_store(milvus_vec_adapter):
async def test_save_openai_vector_store(vector_io_adapter):
store_id = "vs_1234"
openai_vector_store = {
"id": store_id,
@ -260,14 +149,14 @@ async def test_save_openai_vector_store(milvus_vec_adapter):
"embedding_model": "test_model",
}
await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store)
await vector_io_adapter._save_openai_vector_store(store_id, openai_vector_store)
assert openai_vector_store["id"] in milvus_vec_adapter.openai_vector_stores
assert milvus_vec_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store
assert openai_vector_store["id"] in vector_io_adapter.openai_vector_stores
assert vector_io_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store
@pytest.mark.asyncio
async def test_update_openai_vector_store(milvus_vec_adapter):
async def test_update_openai_vector_store(vector_io_adapter):
store_id = "vs_1234"
openai_vector_store = {
"id": store_id,
@ -277,14 +166,14 @@ async def test_update_openai_vector_store(milvus_vec_adapter):
"embedding_model": "test_model",
}
await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store)
await vector_io_adapter._save_openai_vector_store(store_id, openai_vector_store)
openai_vector_store["description"] = "Updated description"
await milvus_vec_adapter._update_openai_vector_store(store_id, openai_vector_store)
assert milvus_vec_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store
await vector_io_adapter._update_openai_vector_store(store_id, openai_vector_store)
assert vector_io_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store
@pytest.mark.asyncio
async def test_delete_openai_vector_store(milvus_vec_adapter):
async def test_delete_openai_vector_store(vector_io_adapter):
store_id = "vs_1234"
openai_vector_store = {
"id": store_id,
@ -294,13 +183,13 @@ async def test_delete_openai_vector_store(milvus_vec_adapter):
"embedding_model": "test_model",
}
await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store)
await milvus_vec_adapter._delete_openai_vector_store_from_storage(store_id)
assert openai_vector_store["id"] not in milvus_vec_adapter.openai_vector_stores
await vector_io_adapter._save_openai_vector_store(store_id, openai_vector_store)
await vector_io_adapter._delete_openai_vector_store_from_storage(store_id)
assert openai_vector_store["id"] not in vector_io_adapter.openai_vector_stores
@pytest.mark.asyncio
async def test_load_openai_vector_stores(milvus_vec_adapter):
async def test_load_openai_vector_stores(vector_io_adapter):
store_id = "vs_1234"
openai_vector_store = {
"id": store_id,
@ -310,13 +199,13 @@ async def test_load_openai_vector_stores(milvus_vec_adapter):
"embedding_model": "test_model",
}
await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store)
loaded_stores = await milvus_vec_adapter._load_openai_vector_stores()
await vector_io_adapter._save_openai_vector_store(store_id, openai_vector_store)
loaded_stores = await vector_io_adapter._load_openai_vector_stores()
assert loaded_stores[store_id] == openai_vector_store
@pytest.mark.asyncio
async def test_save_openai_vector_store_file(milvus_vec_adapter, tmp_path_factory):
async def test_save_openai_vector_store_file(vector_io_adapter, tmp_path_factory):
store_id = "vs_1234"
file_id = "file_1234"
@ -334,11 +223,11 @@ async def test_save_openai_vector_store_file(milvus_vec_adapter, tmp_path_factor
]
# validating we don't raise an exception
await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
@pytest.mark.asyncio
async def test_update_openai_vector_store_file(milvus_vec_adapter, tmp_path_factory):
async def test_update_openai_vector_store_file(vector_io_adapter, tmp_path_factory):
store_id = "vs_1234"
file_id = "file_1234"
@ -355,24 +244,24 @@ async def test_update_openai_vector_store_file(milvus_vec_adapter, tmp_path_fact
{"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}}
]
await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
updated_file_info = file_info.copy()
updated_file_info["filename"] = "updated_test_file.txt"
await milvus_vec_adapter._update_openai_vector_store_file(
await vector_io_adapter._update_openai_vector_store_file(
store_id,
file_id,
updated_file_info,
)
loaded_contents = await milvus_vec_adapter._load_openai_vector_store_file(store_id, file_id)
loaded_contents = await vector_io_adapter._load_openai_vector_store_file(store_id, file_id)
assert loaded_contents == updated_file_info
assert loaded_contents != file_info
@pytest.mark.asyncio
async def test_load_openai_vector_store_file_contents(milvus_vec_adapter, tmp_path_factory):
async def test_load_openai_vector_store_file_contents(vector_io_adapter, tmp_path_factory):
store_id = "vs_1234"
file_id = "file_1234"
@ -389,14 +278,14 @@ async def test_load_openai_vector_store_file_contents(milvus_vec_adapter, tmp_pa
{"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}}
]
await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
loaded_contents = await milvus_vec_adapter._load_openai_vector_store_file_contents(store_id, file_id)
loaded_contents = await vector_io_adapter._load_openai_vector_store_file_contents(store_id, file_id)
assert loaded_contents == file_contents
@pytest.mark.asyncio
async def test_delete_openai_vector_store_file_from_storage(milvus_vec_adapter, tmp_path_factory):
async def test_delete_openai_vector_store_file_from_storage(vector_io_adapter, tmp_path_factory):
store_id = "vs_1234"
file_id = "file_1234"
@ -413,8 +302,8 @@ async def test_delete_openai_vector_store_file_from_storage(milvus_vec_adapter,
{"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}}
]
await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
await milvus_vec_adapter._delete_openai_vector_store_file_from_storage(store_id, file_id)
await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
await vector_io_adapter._delete_openai_vector_store_file_from_storage(store_id, file_id)
loaded_contents = await milvus_vec_adapter._load_openai_vector_store_file_contents(store_id, file_id)
loaded_contents = await vector_io_adapter._load_openai_vector_store_file_contents(store_id, file_id)
assert loaded_contents == []

View file

@ -153,7 +153,9 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
policy_ids = set()
for scenario in test_scenarios:
sql_record = SqlRecord(
record_id=scenario["id"], table_name="resources", access_attributes=scenario["access_attributes"]
record_id=scenario["id"],
table_name="resources",
owner=User(principal="test-user", attributes=scenario["access_attributes"]),
)
if is_action_allowed(policy, Action.READ, sql_record, user):

4
uv.lock generated
View file

@ -1365,8 +1365,8 @@ requires-dist = [
{ name = "llama-stack-client", specifier = ">=0.2.14" },
{ name = "llama-stack-client", marker = "extra == 'ui'", specifier = ">=0.2.14" },
{ name = "openai", specifier = ">=1.66" },
{ name = "opentelemetry-exporter-otlp-proto-http" },
{ name = "opentelemetry-sdk" },
{ name = "opentelemetry-exporter-otlp-proto-http", specifier = ">=1.30.0" },
{ name = "opentelemetry-sdk", specifier = ">=1.30.0" },
{ name = "pandas", marker = "extra == 'ui'" },
{ name = "pillow" },
{ name = "prompt-toolkit" },