From 80d9d50954f6bc444aef48505e33b2571e7a3de7 Mon Sep 17 00:00:00 2001 From: Ashwin Gangadhar Date: Wed, 19 Feb 2025 21:48:05 +0530 Subject: [PATCH] adding mongodb vector_io module updated mongodb.py from print to log add documentation for mongodb vector search module changed insert to update mongodb bug fix mongodb json object conversion error --- README.md | 1 + docs/source/providers/vector_io/mongodb.md | 35 ++++ llama_stack/providers/registry/vector_io.py | 10 + .../remote/vector_io/mongodb/__init__.py | 19 ++ .../remote/vector_io/mongodb/config.py | 18 +- .../mongodb/{mongodb_atlas.py => mongodb.py} | 173 +++++++++++----- .../providers/tests/vector_io/conftest.py | 116 +++++++++++ .../providers/tests/vector_io/fixtures.py | 196 ++++++++++++++++++ 8 files changed, 503 insertions(+), 65 deletions(-) create mode 100644 docs/source/providers/vector_io/mongodb.md rename llama_stack/providers/remote/vector_io/mongodb/{mongodb_atlas.py => mongodb.py} (51%) create mode 100644 llama_stack/providers/tests/vector_io/conftest.py create mode 100644 llama_stack/providers/tests/vector_io/fixtures.py diff --git a/README.md b/README.md index b24e69514..d81851a31 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,7 @@ Here is a list of the various API providers and available distributions that can | NVIDIA NIM | Hosted and Single Node | | ✅ | | | | | Chroma | Single Node | | | ✅ | | | | PG Vector | Single Node | | | ✅ | | | +| MongoDB Atlas | Hosted | | | ✅ | | | | PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | | | vLLM | Hosted and Single Node | | ✅ | | | | diff --git a/docs/source/providers/vector_io/mongodb.md b/docs/source/providers/vector_io/mongodb.md new file mode 100644 index 000000000..67e4adec0 --- /dev/null +++ b/docs/source/providers/vector_io/mongodb.md @@ -0,0 +1,35 @@ +--- +orphan: true +--- +# MongoDB Atlas + +[MongoDB Atlas](https://www.mongodb.com/atlas) is a cloud database service that can be used as a vector store provider for Llama Stack. It supports vector search capabilities through its Atlas Vector Search feature, allowing you to store and query vectors within your MongoDB database. + +## Features +MongoDB Atlas Vector Search supports: +- Store embeddings and their metadata +- Vector search with multiple algorithms (cosine similarity, euclidean distance, dot product) +- Hybrid search (combining vector and keyword search) +- Metadata filtering +- Scalable vector indexing +- Managed cloud infrastructure + +## Usage + +To use MongoDB Atlas in your Llama Stack project, follow these steps: + +1. Create a MongoDB Atlas account and cluster. +2. Configure your Atlas cluster to enable Vector Search. +3. Configure your Llama Stack project to use MongoDB Atlas. +4. Start storing and querying vectors. + +## Installation + +You can install the MongoDB Python driver using pip: + +```bash +pip install pymongo +``` + +## Documentation +See [MongoDB Atlas Vector Search documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/) for more details about vector search capabilities in MongoDB Atlas. diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py index 8471748d8..e3b81172a 100644 --- a/llama_stack/providers/registry/vector_io.py +++ b/llama_stack/providers/registry/vector_io.py @@ -112,6 +112,16 @@ def available_providers() -> List[ProviderSpec]: ), api_dependencies=[Api.inference], ), + remote_provider_spec( + Api.vector_io, + AdapterSpec( + adapter_type="mongodb", + pip_packages=["pymongo"], + module="llama_stack.providers.remote.vector_io.mongodb", + config_class="llama_stack.providers.remote.vector_io.mongodb.MongoDBVectorIOConfig", + ), + api_dependencies=[Api.inference], + ), remote_provider_spec( Api.vector_io, AdapterSpec( diff --git a/llama_stack/providers/remote/vector_io/mongodb/__init__.py b/llama_stack/providers/remote/vector_io/mongodb/__init__.py index e69de29bb..fd46551b1 100644 --- a/llama_stack/providers/remote/vector_io/mongodb/__init__.py +++ b/llama_stack/providers/remote/vector_io/mongodb/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Dict + +from llama_stack.providers.datatypes import Api, ProviderSpec + +from .config import MongoDBVectorIOConfig + + +async def get_adapter_impl(config: MongoDBVectorIOConfig, deps: Dict[Api, ProviderSpec]): + from .mongodb import MongoDBVectorIOAdapter + + impl = MongoDBVectorIOAdapter(config, deps[Api.inference]) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/vector_io/mongodb/config.py b/llama_stack/providers/remote/vector_io/mongodb/config.py index 620594566..04e20b474 100644 --- a/llama_stack/providers/remote/vector_io/mongodb/config.py +++ b/llama_stack/providers/remote/vector_io/mongodb/config.py @@ -4,26 +4,22 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, List from pydantic import BaseModel, Field class MongoDBVectorIOConfig(BaseModel): - conncetion_str: str - namespace: str = Field(None, description="Namespace of the MongoDB collection") - index_name: Optional[str] = Field("default", description="Name of the index in the MongoDB collection") - filter_fields: Optional[str] = Field(None, description="Fields to filter the MongoDB collection") - embedding_field: Optional[str] = Field("embeddings", description="Field name for the embeddings in the MongoDB collection") - text_field: Optional[str] = Field("text", description="Field name for the text in the MongoDB collection") + connection_str: str = Field(None, description="Connection string for the MongoDB Atlas collection") + namespace: str = Field(None, description="Namespace i.e. db_name.collection_name of the MongoDB Atlas collection") + index_name: Optional[str] = Field("default", description="Name of the index in the MongoDB Atlas collection") + filter_fields: Optional[List[str]] = Field([], description="Fields to filter along side vector search in MongoDB Atlas collection") + embeddings_key: Optional[str] = Field("embeddings", description="Field name for the embeddings in the MongoDB Atlas collection") + text_field: Optional[str] = Field("text", description="Field name for the text in the MongoDB Atlas collection") @classmethod def sample_config(cls) -> Dict[str, Any]: return { "connection_str": "{env.MONGODB_CONNECTION_STR}", "namespace": "{env.MONGODB_NAMESPACE}", - "index_name": "{env.MONGODB_INDEX_NAME}", - "filter_fields": "{env.MONGODB_FILTER_FIELDS}", - "embedding_field": "{env.MONGODB_EMBEDDING_FIELD}", - "text_field": "{env.MONGODB_TEXT_FIELD}", } diff --git a/llama_stack/providers/remote/vector_io/mongodb/mongodb_atlas.py b/llama_stack/providers/remote/vector_io/mongodb/mongodb.py similarity index 51% rename from llama_stack/providers/remote/vector_io/mongodb/mongodb_atlas.py rename to llama_stack/providers/remote/vector_io/mongodb/mongodb.py index bedb90754..cc172fc70 100644 --- a/llama_stack/providers/remote/vector_io/mongodb/mongodb_atlas.py +++ b/llama_stack/providers/remote/vector_io/mongodb/mongodb.py @@ -1,5 +1,3 @@ -import pymongo - # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # @@ -11,8 +9,8 @@ import logging from typing import Any, Dict, List, Optional, Union from urllib.parse import urlparse -from pymongo import MongoClient, -from pymongo.operations import UpdateOne, InsertOne, DeleteOne, DeleteMany, SearchIndexModel +from pymongo import MongoClient +from pymongo.operations import InsertOne, SearchIndexModel, UpdateOne import certifi from numpy.typing import NDArray @@ -25,19 +23,24 @@ from llama_stack.providers.utils.memory.vector_store import ( EmbeddingIndex, VectorDBWithIndex, ) + +from .config import MongoDBVectorIOConfig -from .config import MongoDBAtlasVectorIOConfig from time import sleep log = logging.getLogger(__name__) CHUNK_ID_KEY = "_chunk_id" + class MongoDBAtlasIndex(EmbeddingIndex): - def __init__(self, client: MongoClient, namespace: str, embeddings_key: str, index_name: str): + + def __init__(self, client: MongoClient, namespace: str, embeddings_key: str, embedding_dimension: str, index_name: str, filter_fields: List[str]): self.client = client self.namespace = namespace self.embeddings_key = embeddings_key self.index_name = index_name + self.filter_fields = filter_fields + self.embedding_dimension = embedding_dimension def _get_index_config(self, collection, index_name): idxs = list(collection.list_search_indexes()) @@ -45,14 +48,39 @@ class MongoDBAtlasIndex(EmbeddingIndex): if ele["name"] == index_name: return ele + def _get_search_index_model(self): + index_fields = [ + { + "path": self.embeddings_key, + "type": "vector", + "numDimensions": self.embedding_dimension, + "similarity": "cosine" + } + ] + + if len(self.filter_fields) > 0: + for filter_field in self.filter_fields: + index_fields.append( + { + "path": filter_field, + "type": "filter" + } + ) + + return SearchIndexModel( + name=self.index_name, + type="vectorSearch", + definition={ + "fields": index_fields + } + ) + def _check_n_create_index(self): client = self.client - db,collection = self.namespace.split(".") + db, collection = self.namespace.split(".") collection = client[db][collection] index_name = self.index_name - print(">>>>>>>>Index name: ", index_name, "<<<<<<<<<<") idx = self._get_index_config(collection, index_name) - print(idx) if not idx: log.info("Creating search index ...") search_index_model = self._get_search_index_model() @@ -60,10 +88,10 @@ class MongoDBAtlasIndex(EmbeddingIndex): while True: idx = self._get_index_config(collection, index_name) if idx and idx["queryable"]: - print("Search index created successfully.") + log.info("Search index created successfully.") break else: - print("Waiting for search index to be created ...") + log.info("Waiting for search index to be created ...") sleep(5) else: log.info("Search index already exists.") @@ -77,46 +105,53 @@ class MongoDBAtlasIndex(EmbeddingIndex): operations = [] for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): chunk_id = f"{chunk.metadata['document_id']}:chunk-{i}" + operations.append( - InsertOne( + UpdateOne( + {CHUNK_ID_KEY: chunk_id}, { - CHUNK_ID_KEY: chunk_id, - "chunk_content": chunk.model_dump_json(), - self.embeddings_key: embedding.tolist(), - } + "$set": { + CHUNK_ID_KEY: chunk_id, + "chunk_content": json.loads(chunk.model_dump_json()), + self.embeddings_key: embedding.tolist(), + } + }, + upsert=True, ) ) # Perform the bulk operations - db,collection_name = self.namespace.split(".") + db, collection_name = self.namespace.split(".") collection = self.client[db][collection_name] collection.bulk_write(operations) - - async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: - + print(f"Added {len(chunks)} chunks to the collection") # Create a search index model + print("Creating search index ...") self._check_n_create_index() + async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: # Perform a query - db,collection_name = self.namespace.split(".") + db, collection_name = self.namespace.split(".") collection = self.client[db][collection_name] # Create vector search query - vs_query = {"$vectorSearch": - { - "index": "vector_index", - "path": self.embeddings_key, - "queryVector": embedding.tolist(), - "numCandidates": k, - "limit": k, - } - } + vs_query = {"$vectorSearch": + { + "index": self.index_name, + "path": self.embeddings_key, + "queryVector": embedding.tolist(), + "numCandidates": k, + "limit": k, + } + } # Add a field to store the score score_add_field_query = { "$addFields": { "score": {"$meta": "vectorSearchScore"} } } + if score_threshold is None: + score_threshold = 0.01 # Filter the results based on the score threshold filter_query = { "$match": { @@ -141,60 +176,90 @@ class MongoDBAtlasIndex(EmbeddingIndex): chunks = [] scores = [] for result in results: + content = result["chunk_content"] chunk = Chunk( - metadata={"document_id": result[CHUNK_ID_KEY]}, - content=json.loads(result["chunk_content"]), + metadata=content["metadata"], + content=content["content"], ) chunks.append(chunk) scores.append(result["score"]) return QueryChunksResponse(chunks=chunks, scores=scores) - -class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): - def __init__(self, config: MongoDBAtlasVectorIOConfig, inference_api: Api.inference): + async def delete(self): + db, _ = self.namespace.split(".") + self.client.drop_database(db) + + +class MongoDBVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): + def __init__(self, config: MongoDBVectorIOConfig, inference_api: Api.inference): self.config = config self.inference_api = inference_api - + self.cache = {} async def initialize(self) -> None: self.client = MongoClient( - self.config.uri, + self.config.connection_str, tlsCAFile=certifi.where(), ) - self.cache = {} - pass async def shutdown(self) -> None: - self.client.close() - pass + if not self.client: + self.client.close() - async def register_vector_db( self, vector_db: VectorDB) -> None: - index = VectorDBWithIndex( + async def register_vector_db(self, vector_db: VectorDB) -> None: + index=MongoDBAtlasIndex( + client=self.client, + namespace=self.config.namespace, + embeddings_key=self.config.embeddings_key, + embedding_dimension=vector_db.embedding_dimension, + index_name=self.config.index_name, + filter_fields=self.config.filter_fields, + ) + self.cache[vector_db.identifier] = VectorDBWithIndex( vector_db=vector_db, + index=index, + inference_api=self.inference_api, + ) + + async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex: + if vector_db_id in self.cache: + return self.cache[vector_db_id] + vector_db = await self.vector_db_store.get_vector_db(vector_db_id) + self.cache[vector_db_id] = VectorDBWithIndex( + vector_db=vector_db_id, index=MongoDBAtlasIndex( client=self.client, namespace=self.config.namespace, embeddings_key=self.config.embeddings_key, + embedding_dimension=vector_db.embedding_dimension, index_name=self.config.index_name, + filter_fields=self.config.filter_fields, ), + inference_api=self.inference_api, ) - self.cache[vector_db] = index - pass + return self.cache[vector_db_id] - async def insert_chunks(self, vector_db_id, chunks, ttl_seconds = None): - index = self.cache[vector_db_id].index + async def unregister_vector_db(self, vector_db_id: str) -> None: + await self.cache[vector_db_id].index.delete() + del self.cache[vector_db_id] + + async def insert_chunks(self, + vector_db_id: str, + chunks: List[Chunk], + ttl_seconds: Optional[int] = None, + ) -> None: + index = await self._get_and_cache_vector_db_index(vector_db_id) if not index: raise ValueError(f"Vector DB {vector_db_id} not found") await index.insert_chunks(chunks) - - async def query_chunks(self, vector_db_id, query, params = None): - index = self.cache[vector_db_id].index + async def query_chunks(self, + vector_db_id: str, + query: InterleavedContent, + params: Optional[Dict[str, Any]] = None, + ) -> QueryChunksResponse: + index = await self._get_and_cache_vector_db_index(vector_db_id) if not index: raise ValueError(f"Vector DB {vector_db_id} not found") return await index.query_chunks(query, params) - - - - diff --git a/llama_stack/providers/tests/vector_io/conftest.py b/llama_stack/providers/tests/vector_io/conftest.py new file mode 100644 index 000000000..776c0458f --- /dev/null +++ b/llama_stack/providers/tests/vector_io/conftest.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +from ..conftest import ( + get_provider_fixture_overrides, + get_provider_fixture_overrides_from_test_config, + get_test_config_for_api, +) +from ..inference.fixtures import INFERENCE_FIXTURES +from .fixtures import VECTOR_IO_FIXTURES + +DEFAULT_PROVIDER_COMBINATIONS = [ + pytest.param( + { + "inference": "sentence_transformers", + "vector_io": "faiss", + }, + id="sentence_transformers", + marks=pytest.mark.sentence_transformers, + ), + pytest.param( + { + "inference": "ollama", + "vector_io": "pgvector", + }, + id="pgvector", + marks=pytest.mark.pgvector, + ), + pytest.param( + { + "inference": "ollama", + "vector_io": "faiss", + }, + id="ollama", + marks=pytest.mark.ollama, + ), + pytest.param( + { + "inference": "ollama", + "vector_io": "sqlite_vec", + }, + id="sqlite_vec", + marks=pytest.mark.ollama, + ), + pytest.param( + { + "inference": "sentence_transformers", + "vector_io": "chroma", + }, + id="chroma", + marks=pytest.mark.chroma, + ), + pytest.param( + { + "inference": "ollama", + "vector_io": "qdrant", + }, + id="qdrant", + marks=pytest.mark.qdrant, + ), + pytest.param( + { + "inference": "fireworks", + "vector_io": "weaviate", + }, + id="weaviate", + marks=pytest.mark.weaviate, + ), + pytest.param( + { + "inference": "bedrock", + "vector_io": "mongodb", + }, + id="mongodb", + marks=pytest.mark.mongodb, + ), +] + + +def pytest_configure(config): + for fixture_name in VECTOR_IO_FIXTURES: + config.addinivalue_line( + "markers", + f"{fixture_name}: marks tests as {fixture_name} specific", + ) + + +def pytest_generate_tests(metafunc): + test_config = get_test_config_for_api(metafunc.config, "vector_io") + if "embedding_model" in metafunc.fixturenames: + model = getattr(test_config, "embedding_model", None) + # Fall back to the default if not specified by the config file + model = model or metafunc.config.getoption("--embedding-model") + if model: + params = [pytest.param(model, id="")] + else: + params = [pytest.param("all-minilm:l6-v2", id="")] + + metafunc.parametrize("embedding_model", params, indirect=True) + + if "vector_io_stack" in metafunc.fixturenames: + available_fixtures = { + "inference": INFERENCE_FIXTURES, + "vector_io": VECTOR_IO_FIXTURES, + } + combinations = ( + get_provider_fixture_overrides_from_test_config(metafunc.config, "vector_io", DEFAULT_PROVIDER_COMBINATIONS) + or get_provider_fixture_overrides(metafunc.config, available_fixtures) + or DEFAULT_PROVIDER_COMBINATIONS + ) + metafunc.parametrize("vector_io_stack", combinations, indirect=True) diff --git a/llama_stack/providers/tests/vector_io/fixtures.py b/llama_stack/providers/tests/vector_io/fixtures.py new file mode 100644 index 000000000..a4c89a77b --- /dev/null +++ b/llama_stack/providers/tests/vector_io/fixtures.py @@ -0,0 +1,196 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +import tempfile + +import pytest +import pytest_asyncio + +from llama_stack.apis.models import ModelInput, ModelType +from llama_stack.distribution.datatypes import Api, Provider +from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig +from llama_stack.providers.inline.vector_io.faiss import FaissVectorIOConfig +from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig +from llama_stack.providers.remote.vector_io.chroma import ChromaVectorIOConfig +from llama_stack.providers.remote.vector_io.pgvector import PGVectorVectorIOConfig +from llama_stack.providers.remote.vector_io.qdrant import QdrantVectorIOConfig +from llama_stack.providers.remote.vector_io.weaviate import WeaviateVectorIOConfig +from llama_stack.providers.remote.vector_io.mongodb import MongoDBVectorIOConfig +from llama_stack.providers.tests.resolver import construct_stack_for_test +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig + +from ..conftest import ProviderFixture, remote_stack_fixture +from ..env import get_env_or_fail + + +@pytest.fixture(scope="session") +def embedding_model(request): + if hasattr(request, "param"): + return request.param + return request.config.getoption("--embedding-model", None) + + +@pytest.fixture(scope="session") +def vector_io_remote() -> ProviderFixture: + return remote_stack_fixture() + + +@pytest.fixture(scope="session") +def vector_io_faiss() -> ProviderFixture: + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db") + return ProviderFixture( + providers=[ + Provider( + provider_id="faiss", + provider_type="inline::faiss", + config=FaissVectorIOConfig( + kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(), + ).model_dump(), + ) + ], + ) + + +@pytest.fixture(scope="session") +def vector_io_sqlite_vec() -> ProviderFixture: + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db") + return ProviderFixture( + providers=[ + Provider( + provider_id="sqlite_vec", + provider_type="inline::sqlite_vec", + config=SQLiteVectorIOConfig( + kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(), + ).model_dump(), + ) + ], + ) + + +@pytest.fixture(scope="session") +def vector_io_pgvector() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="pgvector", + provider_type="remote::pgvector", + config=PGVectorVectorIOConfig( + host=os.getenv("PGVECTOR_HOST", "localhost"), + port=os.getenv("PGVECTOR_PORT", 5432), + db=get_env_or_fail("PGVECTOR_DB"), + user=get_env_or_fail("PGVECTOR_USER"), + password=get_env_or_fail("PGVECTOR_PASSWORD"), + ).model_dump(), + ) + ], + ) + + +@pytest.fixture(scope="session") +def vector_io_weaviate() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="weaviate", + provider_type="remote::weaviate", + config=WeaviateVectorIOConfig().model_dump(), + ) + ], + provider_data=dict( + weaviate_api_key=get_env_or_fail("WEAVIATE_API_KEY"), + weaviate_cluster_url=get_env_or_fail("WEAVIATE_CLUSTER_URL"), + ), + ) + + +@pytest.fixture(scope="session") +def vector_io_chroma() -> ProviderFixture: + url = os.getenv("CHROMA_URL") + if url: + config = ChromaVectorIOConfig(url=url) + provider_type = "remote::chromadb" + else: + if not os.getenv("CHROMA_DB_PATH"): + raise ValueError("CHROMA_DB_PATH or CHROMA_URL must be set") + config = InlineChromaVectorIOConfig(db_path=os.getenv("CHROMA_DB_PATH")) + provider_type = "inline::chromadb" + return ProviderFixture( + providers=[ + Provider( + provider_id="chroma", + provider_type=provider_type, + config=config.model_dump(), + ) + ] + ) + + +@pytest.fixture(scope="session") +def vector_io_qdrant() -> ProviderFixture: + url = os.getenv("QDRANT_URL") + if url: + config = QdrantVectorIOConfig(url=url) + provider_type = "remote::qdrant" + else: + raise ValueError("QDRANT_URL must be set") + return ProviderFixture( + providers=[ + Provider( + provider_id="qdrant", + provider_type=provider_type, + config=config.model_dump(), + ) + ] + ) + +@pytest.fixture(scope="session") +def vector_io_mongodb() -> ProviderFixture: + connection_str = get_env_or_fail("MONGODB_CONNECTION_STR") + namespace = get_env_or_fail("MONGODB_NAMESPACE") + config = MongoDBVectorIOConfig(connection_str=connection_str, namespace=namespace) + provider_type = "remote::mongodb" + return ProviderFixture( + providers=[ + Provider( + provider_id="mongodb", + provider_type=provider_type, + config=config.model_dump(), + ) + ] + ) + +VECTOR_IO_FIXTURES = ["faiss", "pgvector", "weaviate", "chroma", "qdrant", "sqlite_vec", "mongodb"] + + +@pytest_asyncio.fixture(scope="session") +async def vector_io_stack(embedding_model, request): + fixture_dict = request.param + + providers = {} + provider_data = {} + for key in ["inference", "vector_io"]: + fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") + providers[key] = fixture.providers + if fixture.provider_data: + provider_data.update(fixture.provider_data) + + test_stack = await construct_stack_for_test( + [Api.vector_io, Api.inference], + providers, + provider_data, + models=[ + ModelInput( + model_id=embedding_model, + model_type=ModelType.embedding, + metadata={ + "embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"), + }, + ) + ], + ) + + return test_stack.impls[Api.vector_io], test_stack.impls[Api.vector_dbs]