From b3addc94d1f809a37dd09f186c04d423834a2fb0 Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Fri, 17 Oct 2025 16:24:15 -0400 Subject: [PATCH] adding back relevant vector_db files Signed-off-by: Francisco Javier Arceo fix tests Signed-off-by: Francisco Javier Arceo updating tests and fixing routing logic for single provider Signed-off-by: Francisco Javier Arceo setting default provider to update tests Signed-off-by: Francisco Javier Arceo updated provider_id Signed-off-by: Francisco Javier Arceo updated VectorStoreConfig to use (provider_id, embedding_model_id) and add defautl vector store provider Signed-off-by: Francisco Javier Arceo special handling for replay mode for available providers Signed-off-by: Francisco Javier Arceo --- .../workflows/integration-vector-io-tests.yml | 4 +- llama_stack/apis/datatypes.py | 1 + llama_stack/apis/vector_dbs/vector_dbs.py | 34 +- llama_stack/cli/stack/_build.py | 103 +----- llama_stack/cli/stack/build.py | 7 - llama_stack/core/datatypes.py | 6 +- llama_stack/core/distribution.py | 4 + llama_stack/core/resolver.py | 2 + llama_stack/core/routers/__init__.py | 2 + llama_stack/core/routers/vector_io.py | 33 +- llama_stack/core/routing_tables/common.py | 3 + llama_stack/core/routing_tables/vector_dbs.py | 323 ++++++++++++++++++ llama_stack/core/stack.py | 2 +- llama_stack/distributions/ci-tests/run.yaml | 2 +- .../distributions/starter-gpu/run.yaml | 2 +- llama_stack/distributions/starter/run.yaml | 2 +- llama_stack/distributions/starter/starter.py | 2 +- tests/integration/conftest.py | 69 ++++ tests/integration/fixtures/common.py | 2 +- .../vector_io/test_openai_vector_stores.py | 146 ++++++-- tests/integration/vector_io/test_vector_io.py | 57 +++- tests/unit/core/test_stack_validation.py | 4 +- .../test_single_provider_filter.py | 88 ----- 23 files changed, 637 insertions(+), 261 deletions(-) create mode 100644 llama_stack/core/routing_tables/vector_dbs.py delete mode 100644 tests/unit/distribution/test_single_provider_filter.py diff --git a/.github/workflows/integration-vector-io-tests.yml b/.github/workflows/integration-vector-io-tests.yml index fed54cadb..89dc64a45 100644 --- a/.github/workflows/integration-vector-io-tests.yml +++ b/.github/workflows/integration-vector-io-tests.yml @@ -144,7 +144,7 @@ jobs: - name: Build Llama Stack run: | - uv run --no-sync llama stack build --distro starter --image-type venv --single-provider "vector_io=${{ matrix.vector-io-provider }}" + uv run --no-sync llama stack build --template ci-tests --image-type venv - name: Check Storage and Memory Available Before Tests if: ${{ always() }} @@ -168,7 +168,7 @@ jobs: WEAVIATE_CLUSTER_URL: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'localhost:8080' || '' }} run: | uv run --no-sync \ - pytest -sv --stack-config ~/.llama/distributions/starter/starter-filtered-run.yaml \ + pytest -sv --stack-config="files=inline::localfs,inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \ tests/integration/vector_io - name: Check Storage and Memory Available After Tests diff --git a/llama_stack/apis/datatypes.py b/llama_stack/apis/datatypes.py index 8fbf21f3e..5777f3d04 100644 --- a/llama_stack/apis/datatypes.py +++ b/llama_stack/apis/datatypes.py @@ -121,6 +121,7 @@ class Api(Enum, metaclass=DynamicApiMeta): models = "models" shields = "shields" + vector_dbs = "vector_dbs" # only used for routing datasets = "datasets" scoring_functions = "scoring_functions" benchmarks = "benchmarks" diff --git a/llama_stack/apis/vector_dbs/vector_dbs.py b/llama_stack/apis/vector_dbs/vector_dbs.py index 53bf181e9..0368095cb 100644 --- a/llama_stack/apis/vector_dbs/vector_dbs.py +++ b/llama_stack/apis/vector_dbs/vector_dbs.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Literal +from typing import Literal, Protocol, runtime_checkable from pydantic import BaseModel @@ -59,3 +59,35 @@ class ListVectorDBsResponse(BaseModel): """ data: list[VectorDB] + + +@runtime_checkable +class VectorDBs(Protocol): + """Internal protocol for vector_dbs routing - no public API endpoints.""" + + async def list_vector_dbs(self) -> ListVectorDBsResponse: + """Internal method to list vector databases.""" + ... + + async def get_vector_db( + self, + vector_db_id: str, + ) -> VectorDB: + """Internal method to get a vector database by ID.""" + ... + + async def register_vector_db( + self, + vector_db_id: str, + embedding_model: str, + embedding_dimension: int | None = 384, + provider_id: str | None = None, + vector_db_name: str | None = None, + provider_vector_db_id: str | None = None, + ) -> VectorDB: + """Internal method to register a vector database.""" + ... + + async def unregister_vector_db(self, vector_db_id: str) -> None: + """Internal method to unregister a vector database.""" + ... diff --git a/llama_stack/cli/stack/_build.py b/llama_stack/cli/stack/_build.py index e90ec2ef6..471d5cb66 100644 --- a/llama_stack/cli/stack/_build.py +++ b/llama_stack/cli/stack/_build.py @@ -50,85 +50,6 @@ from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig DISTRIBS_PATH = Path(__file__).parent.parent.parent / "distributions" -def _apply_single_provider_filter(build_config: BuildConfig, single_provider_arg: str) -> BuildConfig: - """Filter a distribution to only include specified providers for certain APIs.""" - # Parse the single-provider argument using the same logic as --providers - provider_filters: dict[str, str] = {} - for api_provider in single_provider_arg.split(","): - if "=" not in api_provider: - cprint( - "Could not parse `--single-provider`. Please ensure the list is in the format api1=provider1,api2=provider2", - color="red", - file=sys.stderr, - ) - sys.exit(1) - api, provider_type = api_provider.split("=") - provider_filters[api] = provider_type - - # Create a copy of the build config to modify - filtered_build_config = BuildConfig( - image_type=build_config.image_type, - image_name=build_config.image_name, - external_providers_dir=build_config.external_providers_dir, - external_apis_dir=build_config.external_apis_dir, - distribution_spec=DistributionSpec( - providers={}, - description=build_config.distribution_spec.description, - ), - ) - - # Copy all providers, but filter the specified APIs - for api, providers in build_config.distribution_spec.providers.items(): - if api in provider_filters: - target_provider_type = provider_filters[api] - filtered_providers = [p for p in providers if p.provider_type == target_provider_type] - if not filtered_providers: - cprint( - f"Provider {target_provider_type} not found in distribution for API {api}", - color="red", - file=sys.stderr, - ) - sys.exit(1) - filtered_build_config.distribution_spec.providers[api] = filtered_providers - else: - # Keep all providers for unfiltered APIs - filtered_build_config.distribution_spec.providers[api] = providers - - return filtered_build_config - - -def _generate_filtered_run_config( - build_config: BuildConfig, - build_dir: Path, - distro_name: str, -) -> Path: - """ - Generate a filtered run.yaml by starting with the original distribution's run.yaml - and filtering the providers according to the build_config. - """ - # Load the original distribution's run.yaml - distro_resource = importlib.resources.files("llama_stack") / f"distributions/{distro_name}/run.yaml" - - with importlib.resources.as_file(distro_resource) as path: - with open(path) as f: - original_config = yaml.safe_load(f) - - # Apply provider filtering to the loaded config - for api, providers in build_config.distribution_spec.providers.items(): - if api in original_config.get("providers", {}): - # Filter this API to only include the providers from build_config - provider_types = {p.provider_type for p in providers} - filtered_providers = [p for p in original_config["providers"][api] if p["provider_type"] in provider_types] - original_config["providers"][api] = filtered_providers - - # Write the filtered config - run_config_file = build_dir / f"{distro_name}-filtered-run.yaml" - with open(run_config_file, "w") as f: - yaml.dump(original_config, f, sort_keys=False) - - return run_config_file - - @lru_cache def available_distros_specs() -> dict[str, BuildConfig]: import yaml @@ -172,11 +93,6 @@ def run_stack_build_command(args: argparse.Namespace) -> None: ) sys.exit(1) build_config = available_distros[distro_name] - - # Apply single-provider filtering if specified - if args.single_provider: - build_config = _apply_single_provider_filter(build_config, args.single_provider) - if args.image_type: build_config.image_type = args.image_type else: @@ -329,7 +245,6 @@ def run_stack_build_command(args: argparse.Namespace) -> None: image_name=image_name, config_path=args.config, distro_name=distro_name, - is_filtered=bool(args.single_provider), ) except (Exception, RuntimeError) as exc: @@ -448,7 +363,6 @@ def _run_stack_build_command_from_build_config( image_name: str | None = None, distro_name: str | None = None, config_path: str | None = None, - is_filtered: bool = False, ) -> Path | Traversable: image_name = image_name or build_config.image_name if build_config.image_type == LlamaStackImageType.CONTAINER.value: @@ -521,19 +435,12 @@ def _run_stack_build_command_from_build_config( raise RuntimeError(f"Failed to build image {image_name}") if distro_name: - # If single-provider filtering was applied, generate a filtered run config - # Otherwise, copy run.yaml from distribution as before - if is_filtered: - run_config_file = _generate_filtered_run_config(build_config, build_dir, distro_name) - distro_path = run_config_file # Use the generated file as the distro_path - else: - # copy run.yaml from distribution to build_dir instead of generating it again - distro_resource = importlib.resources.files("llama_stack") / f"distributions/{distro_name}/run.yaml" - run_config_file = build_dir / f"{distro_name}-run.yaml" + # copy run.yaml from distribution to build_dir instead of generating it again + distro_path = importlib.resources.files("llama_stack") / f"distributions/{distro_name}/run.yaml" + run_config_file = build_dir / f"{distro_name}-run.yaml" - with importlib.resources.as_file(distro_resource) as path: - shutil.copy(path, run_config_file) - distro_path = run_config_file # Update distro_path to point to the copied file + with importlib.resources.as_file(distro_path) as path: + shutil.copy(path, run_config_file) cprint("Build Successful!", color="green", file=sys.stderr) cprint(f"You can find the newly-built distribution here: {run_config_file}", color="blue", file=sys.stderr) diff --git a/llama_stack/cli/stack/build.py b/llama_stack/cli/stack/build.py index a6c466c6f..80cf6fb38 100644 --- a/llama_stack/cli/stack/build.py +++ b/llama_stack/cli/stack/build.py @@ -92,13 +92,6 @@ the build. If not specified, currently active environment will be used if found. help="Build a config for a list of providers and only those providers. This list is formatted like: api1=provider1,api2=provider2. Where there can be multiple providers per API.", ) - self.parser.add_argument( - "--single-provider", - type=str, - default=None, - help="Limit a distribution to a single provider for specific APIs. Format: api1=provider1,api2=provider2. Use with --distro to filter an existing distribution.", - ) - def _run_stack_build_command(self, args: argparse.Namespace) -> None: # always keep implementation completely silo-ed away from CLI so CLI # can be fast to load and reduces dependencies diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index d1e782510..49035114c 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -354,10 +354,14 @@ class AuthenticationRequiredError(Exception): class VectorStoresConfig(BaseModel): """Configuration for vector stores in the stack.""" - default_embedding_model_id: str = Field( + embedding_model_id: str = Field( ..., description="ID of the embedding model to use as default for vector stores when none is specified. Must reference a model defined in the 'models' section.", ) + provider_id: str | None = Field( + default=None, + description="ID of the vector_io provider to use as default when multiple providers are available and none is specified.", + ) class QuotaPeriod(StrEnum): diff --git a/llama_stack/core/distribution.py b/llama_stack/core/distribution.py index 0e1f672c3..59461f5d6 100644 --- a/llama_stack/core/distribution.py +++ b/llama_stack/core/distribution.py @@ -63,6 +63,10 @@ def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]: routing_table_api=Api.tool_groups, router_api=Api.tool_runtime, ), + AutoRoutedApiInfo( + routing_table_api=Api.vector_dbs, + router_api=Api.vector_io, + ), ] diff --git a/llama_stack/core/resolver.py b/llama_stack/core/resolver.py index 73c047979..ba6d1d8f9 100644 --- a/llama_stack/core/resolver.py +++ b/llama_stack/core/resolver.py @@ -28,6 +28,7 @@ from llama_stack.apis.scoring_functions import ScoringFunctions from llama_stack.apis.shields import Shields from llama_stack.apis.telemetry import Telemetry from llama_stack.apis.tools import ToolGroups, ToolRuntime +from llama_stack.apis.vector_dbs import VectorDBs from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA from llama_stack.core.client import get_client_impl @@ -80,6 +81,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) -> Api.inspect: Inspect, Api.batches: Batches, Api.vector_io: VectorIO, + Api.vector_dbs: VectorDBs, Api.models: Models, Api.safety: Safety, Api.shields: Shields, diff --git a/llama_stack/core/routers/__init__.py b/llama_stack/core/routers/__init__.py index 6f9a4a9e0..a264a2ee5 100644 --- a/llama_stack/core/routers/__init__.py +++ b/llama_stack/core/routers/__init__.py @@ -26,6 +26,7 @@ async def get_routing_table_impl( from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable from ..routing_tables.shields import ShieldsRoutingTable from ..routing_tables.toolgroups import ToolGroupsRoutingTable + from ..routing_tables.vector_dbs import VectorDBsRoutingTable api_to_tables = { "models": ModelsRoutingTable, @@ -34,6 +35,7 @@ async def get_routing_table_impl( "scoring_functions": ScoringFunctionsRoutingTable, "benchmarks": BenchmarksRoutingTable, "tool_groups": ToolGroupsRoutingTable, + "vector_dbs": VectorDBsRoutingTable, } if api.value not in api_to_tables: diff --git a/llama_stack/core/routers/vector_io.py b/llama_stack/core/routers/vector_io.py index d559d16e0..fd9ec387e 100644 --- a/llama_stack/core/routers/vector_io.py +++ b/llama_stack/core/routers/vector_io.py @@ -31,7 +31,6 @@ from llama_stack.apis.vector_io import ( VectorStoreObject, VectorStoreSearchResponsePage, ) -from llama_stack.core.datatypes import VectorStoresConfig from llama_stack.log import get_logger from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable @@ -44,7 +43,7 @@ class VectorIORouter(VectorIO): def __init__( self, routing_table: RoutingTable, - vector_stores_config: VectorStoresConfig | None = None, + vector_stores_config=None, ) -> None: logger.debug("Initializing VectorIORouter") self.routing_table = routing_table @@ -125,9 +124,9 @@ class VectorIORouter(VectorIO): embedding_dimension = extra.get("embedding_dimension") provider_id = extra.get("provider_id") + # Use default embedding model if not specified if embedding_model is None and self.vector_stores_config is not None: - embedding_model = self.vector_stores_config.default_embedding_model_id - logger.debug(f"Using default embedding model: {embedding_model}") + embedding_model = self.vector_stores_config.embedding_model_id if embedding_model is not None and embedding_dimension is None: embedding_dimension = await self._get_embedding_model_dimension(embedding_model) @@ -139,11 +138,24 @@ class VectorIORouter(VectorIO): raise ValueError("No vector_io providers available") if num_providers > 1: available_providers = list(self.routing_table.impls_by_provider_id.keys()) - raise ValueError( - f"Multiple vector_io providers available. Please specify provider_id in extra_body. " - f"Available providers: {available_providers}" - ) - provider_id = list(self.routing_table.impls_by_provider_id.keys())[0] + # Use default configured provider + if self.vector_stores_config and self.vector_stores_config.provider_id: + default_provider = self.vector_stores_config.provider_id + if default_provider in available_providers: + provider_id = default_provider + logger.debug(f"Using configured default vector store provider: {provider_id}") + else: + raise ValueError( + f"Configured default vector store provider '{default_provider}' not found. " + f"Available providers: {available_providers}" + ) + else: + raise ValueError( + f"Multiple vector_io providers available. Please specify provider_id in extra_body. " + f"Available providers: {available_providers}" + ) + else: + provider_id = list(self.routing_table.impls_by_provider_id.keys())[0] vector_db_id = f"vs_{uuid.uuid4()}" registered_vector_db = await self.routing_table.register_vector_db( @@ -250,8 +262,7 @@ class VectorIORouter(VectorIO): vector_store_id: str, ) -> VectorStoreDeleteResponse: logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}") - provider = await self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_delete_vector_store(vector_store_id) + return await self.routing_table.openai_delete_vector_store(vector_store_id) async def openai_search_vector_store( self, diff --git a/llama_stack/core/routing_tables/common.py b/llama_stack/core/routing_tables/common.py index 8df0a89a9..087483bb6 100644 --- a/llama_stack/core/routing_tables/common.py +++ b/llama_stack/core/routing_tables/common.py @@ -134,12 +134,15 @@ class CommonRoutingTableImpl(RoutingTable): from .scoring_functions import ScoringFunctionsRoutingTable from .shields import ShieldsRoutingTable from .toolgroups import ToolGroupsRoutingTable + from .vector_dbs import VectorDBsRoutingTable def apiname_object(): if isinstance(self, ModelsRoutingTable): return ("Inference", "model") elif isinstance(self, ShieldsRoutingTable): return ("Safety", "shield") + elif isinstance(self, VectorDBsRoutingTable): + return ("VectorIO", "vector_db") elif isinstance(self, DatasetsRoutingTable): return ("DatasetIO", "dataset") elif isinstance(self, ScoringFunctionsRoutingTable): diff --git a/llama_stack/core/routing_tables/vector_dbs.py b/llama_stack/core/routing_tables/vector_dbs.py new file mode 100644 index 000000000..e87fb61c6 --- /dev/null +++ b/llama_stack/core/routing_tables/vector_dbs.py @@ -0,0 +1,323 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from pydantic import TypeAdapter + +from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError +from llama_stack.apis.models import ModelType +from llama_stack.apis.resource import ResourceType + +# Removed VectorDBs import to avoid exposing public API +from llama_stack.apis.vector_io.vector_io import ( + OpenAICreateVectorStoreRequestWithExtraBody, + SearchRankingOptions, + VectorStoreChunkingStrategy, + VectorStoreDeleteResponse, + VectorStoreFileContentsResponse, + VectorStoreFileDeleteResponse, + VectorStoreFileObject, + VectorStoreFileStatus, + VectorStoreObject, + VectorStoreSearchResponsePage, +) +from llama_stack.core.datatypes import ( + VectorDBWithOwner, +) +from llama_stack.log import get_logger + +from .common import CommonRoutingTableImpl, lookup_model + +logger = get_logger(name=__name__, category="core::routing_tables") + + +class VectorDBsRoutingTable(CommonRoutingTableImpl): + """Internal routing table for vector_db operations. + + Does not inherit from VectorDBs to avoid exposing public API endpoints. + Only provides internal routing functionality for VectorIORouter. + """ + + # Internal methods only - no public API exposure + + async def register_vector_db( + self, + vector_db_id: str, + embedding_model: str, + embedding_dimension: int | None = 384, + provider_id: str | None = None, + provider_vector_db_id: str | None = None, + vector_db_name: str | None = None, + ) -> Any: + if provider_id is None: + if len(self.impls_by_provider_id) > 0: + provider_id = list(self.impls_by_provider_id.keys())[0] + if len(self.impls_by_provider_id) > 1: + logger.warning( + f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}." + ) + else: + raise ValueError("No provider available. Please configure a vector_io provider.") + model = await lookup_model(self, embedding_model) + if model is None: + raise ModelNotFoundError(embedding_model) + if model.model_type != ModelType.embedding: + raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding) + if "embedding_dimension" not in model.metadata: + raise ValueError(f"Model {embedding_model} does not have an embedding dimension") + + try: + provider = self.impls_by_provider_id[provider_id] + except KeyError: + available_providers = list(self.impls_by_provider_id.keys()) + raise ValueError( + f"Provider '{provider_id}' not found in routing table. Available providers: {available_providers}" + ) from None + logger.warning( + "VectorDB is being deprecated in future releases in favor of VectorStore. Please migrate your usage accordingly." + ) + request = OpenAICreateVectorStoreRequestWithExtraBody( + name=vector_db_name or vector_db_id, + embedding_model=embedding_model, + embedding_dimension=model.metadata["embedding_dimension"], + provider_id=provider_id, + provider_vector_db_id=provider_vector_db_id, + ) + vector_store = await provider.openai_create_vector_store(request) + + vector_store_id = vector_store.id + actual_provider_vector_db_id = provider_vector_db_id or vector_store_id + logger.warning( + f"Ignoring vector_db_id {vector_db_id} and using vector_store_id {vector_store_id} instead. Setting VectorDB {vector_db_id} to VectorDB.vector_db_name" + ) + + vector_db_data = { + "identifier": vector_store_id, + "type": ResourceType.vector_db.value, + "provider_id": provider_id, + "provider_resource_id": actual_provider_vector_db_id, + "embedding_model": embedding_model, + "embedding_dimension": model.metadata["embedding_dimension"], + "vector_db_name": vector_store.name, + } + vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data) + await self.register_object(vector_db) + return vector_db + + async def openai_retrieve_vector_store( + self, + vector_store_id: str, + ) -> VectorStoreObject: + await self.assert_action_allowed("read", "vector_db", vector_store_id) + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_retrieve_vector_store(vector_store_id) + + async def openai_update_vector_store( + self, + vector_store_id: str, + name: str | None = None, + expires_after: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, + ) -> VectorStoreObject: + await self.assert_action_allowed("update", "vector_db", vector_store_id) + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_update_vector_store( + vector_store_id=vector_store_id, + name=name, + expires_after=expires_after, + metadata=metadata, + ) + + async def openai_delete_vector_store( + self, + vector_store_id: str, + ) -> VectorStoreDeleteResponse: + await self.assert_action_allowed("delete", "vector_db", vector_store_id) + provider = await self.get_provider_impl(vector_store_id) + result = await provider.openai_delete_vector_store(vector_store_id) + await self.unregister_vector_db(vector_store_id) + return result + + async def unregister_vector_db(self, vector_store_id: str) -> None: + """Remove the vector store from the routing table registry.""" + try: + vector_db_obj = await self.get_object_by_identifier("vector_db", vector_store_id) + if vector_db_obj: + await self.unregister_object(vector_db_obj) + except Exception as e: + # Log the error but don't fail the operation + logger.warning(f"Failed to unregister vector store {vector_store_id} from routing table: {e}") + + async def openai_search_vector_store( + self, + vector_store_id: str, + query: str | list[str], + filters: dict[str, Any] | None = None, + max_num_results: int | None = 10, + ranking_options: SearchRankingOptions | None = None, + rewrite_query: bool | None = False, + search_mode: str | None = "vector", + ) -> VectorStoreSearchResponsePage: + await self.assert_action_allowed("read", "vector_db", vector_store_id) + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_search_vector_store( + vector_store_id=vector_store_id, + query=query, + filters=filters, + max_num_results=max_num_results, + ranking_options=ranking_options, + rewrite_query=rewrite_query, + search_mode=search_mode, + ) + + async def openai_attach_file_to_vector_store( + self, + vector_store_id: str, + file_id: str, + attributes: dict[str, Any] | None = None, + chunking_strategy: VectorStoreChunkingStrategy | None = None, + ) -> VectorStoreFileObject: + await self.assert_action_allowed("update", "vector_db", vector_store_id) + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_attach_file_to_vector_store( + vector_store_id=vector_store_id, + file_id=file_id, + attributes=attributes, + chunking_strategy=chunking_strategy, + ) + + async def openai_list_files_in_vector_store( + self, + vector_store_id: str, + limit: int | None = 20, + order: str | None = "desc", + after: str | None = None, + before: str | None = None, + filter: VectorStoreFileStatus | None = None, + ) -> list[VectorStoreFileObject]: + await self.assert_action_allowed("read", "vector_db", vector_store_id) + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_list_files_in_vector_store( + vector_store_id=vector_store_id, + limit=limit, + order=order, + after=after, + before=before, + filter=filter, + ) + + async def openai_retrieve_vector_store_file( + self, + vector_store_id: str, + file_id: str, + ) -> VectorStoreFileObject: + await self.assert_action_allowed("read", "vector_db", vector_store_id) + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_retrieve_vector_store_file( + vector_store_id=vector_store_id, + file_id=file_id, + ) + + async def openai_retrieve_vector_store_file_contents( + self, + vector_store_id: str, + file_id: str, + ) -> VectorStoreFileContentsResponse: + await self.assert_action_allowed("read", "vector_db", vector_store_id) + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_retrieve_vector_store_file_contents( + vector_store_id=vector_store_id, + file_id=file_id, + ) + + async def openai_update_vector_store_file( + self, + vector_store_id: str, + file_id: str, + attributes: dict[str, Any], + ) -> VectorStoreFileObject: + await self.assert_action_allowed("update", "vector_db", vector_store_id) + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_update_vector_store_file( + vector_store_id=vector_store_id, + file_id=file_id, + attributes=attributes, + ) + + async def openai_delete_vector_store_file( + self, + vector_store_id: str, + file_id: str, + ) -> VectorStoreFileDeleteResponse: + await self.assert_action_allowed("delete", "vector_db", vector_store_id) + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_delete_vector_store_file( + vector_store_id=vector_store_id, + file_id=file_id, + ) + + async def openai_create_vector_store_file_batch( + self, + vector_store_id: str, + file_ids: list[str], + attributes: dict[str, Any] | None = None, + chunking_strategy: Any | None = None, + ): + await self.assert_action_allowed("update", "vector_db", vector_store_id) + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_create_vector_store_file_batch( + vector_store_id=vector_store_id, + file_ids=file_ids, + attributes=attributes, + chunking_strategy=chunking_strategy, + ) + + async def openai_retrieve_vector_store_file_batch( + self, + batch_id: str, + vector_store_id: str, + ): + await self.assert_action_allowed("read", "vector_db", vector_store_id) + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_retrieve_vector_store_file_batch( + batch_id=batch_id, + vector_store_id=vector_store_id, + ) + + async def openai_list_files_in_vector_store_file_batch( + self, + batch_id: str, + vector_store_id: str, + after: str | None = None, + before: str | None = None, + filter: str | None = None, + limit: int | None = 20, + order: str | None = "desc", + ): + await self.assert_action_allowed("read", "vector_db", vector_store_id) + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_list_files_in_vector_store_file_batch( + batch_id=batch_id, + vector_store_id=vector_store_id, + after=after, + before=before, + filter=filter, + limit=limit, + order=order, + ) + + async def openai_cancel_vector_store_file_batch( + self, + batch_id: str, + vector_store_id: str, + ): + await self.assert_action_allowed("update", "vector_db", vector_store_id) + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_cancel_vector_store_file_batch( + batch_id=batch_id, + vector_store_id=vector_store_id, + ) diff --git a/llama_stack/core/stack.py b/llama_stack/core/stack.py index 5050102d8..6a1015881 100644 --- a/llama_stack/core/stack.py +++ b/llama_stack/core/stack.py @@ -135,7 +135,7 @@ async def validate_vector_stores_config(run_config: StackRunConfig, impls: dict[ return vector_stores_config = run_config.vector_stores - default_model_id = vector_stores_config.default_embedding_model_id + default_model_id = vector_stores_config.embedding_model_id if Api.models not in impls: raise ValueError(f"Models API is not available but vector_stores config requires model '{default_model_id}'") diff --git a/llama_stack/distributions/ci-tests/run.yaml b/llama_stack/distributions/ci-tests/run.yaml index 59decf197..3ae049cbc 100644 --- a/llama_stack/distributions/ci-tests/run.yaml +++ b/llama_stack/distributions/ci-tests/run.yaml @@ -255,4 +255,4 @@ server: telemetry: enabled: true vector_stores: - default_embedding_model_id: sentence-transformers/nomic-ai/nomic-embed-text-v1.5 + embedding_model_id: sentence-transformers/nomic-ai/nomic-embed-text-v1.5 diff --git a/llama_stack/distributions/starter-gpu/run.yaml b/llama_stack/distributions/starter-gpu/run.yaml index d98b26b8a..6803e8a64 100644 --- a/llama_stack/distributions/starter-gpu/run.yaml +++ b/llama_stack/distributions/starter-gpu/run.yaml @@ -258,4 +258,4 @@ server: telemetry: enabled: true vector_stores: - default_embedding_model_id: sentence-transformers/nomic-ai/nomic-embed-text-v1.5 + embedding_model_id: sentence-transformers/nomic-ai/nomic-embed-text-v1.5 diff --git a/llama_stack/distributions/starter/run.yaml b/llama_stack/distributions/starter/run.yaml index f47a8e1b2..ca18baf09 100644 --- a/llama_stack/distributions/starter/run.yaml +++ b/llama_stack/distributions/starter/run.yaml @@ -255,4 +255,4 @@ server: telemetry: enabled: true vector_stores: - default_embedding_model_id: sentence-transformers/nomic-ai/nomic-embed-text-v1.5 + embedding_model_id: sentence-transformers/nomic-ai/nomic-embed-text-v1.5 diff --git a/llama_stack/distributions/starter/starter.py b/llama_stack/distributions/starter/starter.py index da775f566..897f75bf9 100644 --- a/llama_stack/distributions/starter/starter.py +++ b/llama_stack/distributions/starter/starter.py @@ -249,7 +249,7 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate: default_tool_groups=default_tool_groups, default_shields=default_shields, vector_stores_config=VectorStoresConfig( - default_embedding_model_id="sentence-transformers/nomic-ai/nomic-embed-text-v1.5" + embedding_model_id="sentence-transformers/nomic-ai/nomic-embed-text-v1.5" ), ), }, diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 3137de0de..a258eb1a0 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -317,3 +317,72 @@ def pytest_ignore_collect(path: str, config: pytest.Config) -> bool: if p.is_relative_to(rp): return False return True + + +def get_vector_io_provider_ids(client): + """Get all available vector_io provider IDs.""" + providers = [p for p in client.providers.list() if p.api == "vector_io"] + return [p.provider_id for p in providers] + + +def vector_provider_wrapper(func): + """Decorator to run a test against all available vector_io providers.""" + import functools + import os + + @functools.wraps(func) + def wrapper(*args, **kwargs): + # Get the vector_io_provider_id from the test arguments + import inspect + + sig = inspect.signature(func) + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + + vector_io_provider_id = bound_args.arguments.get("vector_io_provider_id") + if not vector_io_provider_id: + pytest.skip("No vector_io_provider_id provided") + + # Get client_with_models to check available providers + client_with_models = bound_args.arguments.get("client_with_models") + if client_with_models: + available_providers = get_vector_io_provider_ids(client_with_models) + if vector_io_provider_id not in available_providers: + pytest.skip(f"Provider '{vector_io_provider_id}' not available. Available: {available_providers}") + + return func(*args, **kwargs) + + # For replay tests, only use providers that are available in ci-tests environment + if os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE") == "replay": + all_providers = ["faiss", "sqlite-vec"] + else: + # For live tests, try all providers (they'll skip if not available) + all_providers = [ + "faiss", + "sqlite-vec", + "milvus", + "chromadb", + "pgvector", + "weaviate", + "qdrant", + ] + + return pytest.mark.parametrize("vector_io_provider_id", all_providers)(wrapper) + + +@pytest.fixture +def vector_io_provider_id(request, client_with_models): + """Fixture that provides a specific vector_io provider ID, skipping if not available.""" + if hasattr(request, "param"): + requested_provider = request.param + available_providers = get_vector_io_provider_ids(client_with_models) + + if requested_provider not in available_providers: + pytest.skip(f"Provider '{requested_provider}' not available. Available: {available_providers}") + + return requested_provider + else: + provider_ids = get_vector_io_provider_ids(client_with_models) + if not provider_ids: + pytest.skip("No vector_io providers available") + return provider_ids[0] diff --git a/tests/integration/fixtures/common.py b/tests/integration/fixtures/common.py index fd034fdc2..0d9dc1970 100644 --- a/tests/integration/fixtures/common.py +++ b/tests/integration/fixtures/common.py @@ -241,7 +241,7 @@ def instantiate_llama_stack_client(session): # --stack-config bypasses template so need this to set default embedding model if "vector_io" in config and "inference" in config: run_config.vector_stores = VectorStoresConfig( - default_embedding_model_id="sentence-transformers/nomic-ai/nomic-embed-text-v1.5" + embedding_model_id="inline::sentence-transformers/nomic-ai/nomic-embed-text-v1.5" ) run_config_file = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml") diff --git a/tests/integration/vector_io/test_openai_vector_stores.py b/tests/integration/vector_io/test_openai_vector_stores.py index e21b233bc..9b28adc90 100644 --- a/tests/integration/vector_io/test_openai_vector_stores.py +++ b/tests/integration/vector_io/test_openai_vector_stores.py @@ -16,6 +16,8 @@ from llama_stack.apis.vector_io import Chunk from llama_stack.core.library_client import LlamaStackAsLibraryClient from llama_stack.log import get_logger +from ..conftest import vector_provider_wrapper + logger = get_logger(name=__name__, category="vector_io") @@ -133,8 +135,9 @@ def compat_client_with_empty_stores(compat_client): clear_files() +@vector_provider_wrapper def test_openai_create_vector_store( - compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension + compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id ): """Test creating a vector store using OpenAI API.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) @@ -146,6 +149,7 @@ def test_openai_create_vector_store( metadata={"purpose": "testing", "environment": "integration"}, extra_body={ "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, }, ) @@ -159,14 +163,18 @@ def test_openai_create_vector_store( assert hasattr(vector_store, "created_at") -def test_openai_create_vector_store_default(compat_client_with_empty_stores, client_with_models): +@vector_provider_wrapper +def test_openai_create_vector_store_default(compat_client_with_empty_stores, client_with_models, vector_io_provider_id): skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) - vector_store = compat_client_with_empty_stores.vector_stores.create() + vector_store = compat_client_with_empty_stores.vector_stores.create( + extra_body={"provider_id": vector_io_provider_id} + ) assert vector_store.id +@vector_provider_wrapper def test_openai_list_vector_stores( - compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension + compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id ): """Test listing vector stores using OpenAI API.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) @@ -179,6 +187,7 @@ def test_openai_list_vector_stores( metadata={"type": "test"}, extra_body={ "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, }, ) store2 = client.vector_stores.create( @@ -186,6 +195,7 @@ def test_openai_list_vector_stores( metadata={"type": "test"}, extra_body={ "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, }, ) @@ -206,8 +216,9 @@ def test_openai_list_vector_stores( assert len(limited_response.data) == 1 +@vector_provider_wrapper def test_openai_retrieve_vector_store( - compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension + compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id ): """Test retrieving a specific vector store using OpenAI API.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) @@ -220,6 +231,7 @@ def test_openai_retrieve_vector_store( metadata={"purpose": "retrieval_test"}, extra_body={ "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, }, ) @@ -233,8 +245,9 @@ def test_openai_retrieve_vector_store( assert retrieved_store.object == "vector_store" +@vector_provider_wrapper def test_openai_update_vector_store( - compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension + compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id ): """Test modifying a vector store using OpenAI API.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) @@ -247,6 +260,7 @@ def test_openai_update_vector_store( metadata={"version": "1.0"}, extra_body={ "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, }, ) time.sleep(1) @@ -264,8 +278,9 @@ def test_openai_update_vector_store( assert modified_store.last_active_at > created_store.last_active_at +@vector_provider_wrapper def test_openai_delete_vector_store( - compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension + compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id ): """Test deleting a vector store using OpenAI API.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) @@ -278,6 +293,7 @@ def test_openai_delete_vector_store( metadata={"purpose": "deletion_test"}, extra_body={ "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, }, ) @@ -294,8 +310,9 @@ def test_openai_delete_vector_store( client.vector_stores.retrieve(vector_store_id=created_store.id) +@vector_provider_wrapper def test_openai_vector_store_search_empty( - compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension + compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id ): """Test searching an empty vector store using OpenAI API.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) @@ -308,6 +325,7 @@ def test_openai_vector_store_search_empty( metadata={"purpose": "search_testing"}, extra_body={ "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, }, ) @@ -323,8 +341,14 @@ def test_openai_vector_store_search_empty( assert search_response.has_more is False +@vector_provider_wrapper def test_openai_vector_store_with_chunks( - compat_client_with_empty_stores, client_with_models, sample_chunks, embedding_model_id, embedding_dimension + compat_client_with_empty_stores, + client_with_models, + sample_chunks, + embedding_model_id, + embedding_dimension, + vector_io_provider_id, ): """Test vector store functionality with actual chunks using both OpenAI and native APIs.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) @@ -338,6 +362,7 @@ def test_openai_vector_store_with_chunks( metadata={"purpose": "chunks_testing"}, extra_body={ "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, }, ) @@ -380,6 +405,7 @@ def test_openai_vector_store_with_chunks( ("What inspires neural networks?", "doc4", "ai"), ], ) +@vector_provider_wrapper def test_openai_vector_store_search_relevance( compat_client_with_empty_stores, client_with_models, @@ -387,6 +413,7 @@ def test_openai_vector_store_search_relevance( test_case, embedding_model_id, embedding_dimension, + vector_io_provider_id, ): """Test that OpenAI vector store search returns relevant results for different queries.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) @@ -402,6 +429,7 @@ def test_openai_vector_store_search_relevance( metadata={"purpose": "relevance_testing"}, extra_body={ "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, }, ) @@ -430,8 +458,14 @@ def test_openai_vector_store_search_relevance( assert top_result.score > 0 +@vector_provider_wrapper def test_openai_vector_store_search_with_ranking_options( - compat_client_with_empty_stores, client_with_models, sample_chunks, embedding_model_id, embedding_dimension + compat_client_with_empty_stores, + client_with_models, + sample_chunks, + embedding_model_id, + embedding_dimension, + vector_io_provider_id, ): """Test OpenAI vector store search with ranking options.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) @@ -445,6 +479,7 @@ def test_openai_vector_store_search_with_ranking_options( metadata={"purpose": "ranking_testing"}, extra_body={ "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, }, ) @@ -483,8 +518,14 @@ def test_openai_vector_store_search_with_ranking_options( assert result.score >= threshold +@vector_provider_wrapper def test_openai_vector_store_search_with_high_score_filter( - compat_client_with_empty_stores, client_with_models, sample_chunks, embedding_model_id, embedding_dimension + compat_client_with_empty_stores, + client_with_models, + sample_chunks, + embedding_model_id, + embedding_dimension, + vector_io_provider_id, ): """Test that searching with text very similar to a document and high score threshold returns only that document.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) @@ -498,6 +539,7 @@ def test_openai_vector_store_search_with_high_score_filter( metadata={"purpose": "high_score_filtering"}, extra_body={ "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, }, ) @@ -542,8 +584,14 @@ def test_openai_vector_store_search_with_high_score_filter( assert "python" in top_content.lower() or "programming" in top_content.lower() +@vector_provider_wrapper def test_openai_vector_store_search_with_max_num_results( - compat_client_with_empty_stores, client_with_models, sample_chunks, embedding_model_id, embedding_dimension + compat_client_with_empty_stores, + client_with_models, + sample_chunks, + embedding_model_id, + embedding_dimension, + vector_io_provider_id, ): """Test OpenAI vector store search with max_num_results.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) @@ -557,6 +605,7 @@ def test_openai_vector_store_search_with_max_num_results( metadata={"purpose": "max_num_results_testing"}, extra_body={ "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, }, ) @@ -577,8 +626,9 @@ def test_openai_vector_store_search_with_max_num_results( assert len(search_response.data) == 2 +@vector_provider_wrapper def test_openai_vector_store_attach_file( - compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension + compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id ): """Test OpenAI vector store attach file.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) @@ -591,6 +641,7 @@ def test_openai_vector_store_attach_file( name="test_store", extra_body={ "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, }, ) @@ -637,8 +688,9 @@ def test_openai_vector_store_attach_file( assert "foobazbar" in top_content.lower() +@vector_provider_wrapper def test_openai_vector_store_attach_files_on_creation( - compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension + compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id ): """Test OpenAI vector store attach files on creation.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) @@ -668,6 +720,7 @@ def test_openai_vector_store_attach_files_on_creation( file_ids=file_ids, extra_body={ "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, }, ) @@ -700,8 +753,9 @@ def test_openai_vector_store_attach_files_on_creation( assert updated_vector_store.file_counts.failed == 0 +@vector_provider_wrapper def test_openai_vector_store_list_files( - compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension + compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id ): """Test OpenAI vector store list files.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) @@ -714,6 +768,7 @@ def test_openai_vector_store_list_files( name="test_store", extra_body={ "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, }, ) @@ -773,8 +828,9 @@ def test_openai_vector_store_list_files( assert updated_vector_store.file_counts.in_progress == 0 +@vector_provider_wrapper def test_openai_vector_store_list_files_invalid_vector_store( - compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension + compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id ): """Test OpenAI vector store list files with invalid vector store ID.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) @@ -789,8 +845,9 @@ def test_openai_vector_store_list_files_invalid_vector_store( compat_client.vector_stores.files.list(vector_store_id="abc123") +@vector_provider_wrapper def test_openai_vector_store_retrieve_file_contents( - compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension + compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id ): """Test OpenAI vector store retrieve file contents.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) @@ -803,6 +860,7 @@ def test_openai_vector_store_retrieve_file_contents( name="test_store", extra_body={ "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, }, ) @@ -848,8 +906,9 @@ def test_openai_vector_store_retrieve_file_contents( assert file_contents.attributes == attributes +@vector_provider_wrapper def test_openai_vector_store_delete_file( - compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension + compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id ): """Test OpenAI vector store delete file.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) @@ -862,6 +921,7 @@ def test_openai_vector_store_delete_file( name="test_store", extra_body={ "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, }, ) @@ -912,8 +972,9 @@ def test_openai_vector_store_delete_file( assert updated_vector_store.file_counts.in_progress == 0 +@vector_provider_wrapper def test_openai_vector_store_delete_file_removes_from_vector_store( - compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension + compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id ): """Test OpenAI vector store delete file removes from vector store.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) @@ -926,6 +987,7 @@ def test_openai_vector_store_delete_file_removes_from_vector_store( name="test_store", extra_body={ "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, }, ) @@ -962,8 +1024,9 @@ def test_openai_vector_store_delete_file_removes_from_vector_store( assert not search_response.data +@vector_provider_wrapper def test_openai_vector_store_update_file( - compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension + compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id ): """Test OpenAI vector store update file.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) @@ -976,6 +1039,7 @@ def test_openai_vector_store_update_file( name="test_store", extra_body={ "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, }, ) @@ -1017,8 +1081,9 @@ def test_openai_vector_store_update_file( assert retrieved_file.attributes["foo"] == "baz" +@vector_provider_wrapper def test_create_vector_store_files_duplicate_vector_store_name( - compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension + compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id ): """ This test confirms that client.vector_stores.create() creates a unique ID @@ -1044,6 +1109,7 @@ def test_create_vector_store_files_duplicate_vector_store_name( name="test_store_with_files", extra_body={ "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, }, ) assert vector_store.file_counts.completed == 0 @@ -1056,6 +1122,7 @@ def test_create_vector_store_files_duplicate_vector_store_name( name="test_store_with_files", extra_body={ "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, }, ) @@ -1086,8 +1153,15 @@ def test_create_vector_store_files_duplicate_vector_store_name( @pytest.mark.parametrize("search_mode", ["vector", "keyword", "hybrid"]) +@vector_provider_wrapper def test_openai_vector_store_search_modes( - llama_stack_client, client_with_models, sample_chunks, search_mode, embedding_model_id, embedding_dimension + llama_stack_client, + client_with_models, + sample_chunks, + search_mode, + embedding_model_id, + embedding_dimension, + vector_io_provider_id, ): skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_models, search_mode) @@ -1097,6 +1171,7 @@ def test_openai_vector_store_search_modes( metadata={"purpose": "search_mode_testing"}, extra_body={ "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, }, ) @@ -1115,8 +1190,9 @@ def test_openai_vector_store_search_modes( assert search_response is not None +@vector_provider_wrapper def test_openai_vector_store_file_batch_create_and_retrieve( - compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension + compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id ): """Test creating and retrieving a vector store file batch.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) @@ -1128,6 +1204,7 @@ def test_openai_vector_store_file_batch_create_and_retrieve( name="batch_test_store", extra_body={ "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, }, ) @@ -1178,8 +1255,9 @@ def test_openai_vector_store_file_batch_create_and_retrieve( assert retrieved_batch.status == "completed" # Should be completed after processing +@vector_provider_wrapper def test_openai_vector_store_file_batch_list_files( - compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension + compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id ): """Test listing files in a vector store file batch.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) @@ -1191,6 +1269,7 @@ def test_openai_vector_store_file_batch_list_files( name="batch_list_test_store", extra_body={ "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, }, ) @@ -1271,8 +1350,9 @@ def test_openai_vector_store_file_batch_list_files( assert first_page_ids.isdisjoint(second_page_ids) +@vector_provider_wrapper def test_openai_vector_store_file_batch_cancel( - compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension + compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id ): """Test cancelling a vector store file batch.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) @@ -1284,6 +1364,7 @@ def test_openai_vector_store_file_batch_cancel( name="batch_cancel_test_store", extra_body={ "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, }, ) @@ -1326,8 +1407,9 @@ def test_openai_vector_store_file_batch_cancel( assert final_batch.status in ["completed", "cancelled"] +@vector_provider_wrapper def test_openai_vector_store_file_batch_retrieve_contents( - compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension + compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id ): """Test retrieving file contents after file batch processing.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) @@ -1339,6 +1421,7 @@ def test_openai_vector_store_file_batch_retrieve_contents( name="batch_contents_test_store", extra_body={ "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, }, ) @@ -1399,8 +1482,9 @@ def test_openai_vector_store_file_batch_retrieve_contents( assert file_data[i][1].decode("utf-8") in content_text +@vector_provider_wrapper def test_openai_vector_store_file_batch_error_handling( - compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension + compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id ): """Test error handling for file batch operations.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) @@ -1412,6 +1496,7 @@ def test_openai_vector_store_file_batch_error_handling( name="batch_error_test_store", extra_body={ "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, }, ) @@ -1456,8 +1541,9 @@ def test_openai_vector_store_file_batch_error_handling( ) +@vector_provider_wrapper def test_openai_vector_store_embedding_config_from_metadata( - compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension + compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id ): """Test that embedding configuration works from metadata source.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) @@ -1471,6 +1557,9 @@ def test_openai_vector_store_embedding_config_from_metadata( "embedding_dimension": str(embedding_dimension), "test_source": "metadata", }, + extra_body={ + "provider_id": vector_io_provider_id, + }, ) assert vector_store_metadata is not None @@ -1489,6 +1578,7 @@ def test_openai_vector_store_embedding_config_from_metadata( extra_body={ "embedding_model": embedding_model_id, "embedding_dimension": int(embedding_dimension), # Ensure same type/value + "provider_id": vector_io_provider_id, }, ) diff --git a/tests/integration/vector_io/test_vector_io.py b/tests/integration/vector_io/test_vector_io.py index 653299338..e5ca7a0db 100644 --- a/tests/integration/vector_io/test_vector_io.py +++ b/tests/integration/vector_io/test_vector_io.py @@ -8,6 +8,8 @@ import pytest from llama_stack.apis.vector_io import Chunk +from ..conftest import vector_provider_wrapper + @pytest.fixture(scope="session") def sample_chunks(): @@ -46,12 +48,13 @@ def client_with_empty_registry(client_with_models): clear_registry() -def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id, embedding_dimension): +@vector_provider_wrapper +def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id): vector_db_name = "test_vector_db" create_response = client_with_empty_registry.vector_stores.create( name=vector_db_name, extra_body={ - "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, }, ) @@ -65,12 +68,13 @@ def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id, embe assert response.id.startswith("vs_") -def test_vector_db_register(client_with_empty_registry, embedding_model_id, embedding_dimension): +@vector_provider_wrapper +def test_vector_db_register(client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id): vector_db_name = "test_vector_db" response = client_with_empty_registry.vector_stores.create( name=vector_db_name, extra_body={ - "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, }, ) @@ -100,12 +104,15 @@ def test_vector_db_register(client_with_empty_registry, embedding_model_id, embe ("How does machine learning improve over time?", "doc2"), ], ) -def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding_dimension, sample_chunks, test_case): +@vector_provider_wrapper +def test_insert_chunks( + client_with_empty_registry, embedding_model_id, embedding_dimension, sample_chunks, test_case, vector_io_provider_id +): vector_db_name = "test_vector_db" create_response = client_with_empty_registry.vector_stores.create( name=vector_db_name, extra_body={ - "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, }, ) @@ -135,7 +142,10 @@ def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding assert top_match.metadata["document_id"] == expected_doc_id, f"Query '{query}' should match {expected_doc_id}" -def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, embedding_model_id, embedding_dimension): +@vector_provider_wrapper +def test_insert_chunks_with_precomputed_embeddings( + client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id +): vector_io_provider_params_dict = { "inline::milvus": {"score_threshold": -1.0}, "inline::qdrant": {"score_threshold": -1.0}, @@ -145,7 +155,7 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e register_response = client_with_empty_registry.vector_stores.create( name=vector_db_name, extra_body={ - "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, }, ) @@ -181,8 +191,9 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e # expect this test to fail +@vector_provider_wrapper def test_query_returns_valid_object_when_identical_to_embedding_in_vdb( - client_with_empty_registry, embedding_model_id, embedding_dimension + client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id ): vector_io_provider_params_dict = { "inline::milvus": {"score_threshold": 0.0}, @@ -194,6 +205,7 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb( name=vector_db_name, extra_body={ "embedding_model": embedding_model_id, + "provider_id": vector_io_provider_id, }, ) @@ -226,33 +238,44 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb( assert response.chunks[0].metadata["source"] == "precomputed" -def test_auto_extract_embedding_dimension(client_with_empty_registry, embedding_model_id): +@vector_provider_wrapper +def test_auto_extract_embedding_dimension( + client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id +): + # This test specifically tests embedding model override, so we keep embedding_model vs = client_with_empty_registry.vector_stores.create( - name="test_auto_extract", extra_body={"embedding_model": embedding_model_id} + name="test_auto_extract", + extra_body={"embedding_model": embedding_model_id, "provider_id": vector_io_provider_id}, ) assert vs.id is not None -def test_provider_auto_selection_single_provider(client_with_empty_registry, embedding_model_id): +@vector_provider_wrapper +def test_provider_auto_selection_single_provider( + client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id +): providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"] if len(providers) != 1: pytest.skip(f"Test requires exactly one vector_io provider, found {len(providers)}") - vs = client_with_empty_registry.vector_stores.create( - name="test_auto_provider", extra_body={"embedding_model": embedding_model_id} - ) + # Test that when only one provider is available, it's auto-selected (no provider_id needed) + vs = client_with_empty_registry.vector_stores.create(name="test_auto_provider") assert vs.id is not None -def test_provider_id_override(client_with_empty_registry, embedding_model_id): +@vector_provider_wrapper +def test_provider_id_override( + client_with_empty_registry, embedding_model_id, embedding_dimension, vector_io_provider_id +): providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"] if len(providers) != 1: pytest.skip(f"Test requires exactly one vector_io provider, found {len(providers)}") provider_id = providers[0].provider_id + # Test explicit provider_id specification (using default embedding model) vs = client_with_empty_registry.vector_stores.create( - name="test_provider_override", extra_body={"embedding_model": embedding_model_id, "provider_id": provider_id} + name="test_provider_override", extra_body={"provider_id": provider_id} ) assert vs.id is not None assert vs.metadata.get("provider_id") == provider_id diff --git a/tests/unit/core/test_stack_validation.py b/tests/unit/core/test_stack_validation.py index b5f6c1b24..b9fe29f23 100644 --- a/tests/unit/core/test_stack_validation.py +++ b/tests/unit/core/test_stack_validation.py @@ -20,7 +20,7 @@ class TestVectorStoresValidation: async def test_validate_missing_model(self): """Test validation fails when model not found.""" run_config = StackRunConfig( - image_name="test", providers={}, vector_stores=VectorStoresConfig(default_embedding_model_id="missing") + image_name="test", providers={}, vector_stores=VectorStoresConfig(embedding_model_id="missing") ) mock_models = AsyncMock() mock_models.list_models.return_value = [] @@ -31,7 +31,7 @@ class TestVectorStoresValidation: async def test_validate_success(self): """Test validation passes with valid model.""" run_config = StackRunConfig( - image_name="test", providers={}, vector_stores=VectorStoresConfig(default_embedding_model_id="valid") + image_name="test", providers={}, vector_stores=VectorStoresConfig(embedding_model_id="valid") ) mock_models = AsyncMock() mock_models.list_models.return_value = [ diff --git a/tests/unit/distribution/test_single_provider_filter.py b/tests/unit/distribution/test_single_provider_filter.py deleted file mode 100644 index 7447e3eaa..000000000 --- a/tests/unit/distribution/test_single_provider_filter.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import pytest - -from llama_stack.cli.stack._build import _apply_single_provider_filter -from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec -from llama_stack.core.utils.image_types import LlamaStackImageType - - -def test_filters_single_api(): - """Test filtering keeps only specified provider for one API.""" - build_config = BuildConfig( - image_type=LlamaStackImageType.VENV.value, - distribution_spec=DistributionSpec( - providers={ - "vector_io": [ - BuildProvider(provider_type="inline::faiss"), - BuildProvider(provider_type="inline::sqlite-vec"), - ], - "inference": [ - BuildProvider(provider_type="remote::openai"), - ], - }, - description="Test", - ), - ) - - filtered = _apply_single_provider_filter(build_config, "vector_io=inline::sqlite-vec") - - assert len(filtered.distribution_spec.providers["vector_io"]) == 1 - assert filtered.distribution_spec.providers["vector_io"][0].provider_type == "inline::sqlite-vec" - assert len(filtered.distribution_spec.providers["inference"]) == 1 # unchanged - - -def test_filters_multiple_apis(): - """Test filtering multiple APIs.""" - build_config = BuildConfig( - image_type=LlamaStackImageType.VENV.value, - distribution_spec=DistributionSpec( - providers={ - "vector_io": [ - BuildProvider(provider_type="inline::faiss"), - BuildProvider(provider_type="inline::sqlite-vec"), - ], - "inference": [ - BuildProvider(provider_type="remote::openai"), - BuildProvider(provider_type="remote::anthropic"), - ], - }, - description="Test", - ), - ) - - filtered = _apply_single_provider_filter(build_config, "vector_io=inline::faiss,inference=remote::openai") - - assert len(filtered.distribution_spec.providers["vector_io"]) == 1 - assert filtered.distribution_spec.providers["vector_io"][0].provider_type == "inline::faiss" - assert len(filtered.distribution_spec.providers["inference"]) == 1 - assert filtered.distribution_spec.providers["inference"][0].provider_type == "remote::openai" - - -def test_provider_not_found_exits(): - """Test error when specified provider doesn't exist.""" - build_config = BuildConfig( - image_type=LlamaStackImageType.VENV.value, - distribution_spec=DistributionSpec( - providers={"vector_io": [BuildProvider(provider_type="inline::faiss")]}, - description="Test", - ), - ) - - with pytest.raises(SystemExit): - _apply_single_provider_filter(build_config, "vector_io=inline::nonexistent") - - -def test_invalid_format_exits(): - """Test error for invalid filter format.""" - build_config = BuildConfig( - image_type=LlamaStackImageType.VENV.value, - distribution_spec=DistributionSpec(providers={}, description="Test"), - ) - - with pytest.raises(SystemExit): - _apply_single_provider_filter(build_config, "invalid_format")