From accc4c437ed22baaffa32111fe359f72c93e9178 Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Thu, 16 Oct 2025 10:59:01 -0400 Subject: [PATCH] update resolver to only pass vector_stores section of run config Signed-off-by: Francisco Javier Arceo Using Router only from VectorDBs Signed-off-by: Francisco Javier Arceo removing model_api from vector store providers Signed-off-by: Francisco Javier Arceo fix test Signed-off-by: Francisco Javier Arceo updating integration tests Signed-off-by: Francisco Javier Arceo special handling for replay mode for available providers Signed-off-by: Francisco Javier Arceo --- .../workflows/integration-vector-io-tests.yml | 4 +- llama_stack/cli/stack/_build.py | 103 +++++++++++++++++- llama_stack/cli/stack/build.py | 7 ++ llama_stack/core/resolver.py | 4 - llama_stack/core/routers/__init__.py | 3 + llama_stack/core/routers/vector_io.py | 7 ++ llama_stack/distributions/ci-tests/build.yaml | 2 + llama_stack/distributions/ci-tests/run.yaml | 15 +++ .../distributions/starter-gpu/build.yaml | 2 + .../distributions/starter-gpu/run.yaml | 19 +++- llama_stack/distributions/starter/build.yaml | 2 + llama_stack/distributions/starter/run.yaml | 15 +++ llama_stack/distributions/starter/starter.py | 20 ++++ .../inline/vector_io/chroma/__init__.py | 21 +--- .../inline/vector_io/chroma/config.py | 3 +- .../inline/vector_io/faiss/__init__.py | 17 +-- .../inline/vector_io/faiss/config.py | 10 +- .../providers/inline/vector_io/faiss/faiss.py | 67 ++---------- .../inline/vector_io/milvus/__init__.py | 17 +-- .../inline/vector_io/milvus/config.py | 8 +- .../inline/vector_io/qdrant/__init__.py | 17 +-- .../inline/vector_io/qdrant/config.py | 5 +- .../inline/vector_io/sqlite_vec/__init__.py | 17 +-- .../inline/vector_io/sqlite_vec/config.py | 8 +- .../inline/vector_io/sqlite_vec/sqlite_vec.py | 65 ++--------- .../remote/vector_io/chroma/__init__.py | 17 +-- .../remote/vector_io/chroma/chroma.py | 55 ++-------- .../remote/vector_io/chroma/config.py | 3 +- .../remote/vector_io/milvus/__init__.py | 17 +-- .../remote/vector_io/milvus/config.py | 3 +- .../remote/vector_io/milvus/milvus.py | 84 +++----------- .../remote/vector_io/pgvector/__init__.py | 16 +-- .../remote/vector_io/pgvector/config.py | 8 +- .../remote/vector_io/pgvector/pgvector.py | 54 ++------- .../remote/vector_io/qdrant/__init__.py | 17 +-- .../remote/vector_io/qdrant/config.py | 8 +- .../remote/vector_io/qdrant/qdrant.py | 49 ++------- .../remote/vector_io/weaviate/__init__.py | 17 +-- .../remote/vector_io/weaviate/config.py | 14 +-- .../remote/vector_io/weaviate/weaviate.py | 87 +++------------ .../utils/memory/openai_vector_store_mixin.py | 79 +------------- tests/integration/fixtures/common.py | 2 +- .../test_single_provider_filter.py | 88 +++++++++++++++ tests/unit/providers/vector_io/conftest.py | 2 - tests/unit/providers/vector_io/test_faiss.py | 19 +--- .../test_vector_io_openai_vector_stores.py | 2 +- 46 files changed, 397 insertions(+), 702 deletions(-) create mode 100644 tests/unit/distribution/test_single_provider_filter.py diff --git a/.github/workflows/integration-vector-io-tests.yml b/.github/workflows/integration-vector-io-tests.yml index 89dc64a45..fed54cadb 100644 --- a/.github/workflows/integration-vector-io-tests.yml +++ b/.github/workflows/integration-vector-io-tests.yml @@ -144,7 +144,7 @@ jobs: - name: Build Llama Stack run: | - uv run --no-sync llama stack build --template ci-tests --image-type venv + uv run --no-sync llama stack build --distro starter --image-type venv --single-provider "vector_io=${{ matrix.vector-io-provider }}" - name: Check Storage and Memory Available Before Tests if: ${{ always() }} @@ -168,7 +168,7 @@ jobs: WEAVIATE_CLUSTER_URL: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'localhost:8080' || '' }} run: | uv run --no-sync \ - pytest -sv --stack-config="files=inline::localfs,inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \ + pytest -sv --stack-config ~/.llama/distributions/starter/starter-filtered-run.yaml \ tests/integration/vector_io - name: Check Storage and Memory Available After Tests diff --git a/llama_stack/cli/stack/_build.py b/llama_stack/cli/stack/_build.py index 471d5cb66..e90ec2ef6 100644 --- a/llama_stack/cli/stack/_build.py +++ b/llama_stack/cli/stack/_build.py @@ -50,6 +50,85 @@ from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig DISTRIBS_PATH = Path(__file__).parent.parent.parent / "distributions" +def _apply_single_provider_filter(build_config: BuildConfig, single_provider_arg: str) -> BuildConfig: + """Filter a distribution to only include specified providers for certain APIs.""" + # Parse the single-provider argument using the same logic as --providers + provider_filters: dict[str, str] = {} + for api_provider in single_provider_arg.split(","): + if "=" not in api_provider: + cprint( + "Could not parse `--single-provider`. Please ensure the list is in the format api1=provider1,api2=provider2", + color="red", + file=sys.stderr, + ) + sys.exit(1) + api, provider_type = api_provider.split("=") + provider_filters[api] = provider_type + + # Create a copy of the build config to modify + filtered_build_config = BuildConfig( + image_type=build_config.image_type, + image_name=build_config.image_name, + external_providers_dir=build_config.external_providers_dir, + external_apis_dir=build_config.external_apis_dir, + distribution_spec=DistributionSpec( + providers={}, + description=build_config.distribution_spec.description, + ), + ) + + # Copy all providers, but filter the specified APIs + for api, providers in build_config.distribution_spec.providers.items(): + if api in provider_filters: + target_provider_type = provider_filters[api] + filtered_providers = [p for p in providers if p.provider_type == target_provider_type] + if not filtered_providers: + cprint( + f"Provider {target_provider_type} not found in distribution for API {api}", + color="red", + file=sys.stderr, + ) + sys.exit(1) + filtered_build_config.distribution_spec.providers[api] = filtered_providers + else: + # Keep all providers for unfiltered APIs + filtered_build_config.distribution_spec.providers[api] = providers + + return filtered_build_config + + +def _generate_filtered_run_config( + build_config: BuildConfig, + build_dir: Path, + distro_name: str, +) -> Path: + """ + Generate a filtered run.yaml by starting with the original distribution's run.yaml + and filtering the providers according to the build_config. + """ + # Load the original distribution's run.yaml + distro_resource = importlib.resources.files("llama_stack") / f"distributions/{distro_name}/run.yaml" + + with importlib.resources.as_file(distro_resource) as path: + with open(path) as f: + original_config = yaml.safe_load(f) + + # Apply provider filtering to the loaded config + for api, providers in build_config.distribution_spec.providers.items(): + if api in original_config.get("providers", {}): + # Filter this API to only include the providers from build_config + provider_types = {p.provider_type for p in providers} + filtered_providers = [p for p in original_config["providers"][api] if p["provider_type"] in provider_types] + original_config["providers"][api] = filtered_providers + + # Write the filtered config + run_config_file = build_dir / f"{distro_name}-filtered-run.yaml" + with open(run_config_file, "w") as f: + yaml.dump(original_config, f, sort_keys=False) + + return run_config_file + + @lru_cache def available_distros_specs() -> dict[str, BuildConfig]: import yaml @@ -93,6 +172,11 @@ def run_stack_build_command(args: argparse.Namespace) -> None: ) sys.exit(1) build_config = available_distros[distro_name] + + # Apply single-provider filtering if specified + if args.single_provider: + build_config = _apply_single_provider_filter(build_config, args.single_provider) + if args.image_type: build_config.image_type = args.image_type else: @@ -245,6 +329,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None: image_name=image_name, config_path=args.config, distro_name=distro_name, + is_filtered=bool(args.single_provider), ) except (Exception, RuntimeError) as exc: @@ -363,6 +448,7 @@ def _run_stack_build_command_from_build_config( image_name: str | None = None, distro_name: str | None = None, config_path: str | None = None, + is_filtered: bool = False, ) -> Path | Traversable: image_name = image_name or build_config.image_name if build_config.image_type == LlamaStackImageType.CONTAINER.value: @@ -435,12 +521,19 @@ def _run_stack_build_command_from_build_config( raise RuntimeError(f"Failed to build image {image_name}") if distro_name: - # copy run.yaml from distribution to build_dir instead of generating it again - distro_path = importlib.resources.files("llama_stack") / f"distributions/{distro_name}/run.yaml" - run_config_file = build_dir / f"{distro_name}-run.yaml" + # If single-provider filtering was applied, generate a filtered run config + # Otherwise, copy run.yaml from distribution as before + if is_filtered: + run_config_file = _generate_filtered_run_config(build_config, build_dir, distro_name) + distro_path = run_config_file # Use the generated file as the distro_path + else: + # copy run.yaml from distribution to build_dir instead of generating it again + distro_resource = importlib.resources.files("llama_stack") / f"distributions/{distro_name}/run.yaml" + run_config_file = build_dir / f"{distro_name}-run.yaml" - with importlib.resources.as_file(distro_path) as path: - shutil.copy(path, run_config_file) + with importlib.resources.as_file(distro_resource) as path: + shutil.copy(path, run_config_file) + distro_path = run_config_file # Update distro_path to point to the copied file cprint("Build Successful!", color="green", file=sys.stderr) cprint(f"You can find the newly-built distribution here: {run_config_file}", color="blue", file=sys.stderr) diff --git a/llama_stack/cli/stack/build.py b/llama_stack/cli/stack/build.py index 80cf6fb38..a6c466c6f 100644 --- a/llama_stack/cli/stack/build.py +++ b/llama_stack/cli/stack/build.py @@ -92,6 +92,13 @@ the build. If not specified, currently active environment will be used if found. help="Build a config for a list of providers and only those providers. This list is formatted like: api1=provider1,api2=provider2. Where there can be multiple providers per API.", ) + self.parser.add_argument( + "--single-provider", + type=str, + default=None, + help="Limit a distribution to a single provider for specific APIs. Format: api1=provider1,api2=provider2. Use with --distro to filter an existing distribution.", + ) + def _run_stack_build_command(self, args: argparse.Namespace) -> None: # always keep implementation completely silo-ed away from CLI so CLI # can be fast to load and reduces dependencies diff --git a/llama_stack/core/resolver.py b/llama_stack/core/resolver.py index 1c33010c4..73c047979 100644 --- a/llama_stack/core/resolver.py +++ b/llama_stack/core/resolver.py @@ -409,10 +409,6 @@ async def instantiate_provider( if "telemetry_enabled" in inspect.signature(getattr(module, method)).parameters and run_config.telemetry: args.append(run_config.telemetry.enabled) - # vector_io providers need access to run_config.vector_stores - if provider_spec.api == Api.vector_io and "run_config" in inspect.signature(getattr(module, method)).parameters: - args.append(run_config) - fn = getattr(module, method) impl = await fn(*args) impl.__provider_id__ = provider.provider_id diff --git a/llama_stack/core/routers/__init__.py b/llama_stack/core/routers/__init__.py index 4463d2460..6f9a4a9e0 100644 --- a/llama_stack/core/routers/__init__.py +++ b/llama_stack/core/routers/__init__.py @@ -84,6 +84,9 @@ async def get_auto_router_impl( await inference_store.initialize() api_to_dep_impl["store"] = inference_store + if api == Api.vector_io and run_config.vector_stores: + api_to_dep_impl["vector_stores_config"] = run_config.vector_stores + impl = api_to_routers[api.value](routing_table, **api_to_dep_impl) await impl.initialize() return impl diff --git a/llama_stack/core/routers/vector_io.py b/llama_stack/core/routers/vector_io.py index f4e871a40..d559d16e0 100644 --- a/llama_stack/core/routers/vector_io.py +++ b/llama_stack/core/routers/vector_io.py @@ -31,6 +31,7 @@ from llama_stack.apis.vector_io import ( VectorStoreObject, VectorStoreSearchResponsePage, ) +from llama_stack.core.datatypes import VectorStoresConfig from llama_stack.log import get_logger from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable @@ -43,9 +44,11 @@ class VectorIORouter(VectorIO): def __init__( self, routing_table: RoutingTable, + vector_stores_config: VectorStoresConfig | None = None, ) -> None: logger.debug("Initializing VectorIORouter") self.routing_table = routing_table + self.vector_stores_config = vector_stores_config async def initialize(self) -> None: logger.debug("VectorIORouter.initialize") @@ -122,6 +125,10 @@ class VectorIORouter(VectorIO): embedding_dimension = extra.get("embedding_dimension") provider_id = extra.get("provider_id") + if embedding_model is None and self.vector_stores_config is not None: + embedding_model = self.vector_stores_config.default_embedding_model_id + logger.debug(f"Using default embedding model: {embedding_model}") + if embedding_model is not None and embedding_dimension is None: embedding_dimension = await self._get_embedding_model_dimension(embedding_model) diff --git a/llama_stack/distributions/ci-tests/build.yaml b/llama_stack/distributions/ci-tests/build.yaml index 191d0ae59..c01e415a9 100644 --- a/llama_stack/distributions/ci-tests/build.yaml +++ b/llama_stack/distributions/ci-tests/build.yaml @@ -25,6 +25,8 @@ distribution_spec: - provider_type: inline::milvus - provider_type: remote::chromadb - provider_type: remote::pgvector + - provider_type: remote::qdrant + - provider_type: remote::weaviate files: - provider_type: inline::localfs safety: diff --git a/llama_stack/distributions/ci-tests/run.yaml b/llama_stack/distributions/ci-tests/run.yaml index 42741f102..59decf197 100644 --- a/llama_stack/distributions/ci-tests/run.yaml +++ b/llama_stack/distributions/ci-tests/run.yaml @@ -128,6 +128,21 @@ providers: kvstore: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/pgvector_registry.db + - provider_id: ${env.QDRANT_URL:+qdrant} + provider_type: remote::qdrant + config: + api_key: ${env.QDRANT_API_KEY:=} + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/qdrant_registry.db + - provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate} + provider_type: remote::weaviate + config: + weaviate_api_key: null + weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080} + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/weaviate_registry.db files: - provider_id: meta-reference-files provider_type: inline::localfs diff --git a/llama_stack/distributions/starter-gpu/build.yaml b/llama_stack/distributions/starter-gpu/build.yaml index 943c6134d..b2e2a0c85 100644 --- a/llama_stack/distributions/starter-gpu/build.yaml +++ b/llama_stack/distributions/starter-gpu/build.yaml @@ -26,6 +26,8 @@ distribution_spec: - provider_type: inline::milvus - provider_type: remote::chromadb - provider_type: remote::pgvector + - provider_type: remote::qdrant + - provider_type: remote::weaviate files: - provider_type: inline::localfs safety: diff --git a/llama_stack/distributions/starter-gpu/run.yaml b/llama_stack/distributions/starter-gpu/run.yaml index 9593b2947..d98b26b8a 100644 --- a/llama_stack/distributions/starter-gpu/run.yaml +++ b/llama_stack/distributions/starter-gpu/run.yaml @@ -128,6 +128,21 @@ providers: kvstore: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/pgvector_registry.db + - provider_id: ${env.QDRANT_URL:+qdrant} + provider_type: remote::qdrant + config: + api_key: ${env.QDRANT_API_KEY:=} + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/qdrant_registry.db + - provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate} + provider_type: remote::weaviate + config: + weaviate_api_key: null + weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080} + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/weaviate_registry.db files: - provider_id: meta-reference-files provider_type: inline::localfs @@ -240,7 +255,7 @@ tool_groups: provider_id: rag-runtime server: port: 8321 -vector_stores: - default_embedding_model_id: sentence-transformers/nomic-ai/nomic-embed-text-v1.5 telemetry: enabled: true +vector_stores: + default_embedding_model_id: sentence-transformers/nomic-ai/nomic-embed-text-v1.5 diff --git a/llama_stack/distributions/starter/build.yaml b/llama_stack/distributions/starter/build.yaml index c2719d50d..baa80ef3e 100644 --- a/llama_stack/distributions/starter/build.yaml +++ b/llama_stack/distributions/starter/build.yaml @@ -26,6 +26,8 @@ distribution_spec: - provider_type: inline::milvus - provider_type: remote::chromadb - provider_type: remote::pgvector + - provider_type: remote::qdrant + - provider_type: remote::weaviate files: - provider_type: inline::localfs safety: diff --git a/llama_stack/distributions/starter/run.yaml b/llama_stack/distributions/starter/run.yaml index 2e2e5c1b0..f47a8e1b2 100644 --- a/llama_stack/distributions/starter/run.yaml +++ b/llama_stack/distributions/starter/run.yaml @@ -128,6 +128,21 @@ providers: kvstore: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db + - provider_id: ${env.QDRANT_URL:+qdrant} + provider_type: remote::qdrant + config: + api_key: ${env.QDRANT_API_KEY:=} + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/qdrant_registry.db + - provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate} + provider_type: remote::weaviate + config: + weaviate_api_key: null + weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080} + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/weaviate_registry.db files: - provider_id: meta-reference-files provider_type: inline::localfs diff --git a/llama_stack/distributions/starter/starter.py b/llama_stack/distributions/starter/starter.py index 8d637a69a..da775f566 100644 --- a/llama_stack/distributions/starter/starter.py +++ b/llama_stack/distributions/starter/starter.py @@ -32,6 +32,8 @@ from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOC from llama_stack.providers.remote.vector_io.pgvector.config import ( PGVectorVectorIOConfig, ) +from llama_stack.providers.remote.vector_io.qdrant.config import QdrantVectorIOConfig +from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig @@ -114,6 +116,8 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate: BuildProvider(provider_type="inline::milvus"), BuildProvider(provider_type="remote::chromadb"), BuildProvider(provider_type="remote::pgvector"), + BuildProvider(provider_type="remote::qdrant"), + BuildProvider(provider_type="remote::weaviate"), ], "files": [BuildProvider(provider_type="inline::localfs")], "safety": [ @@ -222,6 +226,22 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate: password="${env.PGVECTOR_PASSWORD:=}", ), ), + Provider( + provider_id="${env.QDRANT_URL:+qdrant}", + provider_type="remote::qdrant", + config=QdrantVectorIOConfig.sample_run_config( + f"~/.llama/distributions/{name}", + url="${env.QDRANT_URL:=}", + ), + ), + Provider( + provider_id="${env.WEAVIATE_CLUSTER_URL:+weaviate}", + provider_type="remote::weaviate", + config=WeaviateVectorIOConfig.sample_run_config( + f"~/.llama/distributions/{name}", + cluster_url="${env.WEAVIATE_CLUSTER_URL:=}", + ), + ), ], "files": [files_provider], }, diff --git a/llama_stack/providers/inline/vector_io/chroma/__init__.py b/llama_stack/providers/inline/vector_io/chroma/__init__.py index b1b4312eb..575e5ad88 100644 --- a/llama_stack/providers/inline/vector_io/chroma/__init__.py +++ b/llama_stack/providers/inline/vector_io/chroma/__init__.py @@ -6,29 +6,14 @@ from typing import Any -from llama_stack.core.datatypes import StackRunConfig from llama_stack.providers.datatypes import Api from .config import ChromaVectorIOConfig -async def get_provider_impl( - config: ChromaVectorIOConfig, deps: dict[Api, Any], run_config: StackRunConfig | None = None -): - from llama_stack.providers.remote.vector_io.chroma.chroma import ( - ChromaVectorIOAdapter, - ) +async def get_provider_impl(config: ChromaVectorIOConfig, deps: dict[Api, Any]): + from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaVectorIOAdapter - vector_stores_config = None - if run_config and run_config.vector_stores: - vector_stores_config = run_config.vector_stores - - impl = ChromaVectorIOAdapter( - config, - deps[Api.inference], - deps[Api.models], - deps.get(Api.files), - vector_stores_config, - ) + impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files)) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/vector_io/chroma/config.py b/llama_stack/providers/inline/vector_io/chroma/config.py index a9566f7ff..e36384016 100644 --- a/llama_stack/providers/inline/vector_io/chroma/config.py +++ b/llama_stack/providers/inline/vector_io/chroma/config.py @@ -24,7 +24,6 @@ class ChromaVectorIOConfig(BaseModel): return { "db_path": db_path, "kvstore": SqliteKVStoreConfig.sample_run_config( - __distro_dir__=__distro_dir__, - db_name="chroma_inline_registry.db", + __distro_dir__=__distro_dir__, db_name="chroma_inline_registry.db" ), } diff --git a/llama_stack/providers/inline/vector_io/faiss/__init__.py b/llama_stack/providers/inline/vector_io/faiss/__init__.py index e6aa2a1ef..24d1f292a 100644 --- a/llama_stack/providers/inline/vector_io/faiss/__init__.py +++ b/llama_stack/providers/inline/vector_io/faiss/__init__.py @@ -6,29 +6,16 @@ from typing import Any -from llama_stack.core.datatypes import StackRunConfig from llama_stack.providers.datatypes import Api from .config import FaissVectorIOConfig -async def get_provider_impl( - config: FaissVectorIOConfig, deps: dict[Api, Any], run_config: StackRunConfig | None = None -): +async def get_provider_impl(config: FaissVectorIOConfig, deps: dict[Api, Any]): from .faiss import FaissVectorIOAdapter assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}" - vector_stores_config = None - if run_config and run_config.vector_stores: - vector_stores_config = run_config.vector_stores - - impl = FaissVectorIOAdapter( - config, - deps[Api.inference], - deps[Api.models], - deps.get(Api.files), - vector_stores_config, - ) + impl = FaissVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files)) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/vector_io/faiss/config.py b/llama_stack/providers/inline/vector_io/faiss/config.py index cbcbb1762..d21049d8b 100644 --- a/llama_stack/providers/inline/vector_io/faiss/config.py +++ b/llama_stack/providers/inline/vector_io/faiss/config.py @@ -8,10 +8,7 @@ from typing import Any from pydantic import BaseModel -from llama_stack.providers.utils.kvstore.config import ( - KVStoreConfig, - SqliteKVStoreConfig, -) +from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig from llama_stack.schema_utils import json_schema_type @@ -22,8 +19,5 @@ class FaissVectorIOConfig(BaseModel): @classmethod def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: return { - "kvstore": SqliteKVStoreConfig.sample_run_config( - __distro_dir__=__distro_dir__, - db_name="faiss_store.db", - ) + "kvstore": SqliteKVStoreConfig.sample_run_config(__distro_dir__=__distro_dir__, db_name="faiss_store.db") } diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index 7ec73bed5..0a54f9a5e 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -17,28 +17,14 @@ from numpy.typing import NDArray from llama_stack.apis.common.errors import VectorStoreNotFoundError from llama_stack.apis.files import Files from llama_stack.apis.inference import Inference, InterleavedContent -from llama_stack.apis.models import Models from llama_stack.apis.vector_dbs import VectorDB -from llama_stack.apis.vector_io import ( - Chunk, - QueryChunksResponse, - VectorIO, -) -from llama_stack.core.datatypes import VectorStoresConfig +from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.log import get_logger -from llama_stack.providers.datatypes import ( - HealthResponse, - HealthStatus, - VectorDBsProtocolPrivate, -) +from llama_stack.providers.datatypes import HealthResponse, HealthStatus, VectorDBsProtocolPrivate from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin -from llama_stack.providers.utils.memory.vector_store import ( - ChunkForDeletion, - EmbeddingIndex, - VectorDBWithIndex, -) +from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex from .config import FaissVectorIOConfig @@ -156,12 +142,7 @@ class FaissIndex(EmbeddingIndex): await self._save_index() - async def query_vector( - self, - embedding: NDArray, - k: int, - score_threshold: float, - ) -> QueryChunksResponse: + async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: distances, indices = await asyncio.to_thread(self.index.search, embedding.reshape(1, -1).astype(np.float32), k) chunks = [] scores = [] @@ -176,12 +157,7 @@ class FaissIndex(EmbeddingIndex): return QueryChunksResponse(chunks=chunks, scores=scores) - async def query_keyword( - self, - query_string: str, - k: int, - score_threshold: float, - ) -> QueryChunksResponse: + async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse: raise NotImplementedError( "Keyword search is not supported - underlying DB FAISS does not support this search mode" ) @@ -201,19 +177,10 @@ class FaissIndex(EmbeddingIndex): class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): - def __init__( - self, - config: FaissVectorIOConfig, - inference_api: Inference, - models_api: Models, - files_api: Files | None, - vector_stores_config: VectorStoresConfig | None = None, - ) -> None: + def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None: super().__init__(files_api=files_api, kvstore=None) self.config = config self.inference_api = inference_api - self.models_api = models_api - self.vector_stores_config = vector_stores_config self.cache: dict[str, VectorDBWithIndex] = {} async def initialize(self) -> None: @@ -255,17 +222,11 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr except Exception as e: return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}") - async def register_vector_db( - self, - vector_db: VectorDB, - ) -> None: + async def register_vector_db(self, vector_db: VectorDB) -> None: assert self.kvstore is not None key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}" - await self.kvstore.set( - key=key, - value=vector_db.model_dump_json(), - ) + await self.kvstore.set(key=key, value=vector_db.model_dump_json()) # Store in cache self.cache[vector_db.identifier] = VectorDBWithIndex( @@ -288,12 +249,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr del self.cache[vector_db_id] await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_db_id}") - async def insert_chunks( - self, - vector_db_id: str, - chunks: list[Chunk], - ttl_seconds: int | None = None, - ) -> None: + async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None: index = self.cache.get(vector_db_id) if index is None: raise ValueError(f"Vector DB {vector_db_id} not found. found: {self.cache.keys()}") @@ -301,10 +257,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr await index.insert_chunks(chunks) async def query_chunks( - self, - vector_db_id: str, - query: InterleavedContent, - params: dict[str, Any] | None = None, + self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None ) -> QueryChunksResponse: index = self.cache.get(vector_db_id) if index is None: diff --git a/llama_stack/providers/inline/vector_io/milvus/__init__.py b/llama_stack/providers/inline/vector_io/milvus/__init__.py index 239abd35d..7dc9c6a33 100644 --- a/llama_stack/providers/inline/vector_io/milvus/__init__.py +++ b/llama_stack/providers/inline/vector_io/milvus/__init__.py @@ -6,27 +6,14 @@ from typing import Any -from llama_stack.core.datatypes import StackRunConfig from llama_stack.providers.datatypes import Api from .config import MilvusVectorIOConfig -async def get_provider_impl( - config: MilvusVectorIOConfig, deps: dict[Api, Any], run_config: StackRunConfig | None = None -): +async def get_provider_impl(config: MilvusVectorIOConfig, deps: dict[Api, Any]): from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter - vector_stores_config = None - if run_config and run_config.vector_stores: - vector_stores_config = run_config.vector_stores - - impl = MilvusVectorIOAdapter( - config, - deps[Api.inference], - deps.get(Api.models), - deps.get(Api.files), - vector_stores_config, - ) + impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files)) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/vector_io/milvus/config.py b/llama_stack/providers/inline/vector_io/milvus/config.py index 8cbd056be..fe88c1993 100644 --- a/llama_stack/providers/inline/vector_io/milvus/config.py +++ b/llama_stack/providers/inline/vector_io/milvus/config.py @@ -8,10 +8,7 @@ from typing import Any from pydantic import BaseModel, Field -from llama_stack.providers.utils.kvstore.config import ( - KVStoreConfig, - SqliteKVStoreConfig, -) +from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig from llama_stack.schema_utils import json_schema_type @@ -26,7 +23,6 @@ class MilvusVectorIOConfig(BaseModel): return { "db_path": "${env.MILVUS_DB_PATH:=" + __distro_dir__ + "}/" + "milvus.db", "kvstore": SqliteKVStoreConfig.sample_run_config( - __distro_dir__=__distro_dir__, - db_name="milvus_registry.db", + __distro_dir__=__distro_dir__, db_name="milvus_registry.db" ), } diff --git a/llama_stack/providers/inline/vector_io/qdrant/__init__.py b/llama_stack/providers/inline/vector_io/qdrant/__init__.py index efa4e3dcc..bef6d50e6 100644 --- a/llama_stack/providers/inline/vector_io/qdrant/__init__.py +++ b/llama_stack/providers/inline/vector_io/qdrant/__init__.py @@ -6,28 +6,15 @@ from typing import Any -from llama_stack.core.datatypes import StackRunConfig from llama_stack.providers.datatypes import Api from .config import QdrantVectorIOConfig -async def get_provider_impl( - config: QdrantVectorIOConfig, deps: dict[Api, Any], run_config: StackRunConfig | None = None -): +async def get_provider_impl(config: QdrantVectorIOConfig, deps: dict[Api, Any]): from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter - vector_stores_config = None - if run_config and run_config.vector_stores: - vector_stores_config = run_config.vector_stores - assert isinstance(config, QdrantVectorIOConfig), f"Unexpected config type: {type(config)}" - impl = QdrantVectorIOAdapter( - config, - deps[Api.inference], - deps[Api.models], - deps.get(Api.files), - vector_stores_config, - ) + impl = QdrantVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files)) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/vector_io/qdrant/config.py b/llama_stack/providers/inline/vector_io/qdrant/config.py index e15c27ea1..2e4083c11 100644 --- a/llama_stack/providers/inline/vector_io/qdrant/config.py +++ b/llama_stack/providers/inline/vector_io/qdrant/config.py @@ -9,10 +9,7 @@ from typing import Any from pydantic import BaseModel -from llama_stack.providers.utils.kvstore.config import ( - KVStoreConfig, - SqliteKVStoreConfig, -) +from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig from llama_stack.schema_utils import json_schema_type diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py b/llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py index 4748a3874..df96e927c 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py @@ -6,28 +6,15 @@ from typing import Any -from llama_stack.core.datatypes import StackRunConfig from llama_stack.providers.datatypes import Api from .config import SQLiteVectorIOConfig -async def get_provider_impl( - config: SQLiteVectorIOConfig, deps: dict[Api, Any], run_config: StackRunConfig | None = None -): +async def get_provider_impl(config: SQLiteVectorIOConfig, deps: dict[Api, Any]): from .sqlite_vec import SQLiteVecVectorIOAdapter - vector_stores_config = None - if run_config and run_config.vector_stores: - vector_stores_config = run_config.vector_stores - assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}" - impl = SQLiteVecVectorIOAdapter( - config, - deps[Api.inference], - deps[Api.models], - deps.get(Api.files), - vector_stores_config, - ) + impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files)) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/config.py b/llama_stack/providers/inline/vector_io/sqlite_vec/config.py index 525ed4b1f..90ad86215 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/config.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/config.py @@ -8,10 +8,7 @@ from typing import Any from pydantic import BaseModel, Field -from llama_stack.providers.utils.kvstore.config import ( - KVStoreConfig, - SqliteKVStoreConfig, -) +from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig class SQLiteVectorIOConfig(BaseModel): @@ -23,7 +20,6 @@ class SQLiteVectorIOConfig(BaseModel): return { "db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + "sqlite_vec.db", "kvstore": SqliteKVStoreConfig.sample_run_config( - __distro_dir__=__distro_dir__, - db_name="sqlite_vec_registry.db", + __distro_dir__=__distro_dir__, db_name="sqlite_vec_registry.db" ), } diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index 94d06e830..9fcdf39b0 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -17,14 +17,8 @@ from numpy.typing import NDArray from llama_stack.apis.common.errors import VectorStoreNotFoundError from llama_stack.apis.files import Files from llama_stack.apis.inference import Inference -from llama_stack.apis.models import Models from llama_stack.apis.vector_dbs import VectorDB -from llama_stack.apis.vector_io import ( - Chunk, - QueryChunksResponse, - VectorIO, -) -from llama_stack.core.datatypes import VectorStoresConfig +from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.log import get_logger from llama_stack.providers.datatypes import VectorDBsProtocolPrivate from llama_stack.providers.utils.kvstore import kvstore_impl @@ -176,32 +170,18 @@ class SQLiteVecIndex(EmbeddingIndex): # Insert vector embeddings embedding_data = [ - ( - ( - chunk.chunk_id, - serialize_vector(emb.tolist()), - ) - ) + ((chunk.chunk_id, serialize_vector(emb.tolist()))) for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True) ] - cur.executemany( - f"INSERT INTO [{self.vector_table}] (id, embedding) VALUES (?, ?);", - embedding_data, - ) + cur.executemany(f"INSERT INTO [{self.vector_table}] (id, embedding) VALUES (?, ?);", embedding_data) # Insert FTS content fts_data = [(chunk.chunk_id, chunk.content) for chunk in batch_chunks] # DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT) - cur.executemany( - f"DELETE FROM [{self.fts_table}] WHERE id = ?;", - [(row[0],) for row in fts_data], - ) + cur.executemany(f"DELETE FROM [{self.fts_table}] WHERE id = ?;", [(row[0],) for row in fts_data]) # INSERT new entries - cur.executemany( - f"INSERT INTO [{self.fts_table}] (id, content) VALUES (?, ?);", - fts_data, - ) + cur.executemany(f"INSERT INTO [{self.fts_table}] (id, content) VALUES (?, ?);", fts_data) connection.commit() @@ -217,12 +197,7 @@ class SQLiteVecIndex(EmbeddingIndex): # Run batch insertion in a background thread await asyncio.to_thread(_execute_all_batch_inserts) - async def query_vector( - self, - embedding: NDArray, - k: int, - score_threshold: float, - ) -> QueryChunksResponse: + async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: """ Performs vector-based search using a virtual table for vector similarity. """ @@ -262,12 +237,7 @@ class SQLiteVecIndex(EmbeddingIndex): scores.append(score) return QueryChunksResponse(chunks=chunks, scores=scores) - async def query_keyword( - self, - query_string: str, - k: int, - score_threshold: float, - ) -> QueryChunksResponse: + async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse: """ Performs keyword-based search using SQLite FTS5 for relevance-ranked full-text search. """ @@ -411,19 +381,10 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex). """ - def __init__( - self, - config, - inference_api: Inference, - models_api: Models, - files_api: Files | None, - vector_stores_config: VectorStoresConfig | None = None, - ) -> None: + def __init__(self, config, inference_api: Inference, files_api: Files | None) -> None: super().__init__(files_api=files_api, kvstore=None) self.config = config self.inference_api = inference_api - self.models_api = models_api - self.vector_stores_config = vector_stores_config self.cache: dict[str, VectorDBWithIndex] = {} self.vector_db_store = None @@ -436,9 +397,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc for db_json in stored_vector_dbs: vector_db = VectorDB.model_validate_json(db_json) index = await SQLiteVecIndex.create( - vector_db.embedding_dimension, - self.config.db_path, - vector_db.identifier, + vector_db.embedding_dimension, self.config.db_path, vector_db.identifier ) self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) @@ -453,11 +412,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc return [v.vector_db for v in self.cache.values()] async def register_vector_db(self, vector_db: VectorDB) -> None: - index = await SQLiteVecIndex.create( - vector_db.embedding_dimension, - self.config.db_path, - vector_db.identifier, - ) + index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.config.db_path, vector_db.identifier) self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None: diff --git a/llama_stack/providers/remote/vector_io/chroma/__init__.py b/llama_stack/providers/remote/vector_io/chroma/__init__.py index e0c9df638..e4b77c68d 100644 --- a/llama_stack/providers/remote/vector_io/chroma/__init__.py +++ b/llama_stack/providers/remote/vector_io/chroma/__init__.py @@ -4,27 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.core.datatypes import StackRunConfig from llama_stack.providers.datatypes import Api, ProviderSpec from .config import ChromaVectorIOConfig -async def get_adapter_impl( - config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec], run_config: StackRunConfig | None = None -): +async def get_adapter_impl(config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec]): from .chroma import ChromaVectorIOAdapter - vector_stores_config = None - if run_config and run_config.vector_stores: - vector_stores_config = run_config.vector_stores - - impl = ChromaVectorIOAdapter( - config, - deps[Api.inference], - deps[Api.models], - deps.get(Api.files), - vector_stores_config, - ) + impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files)) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index 9ed9672cd..0310aa3d3 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -13,25 +13,15 @@ from numpy.typing import NDArray from llama_stack.apis.files import Files from llama_stack.apis.inference import Inference, InterleavedContent -from llama_stack.apis.models import Models from llama_stack.apis.vector_dbs import VectorDB -from llama_stack.apis.vector_io import ( - Chunk, - QueryChunksResponse, - VectorIO, -) -from llama_stack.core.datatypes import VectorStoresConfig +from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.log import get_logger from llama_stack.providers.datatypes import VectorDBsProtocolPrivate from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin -from llama_stack.providers.utils.memory.vector_store import ( - ChunkForDeletion, - EmbeddingIndex, - VectorDBWithIndex, -) +from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig @@ -70,19 +60,13 @@ class ChromaIndex(EmbeddingIndex): ids = [f"{c.metadata.get('document_id', '')}:{c.chunk_id}" for c in chunks] await maybe_await( - self.collection.add( - documents=[chunk.model_dump_json() for chunk in chunks], - embeddings=embeddings, - ids=ids, - ) + self.collection.add(documents=[chunk.model_dump_json() for chunk in chunks], embeddings=embeddings, ids=ids) ) async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: results = await maybe_await( self.collection.query( - query_embeddings=[embedding.tolist()], - n_results=k, - include=["documents", "distances"], + query_embeddings=[embedding.tolist()], n_results=k, include=["documents", "distances"] ) ) distances = results["distances"][0] @@ -110,12 +94,7 @@ class ChromaIndex(EmbeddingIndex): async def delete(self): await maybe_await(self.client.delete_collection(self.collection.name)) - async def query_keyword( - self, - query_string: str, - k: int, - score_threshold: float, - ) -> QueryChunksResponse: + async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse: raise NotImplementedError("Keyword search is not supported in Chroma") async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None: @@ -140,16 +119,12 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP self, config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig, inference_api: Inference, - models_apis: Models, files_api: Files | None, - vector_stores_config: VectorStoresConfig | None = None, ) -> None: super().__init__(files_api=files_api, kvstore=None) log.info(f"Initializing ChromaVectorIOAdapter with url: {config}") self.config = config self.inference_api = inference_api - self.models_api = models_apis - self.vector_stores_config = vector_stores_config self.client = None self.cache = {} self.vector_db_store = None @@ -176,14 +151,10 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP # Clean up mixin resources (file batch tasks) await super().shutdown() - async def register_vector_db( - self, - vector_db: VectorDB, - ) -> None: + async def register_vector_db(self, vector_db: VectorDB) -> None: collection = await maybe_await( self.client.get_or_create_collection( - name=vector_db.identifier, - metadata={"vector_db": vector_db.model_dump_json()}, + name=vector_db.identifier, metadata={"vector_db": vector_db.model_dump_json()} ) ) self.cache[vector_db.identifier] = VectorDBWithIndex( @@ -198,12 +169,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP await self.cache[vector_db_id].index.delete() del self.cache[vector_db_id] - async def insert_chunks( - self, - vector_db_id: str, - chunks: list[Chunk], - ttl_seconds: int | None = None, - ) -> None: + async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None: index = await self._get_and_cache_vector_db_index(vector_db_id) if index is None: raise ValueError(f"Vector DB {vector_db_id} not found in Chroma") @@ -211,10 +177,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP await index.insert_chunks(chunks) async def query_chunks( - self, - vector_db_id: str, - query: InterleavedContent, - params: dict[str, Any] | None = None, + self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None ) -> QueryChunksResponse: index = await self._get_and_cache_vector_db_index(vector_db_id) diff --git a/llama_stack/providers/remote/vector_io/chroma/config.py b/llama_stack/providers/remote/vector_io/chroma/config.py index a1193905a..c870a75b9 100644 --- a/llama_stack/providers/remote/vector_io/chroma/config.py +++ b/llama_stack/providers/remote/vector_io/chroma/config.py @@ -22,7 +22,6 @@ class ChromaVectorIOConfig(BaseModel): return { "url": url, "kvstore": SqliteKVStoreConfig.sample_run_config( - __distro_dir__=__distro_dir__, - db_name="chroma_remote_registry.db", + __distro_dir__=__distro_dir__, db_name="chroma_remote_registry.db" ), } diff --git a/llama_stack/providers/remote/vector_io/milvus/__init__.py b/llama_stack/providers/remote/vector_io/milvus/__init__.py index 0b9721beb..526075bb2 100644 --- a/llama_stack/providers/remote/vector_io/milvus/__init__.py +++ b/llama_stack/providers/remote/vector_io/milvus/__init__.py @@ -4,28 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.core.datatypes import StackRunConfig from llama_stack.providers.datatypes import Api, ProviderSpec from .config import MilvusVectorIOConfig -async def get_adapter_impl( - config: MilvusVectorIOConfig, deps: dict[Api, ProviderSpec], run_config: StackRunConfig | None = None -): +async def get_adapter_impl(config: MilvusVectorIOConfig, deps: dict[Api, ProviderSpec]): from .milvus import MilvusVectorIOAdapter - vector_stores_config = None - if run_config and run_config.vector_stores: - vector_stores_config = run_config.vector_stores - assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}" - impl = MilvusVectorIOAdapter( - config, - deps[Api.inference], - deps[Api.models], - deps.get(Api.files), - vector_stores_config, - ) + impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files)) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/vector_io/milvus/config.py b/llama_stack/providers/remote/vector_io/milvus/config.py index 899d3678d..3134c8966 100644 --- a/llama_stack/providers/remote/vector_io/milvus/config.py +++ b/llama_stack/providers/remote/vector_io/milvus/config.py @@ -29,7 +29,6 @@ class MilvusVectorIOConfig(BaseModel): "uri": "${env.MILVUS_ENDPOINT}", "token": "${env.MILVUS_TOKEN}", "kvstore": SqliteKVStoreConfig.sample_run_config( - __distro_dir__=__distro_dir__, - db_name="milvus_remote_registry.db", + __distro_dir__=__distro_dir__, db_name="milvus_remote_registry.db" ), } diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index dd01f6637..3278fea0c 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -14,14 +14,8 @@ from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusC from llama_stack.apis.common.errors import VectorStoreNotFoundError from llama_stack.apis.files import Files from llama_stack.apis.inference import Inference, InterleavedContent -from llama_stack.apis.models import Models from llama_stack.apis.vector_dbs import VectorDB -from llama_stack.apis.vector_io import ( - Chunk, - QueryChunksResponse, - VectorIO, -) -from llama_stack.core.datatypes import VectorStoresConfig +from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.log import get_logger from llama_stack.providers.datatypes import VectorDBsProtocolPrivate from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig @@ -75,46 +69,23 @@ class MilvusIndex(EmbeddingIndex): logger.info(f"Creating new collection {self.collection_name} with nullable sparse field") # Create schema for vector search schema = self.client.create_schema() - schema.add_field( - field_name="chunk_id", - datatype=DataType.VARCHAR, - is_primary=True, - max_length=100, - ) + schema.add_field(field_name="chunk_id", datatype=DataType.VARCHAR, is_primary=True, max_length=100) schema.add_field( field_name="content", datatype=DataType.VARCHAR, max_length=65535, enable_analyzer=True, # Enable text analysis for BM25 ) - schema.add_field( - field_name="vector", - datatype=DataType.FLOAT_VECTOR, - dim=len(embeddings[0]), - ) - schema.add_field( - field_name="chunk_content", - datatype=DataType.JSON, - ) + schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=len(embeddings[0])) + schema.add_field(field_name="chunk_content", datatype=DataType.JSON) # Add sparse vector field for BM25 (required by the function) - schema.add_field( - field_name="sparse", - datatype=DataType.SPARSE_FLOAT_VECTOR, - ) + schema.add_field(field_name="sparse", datatype=DataType.SPARSE_FLOAT_VECTOR) # Create indexes index_params = self.client.prepare_index_params() - index_params.add_index( - field_name="vector", - index_type="FLAT", - metric_type="COSINE", - ) + index_params.add_index(field_name="vector", index_type="FLAT", metric_type="COSINE") # Add index for sparse field (required by BM25 function) - index_params.add_index( - field_name="sparse", - index_type="SPARSE_INVERTED_INDEX", - metric_type="BM25", - ) + index_params.add_index(field_name="sparse", index_type="SPARSE_INVERTED_INDEX", metric_type="BM25") # Add BM25 function for full-text search bm25_function = Function( @@ -145,11 +116,7 @@ class MilvusIndex(EmbeddingIndex): } ) try: - await asyncio.to_thread( - self.client.insert, - self.collection_name, - data=data, - ) + await asyncio.to_thread(self.client.insert, self.collection_name, data=data) except Exception as e: logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}") raise e @@ -168,12 +135,7 @@ class MilvusIndex(EmbeddingIndex): scores = [res["distance"] for res in search_res[0]] return QueryChunksResponse(chunks=chunks, scores=scores) - async def query_keyword( - self, - query_string: str, - k: int, - score_threshold: float, - ) -> QueryChunksResponse: + async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse: """ Perform BM25-based keyword search using Milvus's built-in full-text search. """ @@ -211,12 +173,7 @@ class MilvusIndex(EmbeddingIndex): # Fallback to simple text search return await self._fallback_keyword_search(query_string, k, score_threshold) - async def _fallback_keyword_search( - self, - query_string: str, - k: int, - score_threshold: float, - ) -> QueryChunksResponse: + async def _fallback_keyword_search(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse: """ Fallback to simple text search when BM25 search is not available. """ @@ -309,17 +266,13 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP self, config: RemoteMilvusVectorIOConfig | InlineMilvusVectorIOConfig, inference_api: Inference, - models_api: Models | None, files_api: Files | None, - vector_stores_config: VectorStoresConfig | None = None, ) -> None: super().__init__(files_api=files_api, kvstore=None) self.config = config self.cache = {} self.client = None self.inference_api = inference_api - self.models_api = models_api - self.vector_stores_config = vector_stores_config self.vector_db_store = None self.metadata_collection_name = "openai_vector_stores_metadata" @@ -358,10 +311,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP # Clean up mixin resources (file batch tasks) await super().shutdown() - async def register_vector_db( - self, - vector_db: VectorDB, - ) -> None: + async def register_vector_db(self, vector_db: VectorDB) -> None: if isinstance(self.config, RemoteMilvusVectorIOConfig): consistency_level = self.config.consistency_level else: @@ -398,12 +348,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP await self.cache[vector_db_id].index.delete() del self.cache[vector_db_id] - async def insert_chunks( - self, - vector_db_id: str, - chunks: list[Chunk], - ttl_seconds: int | None = None, - ) -> None: + async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None: index = await self._get_and_cache_vector_db_index(vector_db_id) if not index: raise VectorStoreNotFoundError(vector_db_id) @@ -411,10 +356,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP await index.insert_chunks(chunks) async def query_chunks( - self, - vector_db_id: str, - query: InterleavedContent, - params: dict[str, Any] | None = None, + self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None ) -> QueryChunksResponse: index = await self._get_and_cache_vector_db_index(vector_db_id) if not index: diff --git a/llama_stack/providers/remote/vector_io/pgvector/__init__.py b/llama_stack/providers/remote/vector_io/pgvector/__init__.py index 309c8e159..8086b7650 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/__init__.py +++ b/llama_stack/providers/remote/vector_io/pgvector/__init__.py @@ -4,26 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.core.datatypes import StackRunConfig from llama_stack.providers.datatypes import Api, ProviderSpec from .config import PGVectorVectorIOConfig -async def get_adapter_impl( - config: PGVectorVectorIOConfig, deps: dict[Api, ProviderSpec], run_config: StackRunConfig | None = None -): +async def get_adapter_impl(config: PGVectorVectorIOConfig, deps: dict[Api, ProviderSpec]): from .pgvector import PGVectorVectorIOAdapter - vector_stores_config = None - if run_config and run_config.vector_stores: - vector_stores_config = run_config.vector_stores - impl = PGVectorVectorIOAdapter( - config, - deps[Api.inference], - deps[Api.models], - deps.get(Api.files), - vector_stores_config, - ) + impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files)) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/vector_io/pgvector/config.py b/llama_stack/providers/remote/vector_io/pgvector/config.py index 334cbe5be..e8a637958 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/config.py +++ b/llama_stack/providers/remote/vector_io/pgvector/config.py @@ -8,10 +8,7 @@ from typing import Any from pydantic import BaseModel, Field -from llama_stack.providers.utils.kvstore.config import ( - KVStoreConfig, - SqliteKVStoreConfig, -) +from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig from llama_stack.schema_utils import json_schema_type @@ -42,7 +39,6 @@ class PGVectorVectorIOConfig(BaseModel): "user": user, "password": password, "kvstore": SqliteKVStoreConfig.sample_run_config( - __distro_dir__=__distro_dir__, - db_name="pgvector_registry.db", + __distro_dir__=__distro_dir__, db_name="pgvector_registry.db" ), } diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index e58317bf0..1e167b60e 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -16,27 +16,15 @@ from pydantic import BaseModel, TypeAdapter from llama_stack.apis.common.errors import VectorStoreNotFoundError from llama_stack.apis.files import Files from llama_stack.apis.inference import Inference, InterleavedContent -from llama_stack.apis.models import Models from llama_stack.apis.vector_dbs import VectorDB -from llama_stack.apis.vector_io import ( - Chunk, - QueryChunksResponse, - VectorIO, -) -from llama_stack.core.datatypes import VectorStoresConfig +from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.log import get_logger from llama_stack.providers.datatypes import VectorDBsProtocolPrivate -from llama_stack.providers.utils.inference.prompt_adapter import ( - interleaved_content_as_str, -) +from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin -from llama_stack.providers.utils.memory.vector_store import ( - ChunkForDeletion, - EmbeddingIndex, - VectorDBWithIndex, -) +from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator, sanitize_collection_name from .config import PGVectorVectorIOConfig @@ -206,12 +194,7 @@ class PGVectorIndex(EmbeddingIndex): return QueryChunksResponse(chunks=chunks, scores=scores) - async def query_keyword( - self, - query_string: str, - k: int, - score_threshold: float, - ) -> QueryChunksResponse: + async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse: """ Performs keyword-based search using PostgreSQL's full-text search with ts_rank scoring. @@ -318,7 +301,7 @@ class PGVectorIndex(EmbeddingIndex): """Remove a chunk from the PostgreSQL table.""" chunk_ids = [c.chunk_id for c in chunks_for_deletion] with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: - cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids,)) + cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids)) def get_pgvector_search_function(self) -> str: return self.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION[self.distance_metric] @@ -342,18 +325,11 @@ class PGVectorIndex(EmbeddingIndex): class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): def __init__( - self, - config: PGVectorVectorIOConfig, - inference_api: Inference, - models_api: Models, - files_api: Files | None = None, - vector_stores_config: VectorStoresConfig | None = None, + self, config: PGVectorVectorIOConfig, inference_api: Inference, files_api: Files | None = None ) -> None: super().__init__(files_api=files_api, kvstore=None) self.config = config self.inference_api = inference_api - self.models_api = models_api - self.vector_stores_config = vector_stores_config self.conn = None self.cache = {} self.vector_db_store = None @@ -410,11 +386,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco vector_db=vector_db, dimension=vector_db.embedding_dimension, conn=self.conn, kvstore=self.kvstore ) await pgvector_index.initialize() - index = VectorDBWithIndex( - vector_db, - index=pgvector_index, - inference_api=self.inference_api, - ) + index = VectorDBWithIndex(vector_db, index=pgvector_index, inference_api=self.inference_api) self.cache[vector_db.identifier] = index async def unregister_vector_db(self, vector_db_id: str) -> None: @@ -427,20 +399,12 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco assert self.kvstore is not None await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_db_id}") - async def insert_chunks( - self, - vector_db_id: str, - chunks: list[Chunk], - ttl_seconds: int | None = None, - ) -> None: + async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None: index = await self._get_and_cache_vector_db_index(vector_db_id) await index.insert_chunks(chunks) async def query_chunks( - self, - vector_db_id: str, - query: InterleavedContent, - params: dict[str, Any] | None = None, + self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None ) -> QueryChunksResponse: index = await self._get_and_cache_vector_db_index(vector_db_id) return await index.query_chunks(query, params) diff --git a/llama_stack/providers/remote/vector_io/qdrant/__init__.py b/llama_stack/providers/remote/vector_io/qdrant/__init__.py index abdef9775..e9527f101 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/__init__.py +++ b/llama_stack/providers/remote/vector_io/qdrant/__init__.py @@ -4,27 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.core.datatypes import StackRunConfig from llama_stack.providers.datatypes import Api, ProviderSpec from .config import QdrantVectorIOConfig -async def get_adapter_impl( - config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec], run_config: StackRunConfig | None = None -): +async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]): from .qdrant import QdrantVectorIOAdapter - vector_stores_config = None - if run_config and run_config.vector_stores: - vector_stores_config = run_config.vector_stores - - impl = QdrantVectorIOAdapter( - config, - deps[Api.inference], - deps[Api.models], - deps.get(Api.files), - vector_stores_config, - ) + impl = QdrantVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files)) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/vector_io/qdrant/config.py b/llama_stack/providers/remote/vector_io/qdrant/config.py index ff5506236..5aff2241b 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/config.py +++ b/llama_stack/providers/remote/vector_io/qdrant/config.py @@ -8,10 +8,7 @@ from typing import Any from pydantic import BaseModel -from llama_stack.providers.utils.kvstore.config import ( - KVStoreConfig, - SqliteKVStoreConfig, -) +from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig from llama_stack.schema_utils import json_schema_type @@ -34,7 +31,6 @@ class QdrantVectorIOConfig(BaseModel): return { "api_key": "${env.QDRANT_API_KEY:=}", "kvstore": SqliteKVStoreConfig.sample_run_config( - __distro_dir__=__distro_dir__, - db_name="qdrant_registry.db", + __distro_dir__=__distro_dir__, db_name="qdrant_registry.db" ), } diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index db1e21a6c..b0f8123a8 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -16,7 +16,6 @@ from qdrant_client.models import PointStruct from llama_stack.apis.common.errors import VectorStoreNotFoundError from llama_stack.apis.files import Files from llama_stack.apis.inference import Inference, InterleavedContent -from llama_stack.apis.models import Models from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import ( Chunk, @@ -25,17 +24,12 @@ from llama_stack.apis.vector_io import ( VectorStoreChunkingStrategy, VectorStoreFileObject, ) -from llama_stack.core.datatypes import VectorStoresConfig from llama_stack.log import get_logger from llama_stack.providers.datatypes import VectorDBsProtocolPrivate from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin -from llama_stack.providers.utils.memory.vector_store import ( - ChunkForDeletion, - EmbeddingIndex, - VectorDBWithIndex, -) +from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig @@ -100,8 +94,7 @@ class QdrantIndex(EmbeddingIndex): chunk_ids = [convert_id(c.chunk_id) for c in chunks_for_deletion] try: await self.client.delete( - collection_name=self.collection_name, - points_selector=models.PointIdsList(points=chunk_ids), + collection_name=self.collection_name, points_selector=models.PointIdsList(points=chunk_ids) ) except Exception as e: log.error(f"Error deleting chunks from Qdrant collection {self.collection_name}: {e}") @@ -134,12 +127,7 @@ class QdrantIndex(EmbeddingIndex): return QueryChunksResponse(chunks=chunks, scores=scores) - async def query_keyword( - self, - query_string: str, - k: int, - score_threshold: float, - ) -> QueryChunksResponse: + async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse: raise NotImplementedError("Keyword search is not supported in Qdrant") async def query_hybrid( @@ -162,17 +150,13 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP self, config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig, inference_api: Inference, - models_api: Models, files_api: Files | None = None, - vector_stores_config: VectorStoresConfig | None = None, ) -> None: super().__init__(files_api=files_api, kvstore=None) self.config = config self.client: AsyncQdrantClient = None self.cache = {} self.inference_api = inference_api - self.models_api = models_api - self.vector_stores_config = vector_stores_config self.vector_db_store = None self._qdrant_lock = asyncio.Lock() @@ -187,11 +171,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP for vector_db_data in stored_vector_dbs: vector_db = VectorDB.model_validate_json(vector_db_data) - index = VectorDBWithIndex( - vector_db, - QdrantIndex(self.client, vector_db.identifier), - self.inference_api, - ) + index = VectorDBWithIndex(vector_db, QdrantIndex(self.client, vector_db.identifier), self.inference_api) self.cache[vector_db.identifier] = index self.openai_vector_stores = await self._load_openai_vector_stores() @@ -200,18 +180,13 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP # Clean up mixin resources (file batch tasks) await super().shutdown() - async def register_vector_db( - self, - vector_db: VectorDB, - ) -> None: + async def register_vector_db(self, vector_db: VectorDB) -> None: assert self.kvstore is not None key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}" await self.kvstore.set(key=key, value=vector_db.model_dump_json()) index = VectorDBWithIndex( - vector_db=vector_db, - index=QdrantIndex(self.client, vector_db.identifier), - inference_api=self.inference_api, + vector_db=vector_db, index=QdrantIndex(self.client, vector_db.identifier), inference_api=self.inference_api ) self.cache[vector_db.identifier] = index @@ -243,12 +218,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP self.cache[vector_db_id] = index return index - async def insert_chunks( - self, - vector_db_id: str, - chunks: list[Chunk], - ttl_seconds: int | None = None, - ) -> None: + async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None: index = await self._get_and_cache_vector_db_index(vector_db_id) if not index: raise VectorStoreNotFoundError(vector_db_id) @@ -256,10 +226,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP await index.insert_chunks(chunks) async def query_chunks( - self, - vector_db_id: str, - query: InterleavedContent, - params: dict[str, Any] | None = None, + self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None ) -> QueryChunksResponse: index = await self._get_and_cache_vector_db_index(vector_db_id) if not index: diff --git a/llama_stack/providers/remote/vector_io/weaviate/__init__.py b/llama_stack/providers/remote/vector_io/weaviate/__init__.py index ab0277cc7..12e11d013 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/__init__.py +++ b/llama_stack/providers/remote/vector_io/weaviate/__init__.py @@ -4,27 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.core.datatypes import StackRunConfig from llama_stack.providers.datatypes import Api, ProviderSpec from .config import WeaviateVectorIOConfig -async def get_adapter_impl( - config: WeaviateVectorIOConfig, deps: dict[Api, ProviderSpec], run_config: StackRunConfig | None = None -): +async def get_adapter_impl(config: WeaviateVectorIOConfig, deps: dict[Api, ProviderSpec]): from .weaviate import WeaviateVectorIOAdapter - vector_stores_config = None - if run_config and run_config.vector_stores: - vector_stores_config = run_config.vector_stores - - impl = WeaviateVectorIOAdapter( - config, - deps[Api.inference], - deps[Api.models], - deps.get(Api.files), - vector_stores_config, - ) + impl = WeaviateVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files)) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/vector_io/weaviate/config.py b/llama_stack/providers/remote/vector_io/weaviate/config.py index b693e294e..b1b07d208 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/config.py +++ b/llama_stack/providers/remote/vector_io/weaviate/config.py @@ -8,10 +8,7 @@ from typing import Any from pydantic import BaseModel, Field -from llama_stack.providers.utils.kvstore.config import ( - KVStoreConfig, - SqliteKVStoreConfig, -) +from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig from llama_stack.schema_utils import json_schema_type @@ -22,16 +19,11 @@ class WeaviateVectorIOConfig(BaseModel): kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None) @classmethod - def sample_run_config( - cls, - __distro_dir__: str, - **kwargs: Any, - ) -> dict[str, Any]: + def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: return { "weaviate_api_key": None, "weaviate_cluster_url": "${env.WEAVIATE_CLUSTER_URL:=localhost:8080}", "kvstore": SqliteKVStoreConfig.sample_run_config( - __distro_dir__=__distro_dir__, - db_name="weaviate_registry.db", + __distro_dir__=__distro_dir__, db_name="weaviate_registry.db" ), } diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index f8046e127..d7bb07d4f 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -16,18 +16,14 @@ from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.errors import VectorStoreNotFoundError from llama_stack.apis.files import Files from llama_stack.apis.inference import Inference -from llama_stack.apis.models import Models from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO -from llama_stack.core.datatypes import VectorStoresConfig from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger from llama_stack.providers.datatypes import VectorDBsProtocolPrivate from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore -from llama_stack.providers.utils.memory.openai_vector_store_mixin import ( - OpenAIVectorStoreMixin, -) +from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ( RERANKER_TYPE_RRF, ChunkForDeletion, @@ -49,12 +45,7 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten class WeaviateIndex(EmbeddingIndex): - def __init__( - self, - client: weaviate.WeaviateClient, - collection_name: str, - kvstore: KVStore | None = None, - ): + def __init__(self, client: weaviate.WeaviateClient, collection_name: str, kvstore: KVStore | None = None): self.client = client self.collection_name = sanitize_collection_name(collection_name, weaviate_format=True) self.kvstore = kvstore @@ -109,9 +100,7 @@ class WeaviateIndex(EmbeddingIndex): try: results = collection.query.near_vector( - near_vector=embedding.tolist(), - limit=k, - return_metadata=wvc.query.MetadataQuery(distance=True), + near_vector=embedding.tolist(), limit=k, return_metadata=wvc.query.MetadataQuery(distance=True) ) except Exception as e: log.error(f"Weaviate client vector search failed: {e}") @@ -154,12 +143,7 @@ class WeaviateIndex(EmbeddingIndex): collection = self.client.collections.get(sanitized_collection_name) collection.data.delete_many(where=Filter.by_property("id").contains_any(chunk_ids)) - async def query_keyword( - self, - query_string: str, - k: int, - score_threshold: float, - ) -> QueryChunksResponse: + async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse: """ Performs BM25-based keyword search using Weaviate's built-in full-text search. Args: @@ -176,9 +160,7 @@ class WeaviateIndex(EmbeddingIndex): # Perform BM25 keyword search on chunk_content field try: results = collection.query.bm25( - query=query_string, - limit=k, - return_metadata=wvc.query.MetadataQuery(score=True), + query=query_string, limit=k, return_metadata=wvc.query.MetadataQuery(score=True) ) except Exception as e: log.error(f"Weaviate client keyword search failed: {e}") @@ -275,25 +257,11 @@ class WeaviateIndex(EmbeddingIndex): return QueryChunksResponse(chunks=chunks, scores=scores) -class WeaviateVectorIOAdapter( - OpenAIVectorStoreMixin, - VectorIO, - NeedsRequestProviderData, - VectorDBsProtocolPrivate, -): - def __init__( - self, - config: WeaviateVectorIOConfig, - inference_api: Inference, - models_api: Models, - files_api: Files | None, - vector_stores_config: VectorStoresConfig | None = None, - ) -> None: +class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProviderData, VectorDBsProtocolPrivate): + def __init__(self, config: WeaviateVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None: super().__init__(files_api=files_api, kvstore=None) self.config = config self.inference_api = inference_api - self.models_api = models_api - self.vector_stores_config = vector_stores_config self.client_cache = {} self.cache = {} self.vector_db_store = None @@ -304,10 +272,7 @@ class WeaviateVectorIOAdapter( log.info("Using Weaviate locally in container") host, port = self.config.weaviate_cluster_url.split(":") key = "local_test" - client = weaviate.connect_to_local( - host=host, - port=port, - ) + client = weaviate.connect_to_local(host=host, port=port) else: log.info("Using Weaviate remote cluster with URL") key = f"{self.config.weaviate_cluster_url}::{self.config.weaviate_api_key}" @@ -337,15 +302,9 @@ class WeaviateVectorIOAdapter( for raw in stored: vector_db = VectorDB.model_validate_json(raw) client = self._get_client() - idx = WeaviateIndex( - client=client, - collection_name=vector_db.identifier, - kvstore=self.kvstore, - ) + idx = WeaviateIndex(client=client, collection_name=vector_db.identifier, kvstore=self.kvstore) self.cache[vector_db.identifier] = VectorDBWithIndex( - vector_db=vector_db, - index=idx, - inference_api=self.inference_api, + vector_db=vector_db, index=idx, inference_api=self.inference_api ) # Load OpenAI vector stores metadata into cache @@ -357,10 +316,7 @@ class WeaviateVectorIOAdapter( # Clean up mixin resources (file batch tasks) await super().shutdown() - async def register_vector_db( - self, - vector_db: VectorDB, - ) -> None: + async def register_vector_db(self, vector_db: VectorDB) -> None: client = self._get_client() sanitized_collection_name = sanitize_collection_name(vector_db.identifier, weaviate_format=True) # Create collection if it doesn't exist @@ -369,17 +325,12 @@ class WeaviateVectorIOAdapter( name=sanitized_collection_name, vectorizer_config=wvc.config.Configure.Vectorizer.none(), properties=[ - wvc.config.Property( - name="chunk_content", - data_type=wvc.config.DataType.TEXT, - ), + wvc.config.Property(name="chunk_content", data_type=wvc.config.DataType.TEXT), ], ) self.cache[vector_db.identifier] = VectorDBWithIndex( - vector_db, - WeaviateIndex(client=client, collection_name=sanitized_collection_name), - self.inference_api, + vector_db, WeaviateIndex(client=client, collection_name=sanitized_collection_name), self.inference_api ) async def unregister_vector_db(self, vector_db_id: str) -> None: @@ -415,12 +366,7 @@ class WeaviateVectorIOAdapter( self.cache[vector_db_id] = index return index - async def insert_chunks( - self, - vector_db_id: str, - chunks: list[Chunk], - ttl_seconds: int | None = None, - ) -> None: + async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None: index = await self._get_and_cache_vector_db_index(vector_db_id) if not index: raise VectorStoreNotFoundError(vector_db_id) @@ -428,10 +374,7 @@ class WeaviateVectorIOAdapter( await index.insert_chunks(chunks) async def query_chunks( - self, - vector_db_id: str, - query: InterleavedContent, - params: dict[str, Any] | None = None, + self, vector_db_id: str, query: InterleavedContent, params: dict[str, Any] | None = None ) -> QueryChunksResponse: index = await self._get_and_cache_vector_db_index(vector_db_id) if not index: diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index fbbe42d28..7806d98c1 100644 --- a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -17,7 +17,6 @@ from pydantic import TypeAdapter from llama_stack.apis.common.errors import VectorStoreNotFoundError from llama_stack.apis.files import Files, OpenAIFileObject -from llama_stack.apis.models import Model, Models from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import ( Chunk, @@ -44,7 +43,6 @@ from llama_stack.apis.vector_io import ( VectorStoreSearchResponse, VectorStoreSearchResponsePage, ) -from llama_stack.core.datatypes import VectorStoresConfig from llama_stack.core.id_generation import generate_object_id from llama_stack.log import get_logger from llama_stack.providers.utils.kvstore.api import KVStore @@ -90,9 +88,6 @@ class OpenAIVectorStoreMixin(ABC): self.openai_file_batches: dict[str, dict[str, Any]] = {} self.files_api = files_api self.kvstore = kvstore - # These will be set by implementing classes - self.models_api: Models | None = None - self.vector_stores_config: VectorStoresConfig | None = None self._last_file_batch_cleanup_time = 0 self._file_batch_tasks: dict[str, asyncio.Task[None]] = {} @@ -398,21 +393,7 @@ class OpenAIVectorStoreMixin(ABC): vector_db_id = provider_vector_db_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}") if embedding_model is None: - result = await self._get_default_embedding_model_and_dimension() - if result is None: - raise ValueError( - "embedding_model is required in extra_body when creating a vector store. " - "No default embedding model could be determined automatically." - ) - embedding_model, embedding_dimension = result - elif embedding_dimension is None: - # Embedding model was provided but dimension wasn't, look it up - embedding_dimension = await self._get_embedding_dimension_for_model(embedding_model) - if embedding_dimension is None: - raise ValueError( - f"Could not determine embedding dimension for model '{embedding_model}'. " - "Please provide embedding_dimension in extra_body or ensure the model metadata contains embedding_dimension." - ) + raise ValueError("embedding_model is required") if embedding_dimension is None: raise ValueError("Embedding dimension is required") @@ -479,64 +460,6 @@ class OpenAIVectorStoreMixin(ABC): store_info = self.openai_vector_stores[vector_db_id] return VectorStoreObject.model_validate(store_info) - async def _get_embedding_dimension_for_model(self, model_id: str) -> int | None: - """Get embedding dimension for a specific model by looking it up in the models API. - - Args: - model_id: The identifier of the embedding model (supports both prefixed and non-prefixed) - - Returns: - The embedding dimension for the model, or None if not found - """ - if not self.models_api: - return None - - models_response = await self.models_api.list_models() - models_list = models_response.data if hasattr(models_response, "data") else models_response - - for model in models_list: - if not isinstance(model, Model): - continue - if model.model_type != "embedding": - continue - - # Check for exact match first - if model.identifier == model_id: - embedding_dimension = model.metadata.get("embedding_dimension") - if embedding_dimension is not None: - return int(embedding_dimension) - else: - logger.warning(f"Model {model_id} found but has no embedding_dimension in metadata") - return None - - # Check for prefixed/unprefixed variations - # If model_id is unprefixed, check if it matches the resource_id - if model.provider_resource_id == model_id: - embedding_dimension = model.metadata.get("embedding_dimension") - if embedding_dimension is not None: - return int(embedding_dimension) - - return None - - async def _get_default_embedding_model_and_dimension(self) -> tuple[str, int] | None: - """Get default embedding model from vector stores config. - - Returns None if no vector stores config is provided. - """ - if not self.vector_stores_config: - logger.info("No vector stores config provided") - return None - - model_id = self.vector_stores_config.default_embedding_model_id - embedding_dimension = await self._get_embedding_dimension_for_model(model_id) - if embedding_dimension is None: - raise ValueError(f"Embedding model '{model_id}' not found or has no embedding_dimension in metadata") - - logger.debug( - f"Using default embedding model from vector stores config: {model_id} with dimension {embedding_dimension}" - ) - return model_id, embedding_dimension - async def openai_list_vector_stores( self, limit: int | None = 20, diff --git a/tests/integration/fixtures/common.py b/tests/integration/fixtures/common.py index cd5f49488..fd034fdc2 100644 --- a/tests/integration/fixtures/common.py +++ b/tests/integration/fixtures/common.py @@ -241,7 +241,7 @@ def instantiate_llama_stack_client(session): # --stack-config bypasses template so need this to set default embedding model if "vector_io" in config and "inference" in config: run_config.vector_stores = VectorStoresConfig( - default_embedding_model_id="inline::sentence-transformers/nomic-ai/nomic-embed-text-v1.5" + default_embedding_model_id="sentence-transformers/nomic-ai/nomic-embed-text-v1.5" ) run_config_file = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml") diff --git a/tests/unit/distribution/test_single_provider_filter.py b/tests/unit/distribution/test_single_provider_filter.py new file mode 100644 index 000000000..7447e3eaa --- /dev/null +++ b/tests/unit/distribution/test_single_provider_filter.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +from llama_stack.cli.stack._build import _apply_single_provider_filter +from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec +from llama_stack.core.utils.image_types import LlamaStackImageType + + +def test_filters_single_api(): + """Test filtering keeps only specified provider for one API.""" + build_config = BuildConfig( + image_type=LlamaStackImageType.VENV.value, + distribution_spec=DistributionSpec( + providers={ + "vector_io": [ + BuildProvider(provider_type="inline::faiss"), + BuildProvider(provider_type="inline::sqlite-vec"), + ], + "inference": [ + BuildProvider(provider_type="remote::openai"), + ], + }, + description="Test", + ), + ) + + filtered = _apply_single_provider_filter(build_config, "vector_io=inline::sqlite-vec") + + assert len(filtered.distribution_spec.providers["vector_io"]) == 1 + assert filtered.distribution_spec.providers["vector_io"][0].provider_type == "inline::sqlite-vec" + assert len(filtered.distribution_spec.providers["inference"]) == 1 # unchanged + + +def test_filters_multiple_apis(): + """Test filtering multiple APIs.""" + build_config = BuildConfig( + image_type=LlamaStackImageType.VENV.value, + distribution_spec=DistributionSpec( + providers={ + "vector_io": [ + BuildProvider(provider_type="inline::faiss"), + BuildProvider(provider_type="inline::sqlite-vec"), + ], + "inference": [ + BuildProvider(provider_type="remote::openai"), + BuildProvider(provider_type="remote::anthropic"), + ], + }, + description="Test", + ), + ) + + filtered = _apply_single_provider_filter(build_config, "vector_io=inline::faiss,inference=remote::openai") + + assert len(filtered.distribution_spec.providers["vector_io"]) == 1 + assert filtered.distribution_spec.providers["vector_io"][0].provider_type == "inline::faiss" + assert len(filtered.distribution_spec.providers["inference"]) == 1 + assert filtered.distribution_spec.providers["inference"][0].provider_type == "remote::openai" + + +def test_provider_not_found_exits(): + """Test error when specified provider doesn't exist.""" + build_config = BuildConfig( + image_type=LlamaStackImageType.VENV.value, + distribution_spec=DistributionSpec( + providers={"vector_io": [BuildProvider(provider_type="inline::faiss")]}, + description="Test", + ), + ) + + with pytest.raises(SystemExit): + _apply_single_provider_filter(build_config, "vector_io=inline::nonexistent") + + +def test_invalid_format_exits(): + """Test error for invalid filter format.""" + build_config = BuildConfig( + image_type=LlamaStackImageType.VENV.value, + distribution_spec=DistributionSpec(providers={}, description="Test"), + ) + + with pytest.raises(SystemExit): + _apply_single_provider_filter(build_config, "invalid_format") diff --git a/tests/unit/providers/vector_io/conftest.py b/tests/unit/providers/vector_io/conftest.py index 8e5c85cf1..86d83e377 100644 --- a/tests/unit/providers/vector_io/conftest.py +++ b/tests/unit/providers/vector_io/conftest.py @@ -144,7 +144,6 @@ async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inf config=config, inference_api=mock_inference_api, files_api=None, - models_api=None, ) collection_id = f"sqlite_test_collection_{np.random.randint(1e6)}" await adapter.initialize() @@ -183,7 +182,6 @@ async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding config=config, inference_api=mock_inference_api, files_api=None, - models_api=None, ) await adapter.initialize() await adapter.register_vector_db( diff --git a/tests/unit/providers/vector_io/test_faiss.py b/tests/unit/providers/vector_io/test_faiss.py index 76969b711..fa5c5f56b 100644 --- a/tests/unit/providers/vector_io/test_faiss.py +++ b/tests/unit/providers/vector_io/test_faiss.py @@ -11,7 +11,6 @@ import numpy as np import pytest from llama_stack.apis.files import Files -from llama_stack.apis.models import Models from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, QueryChunksResponse from llama_stack.providers.datatypes import HealthStatus @@ -76,12 +75,6 @@ def mock_files_api(): return mock_api -@pytest.fixture -def mock_models_api(): - mock_api = MagicMock(spec=Models) - return mock_api - - @pytest.fixture def faiss_config(): config = MagicMock(spec=FaissVectorIOConfig) @@ -117,7 +110,7 @@ async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_ assert response.chunks[1] == sample_chunks[1] -async def test_health_success(mock_models_api): +async def test_health_success(): """Test that the health check returns OK status when faiss is working correctly.""" # Create a fresh instance of FaissVectorIOAdapter for testing config = MagicMock() @@ -126,9 +119,7 @@ async def test_health_success(mock_models_api): with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatL2") as mock_index_flat: mock_index_flat.return_value = MagicMock() - adapter = FaissVectorIOAdapter( - config=config, inference_api=inference_api, models_api=mock_models_api, files_api=files_api - ) + adapter = FaissVectorIOAdapter(config=config, inference_api=inference_api, files_api=files_api) # Calling the health method directly response = await adapter.health() @@ -142,7 +133,7 @@ async def test_health_success(mock_models_api): mock_index_flat.assert_called_once_with(128) # VECTOR_DIMENSION is 128 -async def test_health_failure(mock_models_api): +async def test_health_failure(): """Test that the health check returns ERROR status when faiss encounters an error.""" # Create a fresh instance of FaissVectorIOAdapter for testing config = MagicMock() @@ -152,9 +143,7 @@ async def test_health_failure(mock_models_api): with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatL2") as mock_index_flat: mock_index_flat.side_effect = Exception("Test error") - adapter = FaissVectorIOAdapter( - config=config, inference_api=inference_api, models_api=mock_models_api, files_api=files_api - ) + adapter = FaissVectorIOAdapter(config=config, inference_api=inference_api, files_api=files_api) # Calling the health method directly response = await adapter.health() diff --git a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py index a3d2e3173..ad55b9336 100644 --- a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py +++ b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py @@ -1162,5 +1162,5 @@ async def test_embedding_config_required_model_missing(vector_io_adapter): # Test with no embedding model provided params = OpenAICreateVectorStoreRequestWithExtraBody(name="test_store", metadata={}) - with pytest.raises(ValueError, match="embedding_model is required in extra_body when creating a vector store"): + with pytest.raises(ValueError, match="embedding_model is required"): await vector_io_adapter.openai_create_vector_store(params)