mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-10 04:08:31 +00:00
adding mongodb vector_io module
updated mongodb.py from print to log add documentation for mongodb vector search module changed insert to update mongodb bug fix mongodb json object conversion error
This commit is contained in:
parent
d224ae0c8e
commit
80d9d50954
8 changed files with 503 additions and 65 deletions
|
@ -48,6 +48,7 @@ Here is a list of the various API providers and available distributions that can
|
|||
| NVIDIA NIM | Hosted and Single Node | | ✅ | | | |
|
||||
| Chroma | Single Node | | | ✅ | | |
|
||||
| PG Vector | Single Node | | | ✅ | | |
|
||||
| MongoDB Atlas | Hosted | | | ✅ | | |
|
||||
| PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | |
|
||||
| vLLM | Hosted and Single Node | | ✅ | | | |
|
||||
|
||||
|
|
35
docs/source/providers/vector_io/mongodb.md
Normal file
35
docs/source/providers/vector_io/mongodb.md
Normal file
|
@ -0,0 +1,35 @@
|
|||
---
|
||||
orphan: true
|
||||
---
|
||||
# MongoDB Atlas
|
||||
|
||||
[MongoDB Atlas](https://www.mongodb.com/atlas) is a cloud database service that can be used as a vector store provider for Llama Stack. It supports vector search capabilities through its Atlas Vector Search feature, allowing you to store and query vectors within your MongoDB database.
|
||||
|
||||
## Features
|
||||
MongoDB Atlas Vector Search supports:
|
||||
- Store embeddings and their metadata
|
||||
- Vector search with multiple algorithms (cosine similarity, euclidean distance, dot product)
|
||||
- Hybrid search (combining vector and keyword search)
|
||||
- Metadata filtering
|
||||
- Scalable vector indexing
|
||||
- Managed cloud infrastructure
|
||||
|
||||
## Usage
|
||||
|
||||
To use MongoDB Atlas in your Llama Stack project, follow these steps:
|
||||
|
||||
1. Create a MongoDB Atlas account and cluster.
|
||||
2. Configure your Atlas cluster to enable Vector Search.
|
||||
3. Configure your Llama Stack project to use MongoDB Atlas.
|
||||
4. Start storing and querying vectors.
|
||||
|
||||
## Installation
|
||||
|
||||
You can install the MongoDB Python driver using pip:
|
||||
|
||||
```bash
|
||||
pip install pymongo
|
||||
```
|
||||
|
||||
## Documentation
|
||||
See [MongoDB Atlas Vector Search documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/) for more details about vector search capabilities in MongoDB Atlas.
|
|
@ -112,6 +112,16 @@ def available_providers() -> List[ProviderSpec]:
|
|||
),
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
adapter_type="mongodb",
|
||||
pip_packages=["pymongo"],
|
||||
module="llama_stack.providers.remote.vector_io.mongodb",
|
||||
config_class="llama_stack.providers.remote.vector_io.mongodb.MongoDBVectorIOConfig",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import MongoDBVectorIOConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: MongoDBVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
||||
from .mongodb import MongoDBVectorIOAdapter
|
||||
|
||||
impl = MongoDBVectorIOAdapter(config, deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -4,26 +4,22 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MongoDBVectorIOConfig(BaseModel):
|
||||
conncetion_str: str
|
||||
namespace: str = Field(None, description="Namespace of the MongoDB collection")
|
||||
index_name: Optional[str] = Field("default", description="Name of the index in the MongoDB collection")
|
||||
filter_fields: Optional[str] = Field(None, description="Fields to filter the MongoDB collection")
|
||||
embedding_field: Optional[str] = Field("embeddings", description="Field name for the embeddings in the MongoDB collection")
|
||||
text_field: Optional[str] = Field("text", description="Field name for the text in the MongoDB collection")
|
||||
connection_str: str = Field(None, description="Connection string for the MongoDB Atlas collection")
|
||||
namespace: str = Field(None, description="Namespace i.e. db_name.collection_name of the MongoDB Atlas collection")
|
||||
index_name: Optional[str] = Field("default", description="Name of the index in the MongoDB Atlas collection")
|
||||
filter_fields: Optional[List[str]] = Field([], description="Fields to filter along side vector search in MongoDB Atlas collection")
|
||||
embeddings_key: Optional[str] = Field("embeddings", description="Field name for the embeddings in the MongoDB Atlas collection")
|
||||
text_field: Optional[str] = Field("text", description="Field name for the text in the MongoDB Atlas collection")
|
||||
|
||||
@classmethod
|
||||
def sample_config(cls) -> Dict[str, Any]:
|
||||
return {
|
||||
"connection_str": "{env.MONGODB_CONNECTION_STR}",
|
||||
"namespace": "{env.MONGODB_NAMESPACE}",
|
||||
"index_name": "{env.MONGODB_INDEX_NAME}",
|
||||
"filter_fields": "{env.MONGODB_FILTER_FIELDS}",
|
||||
"embedding_field": "{env.MONGODB_EMBEDDING_FIELD}",
|
||||
"text_field": "{env.MONGODB_TEXT_FIELD}",
|
||||
}
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
import pymongo
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
|
@ -11,8 +9,8 @@ import logging
|
|||
from typing import Any, Dict, List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pymongo import MongoClient,
|
||||
from pymongo.operations import UpdateOne, InsertOne, DeleteOne, DeleteMany, SearchIndexModel
|
||||
from pymongo import MongoClient
|
||||
from pymongo.operations import InsertOne, SearchIndexModel, UpdateOne
|
||||
import certifi
|
||||
from numpy.typing import NDArray
|
||||
|
||||
|
@ -26,18 +24,23 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
VectorDBWithIndex,
|
||||
)
|
||||
|
||||
from .config import MongoDBAtlasVectorIOConfig
|
||||
from .config import MongoDBVectorIOConfig
|
||||
|
||||
from time import sleep
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
CHUNK_ID_KEY = "_chunk_id"
|
||||
|
||||
|
||||
class MongoDBAtlasIndex(EmbeddingIndex):
|
||||
def __init__(self, client: MongoClient, namespace: str, embeddings_key: str, index_name: str):
|
||||
|
||||
def __init__(self, client: MongoClient, namespace: str, embeddings_key: str, embedding_dimension: str, index_name: str, filter_fields: List[str]):
|
||||
self.client = client
|
||||
self.namespace = namespace
|
||||
self.embeddings_key = embeddings_key
|
||||
self.index_name = index_name
|
||||
self.filter_fields = filter_fields
|
||||
self.embedding_dimension = embedding_dimension
|
||||
|
||||
def _get_index_config(self, collection, index_name):
|
||||
idxs = list(collection.list_search_indexes())
|
||||
|
@ -45,14 +48,39 @@ class MongoDBAtlasIndex(EmbeddingIndex):
|
|||
if ele["name"] == index_name:
|
||||
return ele
|
||||
|
||||
def _get_search_index_model(self):
|
||||
index_fields = [
|
||||
{
|
||||
"path": self.embeddings_key,
|
||||
"type": "vector",
|
||||
"numDimensions": self.embedding_dimension,
|
||||
"similarity": "cosine"
|
||||
}
|
||||
]
|
||||
|
||||
if len(self.filter_fields) > 0:
|
||||
for filter_field in self.filter_fields:
|
||||
index_fields.append(
|
||||
{
|
||||
"path": filter_field,
|
||||
"type": "filter"
|
||||
}
|
||||
)
|
||||
|
||||
return SearchIndexModel(
|
||||
name=self.index_name,
|
||||
type="vectorSearch",
|
||||
definition={
|
||||
"fields": index_fields
|
||||
}
|
||||
)
|
||||
|
||||
def _check_n_create_index(self):
|
||||
client = self.client
|
||||
db,collection = self.namespace.split(".")
|
||||
db, collection = self.namespace.split(".")
|
||||
collection = client[db][collection]
|
||||
index_name = self.index_name
|
||||
print(">>>>>>>>Index name: ", index_name, "<<<<<<<<<<")
|
||||
idx = self._get_index_config(collection, index_name)
|
||||
print(idx)
|
||||
if not idx:
|
||||
log.info("Creating search index ...")
|
||||
search_index_model = self._get_search_index_model()
|
||||
|
@ -60,10 +88,10 @@ class MongoDBAtlasIndex(EmbeddingIndex):
|
|||
while True:
|
||||
idx = self._get_index_config(collection, index_name)
|
||||
if idx and idx["queryable"]:
|
||||
print("Search index created successfully.")
|
||||
log.info("Search index created successfully.")
|
||||
break
|
||||
else:
|
||||
print("Waiting for search index to be created ...")
|
||||
log.info("Waiting for search index to be created ...")
|
||||
sleep(5)
|
||||
else:
|
||||
log.info("Search index already exists.")
|
||||
|
@ -77,46 +105,53 @@ class MongoDBAtlasIndex(EmbeddingIndex):
|
|||
operations = []
|
||||
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
||||
chunk_id = f"{chunk.metadata['document_id']}:chunk-{i}"
|
||||
|
||||
operations.append(
|
||||
InsertOne(
|
||||
UpdateOne(
|
||||
{CHUNK_ID_KEY: chunk_id},
|
||||
{
|
||||
CHUNK_ID_KEY: chunk_id,
|
||||
"chunk_content": chunk.model_dump_json(),
|
||||
self.embeddings_key: embedding.tolist(),
|
||||
}
|
||||
"$set": {
|
||||
CHUNK_ID_KEY: chunk_id,
|
||||
"chunk_content": json.loads(chunk.model_dump_json()),
|
||||
self.embeddings_key: embedding.tolist(),
|
||||
}
|
||||
},
|
||||
upsert=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Perform the bulk operations
|
||||
db,collection_name = self.namespace.split(".")
|
||||
db, collection_name = self.namespace.split(".")
|
||||
collection = self.client[db][collection_name]
|
||||
collection.bulk_write(operations)
|
||||
|
||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
|
||||
print(f"Added {len(chunks)} chunks to the collection")
|
||||
# Create a search index model
|
||||
print("Creating search index ...")
|
||||
self._check_n_create_index()
|
||||
|
||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
# Perform a query
|
||||
db,collection_name = self.namespace.split(".")
|
||||
db, collection_name = self.namespace.split(".")
|
||||
collection = self.client[db][collection_name]
|
||||
|
||||
# Create vector search query
|
||||
vs_query = {"$vectorSearch":
|
||||
{
|
||||
"index": "vector_index",
|
||||
"path": self.embeddings_key,
|
||||
"queryVector": embedding.tolist(),
|
||||
"numCandidates": k,
|
||||
"limit": k,
|
||||
}
|
||||
}
|
||||
{
|
||||
"index": self.index_name,
|
||||
"path": self.embeddings_key,
|
||||
"queryVector": embedding.tolist(),
|
||||
"numCandidates": k,
|
||||
"limit": k,
|
||||
}
|
||||
}
|
||||
# Add a field to store the score
|
||||
score_add_field_query = {
|
||||
"$addFields": {
|
||||
"score": {"$meta": "vectorSearchScore"}
|
||||
}
|
||||
}
|
||||
if score_threshold is None:
|
||||
score_threshold = 0.01
|
||||
# Filter the results based on the score threshold
|
||||
filter_query = {
|
||||
"$match": {
|
||||
|
@ -141,60 +176,90 @@ class MongoDBAtlasIndex(EmbeddingIndex):
|
|||
chunks = []
|
||||
scores = []
|
||||
for result in results:
|
||||
content = result["chunk_content"]
|
||||
chunk = Chunk(
|
||||
metadata={"document_id": result[CHUNK_ID_KEY]},
|
||||
content=json.loads(result["chunk_content"]),
|
||||
metadata=content["metadata"],
|
||||
content=content["content"],
|
||||
)
|
||||
chunks.append(chunk)
|
||||
scores.append(result["score"])
|
||||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def delete(self):
|
||||
db, _ = self.namespace.split(".")
|
||||
self.client.drop_database(db)
|
||||
|
||||
class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(self, config: MongoDBAtlasVectorIOConfig, inference_api: Api.inference):
|
||||
|
||||
class MongoDBVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(self, config: MongoDBVectorIOConfig, inference_api: Api.inference):
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
|
||||
self.cache = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.client = MongoClient(
|
||||
self.config.uri,
|
||||
self.config.connection_str,
|
||||
tlsCAFile=certifi.where(),
|
||||
)
|
||||
self.cache = {}
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.client.close()
|
||||
pass
|
||||
if not self.client:
|
||||
self.client.close()
|
||||
|
||||
async def register_vector_db( self, vector_db: VectorDB) -> None:
|
||||
index = VectorDBWithIndex(
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
index=MongoDBAtlasIndex(
|
||||
client=self.client,
|
||||
namespace=self.config.namespace,
|
||||
embeddings_key=self.config.embeddings_key,
|
||||
embedding_dimension=vector_db.embedding_dimension,
|
||||
index_name=self.config.index_name,
|
||||
filter_fields=self.config.filter_fields,
|
||||
)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=index,
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
|
||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex:
|
||||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
|
||||
self.cache[vector_db_id] = VectorDBWithIndex(
|
||||
vector_db=vector_db_id,
|
||||
index=MongoDBAtlasIndex(
|
||||
client=self.client,
|
||||
namespace=self.config.namespace,
|
||||
embeddings_key=self.config.embeddings_key,
|
||||
embedding_dimension=vector_db.embedding_dimension,
|
||||
index_name=self.config.index_name,
|
||||
filter_fields=self.config.filter_fields,
|
||||
),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_db] = index
|
||||
pass
|
||||
return self.cache[vector_db_id]
|
||||
|
||||
async def insert_chunks(self, vector_db_id, chunks, ttl_seconds = None):
|
||||
index = self.cache[vector_db_id].index
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
|
||||
async def insert_chunks(self,
|
||||
vector_db_id: str,
|
||||
chunks: List[Chunk],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
|
||||
async def query_chunks(self, vector_db_id, query, params = None):
|
||||
index = self.cache[vector_db_id].index
|
||||
async def query_chunks(self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
|
||||
|
||||
|
116
llama_stack/providers/tests/vector_io/conftest.py
Normal file
116
llama_stack/providers/tests/vector_io/conftest.py
Normal file
|
@ -0,0 +1,116 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from ..conftest import (
|
||||
get_provider_fixture_overrides,
|
||||
get_provider_fixture_overrides_from_test_config,
|
||||
get_test_config_for_api,
|
||||
)
|
||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||
from .fixtures import VECTOR_IO_FIXTURES
|
||||
|
||||
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "sentence_transformers",
|
||||
"vector_io": "faiss",
|
||||
},
|
||||
id="sentence_transformers",
|
||||
marks=pytest.mark.sentence_transformers,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "ollama",
|
||||
"vector_io": "pgvector",
|
||||
},
|
||||
id="pgvector",
|
||||
marks=pytest.mark.pgvector,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "ollama",
|
||||
"vector_io": "faiss",
|
||||
},
|
||||
id="ollama",
|
||||
marks=pytest.mark.ollama,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "ollama",
|
||||
"vector_io": "sqlite_vec",
|
||||
},
|
||||
id="sqlite_vec",
|
||||
marks=pytest.mark.ollama,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "sentence_transformers",
|
||||
"vector_io": "chroma",
|
||||
},
|
||||
id="chroma",
|
||||
marks=pytest.mark.chroma,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "ollama",
|
||||
"vector_io": "qdrant",
|
||||
},
|
||||
id="qdrant",
|
||||
marks=pytest.mark.qdrant,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "fireworks",
|
||||
"vector_io": "weaviate",
|
||||
},
|
||||
id="weaviate",
|
||||
marks=pytest.mark.weaviate,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "bedrock",
|
||||
"vector_io": "mongodb",
|
||||
},
|
||||
id="mongodb",
|
||||
marks=pytest.mark.mongodb,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
for fixture_name in VECTOR_IO_FIXTURES:
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
f"{fixture_name}: marks tests as {fixture_name} specific",
|
||||
)
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
test_config = get_test_config_for_api(metafunc.config, "vector_io")
|
||||
if "embedding_model" in metafunc.fixturenames:
|
||||
model = getattr(test_config, "embedding_model", None)
|
||||
# Fall back to the default if not specified by the config file
|
||||
model = model or metafunc.config.getoption("--embedding-model")
|
||||
if model:
|
||||
params = [pytest.param(model, id="")]
|
||||
else:
|
||||
params = [pytest.param("all-minilm:l6-v2", id="")]
|
||||
|
||||
metafunc.parametrize("embedding_model", params, indirect=True)
|
||||
|
||||
if "vector_io_stack" in metafunc.fixturenames:
|
||||
available_fixtures = {
|
||||
"inference": INFERENCE_FIXTURES,
|
||||
"vector_io": VECTOR_IO_FIXTURES,
|
||||
}
|
||||
combinations = (
|
||||
get_provider_fixture_overrides_from_test_config(metafunc.config, "vector_io", DEFAULT_PROVIDER_COMBINATIONS)
|
||||
or get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
||||
or DEFAULT_PROVIDER_COMBINATIONS
|
||||
)
|
||||
metafunc.parametrize("vector_io_stack", combinations, indirect=True)
|
196
llama_stack/providers/tests/vector_io/fixtures.py
Normal file
196
llama_stack/providers/tests/vector_io/fixtures.py
Normal file
|
@ -0,0 +1,196 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.models import ModelInput, ModelType
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss import FaissVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.chroma import ChromaVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.pgvector import PGVectorVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.qdrant import QdrantVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.weaviate import WeaviateVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.mongodb import MongoDBVectorIOConfig
|
||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
from ..env import get_env_or_fail
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def embedding_model(request):
|
||||
if hasattr(request, "param"):
|
||||
return request.param
|
||||
return request.config.getoption("--embedding-model", None)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def vector_io_remote() -> ProviderFixture:
|
||||
return remote_stack_fixture()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def vector_io_faiss() -> ProviderFixture:
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="faiss",
|
||||
provider_type="inline::faiss",
|
||||
config=FaissVectorIOConfig(
|
||||
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
|
||||
).model_dump(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def vector_io_sqlite_vec() -> ProviderFixture:
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="sqlite_vec",
|
||||
provider_type="inline::sqlite_vec",
|
||||
config=SQLiteVectorIOConfig(
|
||||
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
|
||||
).model_dump(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def vector_io_pgvector() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="pgvector",
|
||||
provider_type="remote::pgvector",
|
||||
config=PGVectorVectorIOConfig(
|
||||
host=os.getenv("PGVECTOR_HOST", "localhost"),
|
||||
port=os.getenv("PGVECTOR_PORT", 5432),
|
||||
db=get_env_or_fail("PGVECTOR_DB"),
|
||||
user=get_env_or_fail("PGVECTOR_USER"),
|
||||
password=get_env_or_fail("PGVECTOR_PASSWORD"),
|
||||
).model_dump(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def vector_io_weaviate() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="weaviate",
|
||||
provider_type="remote::weaviate",
|
||||
config=WeaviateVectorIOConfig().model_dump(),
|
||||
)
|
||||
],
|
||||
provider_data=dict(
|
||||
weaviate_api_key=get_env_or_fail("WEAVIATE_API_KEY"),
|
||||
weaviate_cluster_url=get_env_or_fail("WEAVIATE_CLUSTER_URL"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def vector_io_chroma() -> ProviderFixture:
|
||||
url = os.getenv("CHROMA_URL")
|
||||
if url:
|
||||
config = ChromaVectorIOConfig(url=url)
|
||||
provider_type = "remote::chromadb"
|
||||
else:
|
||||
if not os.getenv("CHROMA_DB_PATH"):
|
||||
raise ValueError("CHROMA_DB_PATH or CHROMA_URL must be set")
|
||||
config = InlineChromaVectorIOConfig(db_path=os.getenv("CHROMA_DB_PATH"))
|
||||
provider_type = "inline::chromadb"
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="chroma",
|
||||
provider_type=provider_type,
|
||||
config=config.model_dump(),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def vector_io_qdrant() -> ProviderFixture:
|
||||
url = os.getenv("QDRANT_URL")
|
||||
if url:
|
||||
config = QdrantVectorIOConfig(url=url)
|
||||
provider_type = "remote::qdrant"
|
||||
else:
|
||||
raise ValueError("QDRANT_URL must be set")
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="qdrant",
|
||||
provider_type=provider_type,
|
||||
config=config.model_dump(),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def vector_io_mongodb() -> ProviderFixture:
|
||||
connection_str = get_env_or_fail("MONGODB_CONNECTION_STR")
|
||||
namespace = get_env_or_fail("MONGODB_NAMESPACE")
|
||||
config = MongoDBVectorIOConfig(connection_str=connection_str, namespace=namespace)
|
||||
provider_type = "remote::mongodb"
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="mongodb",
|
||||
provider_type=provider_type,
|
||||
config=config.model_dump(),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
VECTOR_IO_FIXTURES = ["faiss", "pgvector", "weaviate", "chroma", "qdrant", "sqlite_vec", "mongodb"]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def vector_io_stack(embedding_model, request):
|
||||
fixture_dict = request.param
|
||||
|
||||
providers = {}
|
||||
provider_data = {}
|
||||
for key in ["inference", "vector_io"]:
|
||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||
providers[key] = fixture.providers
|
||||
if fixture.provider_data:
|
||||
provider_data.update(fixture.provider_data)
|
||||
|
||||
test_stack = await construct_stack_for_test(
|
||||
[Api.vector_io, Api.inference],
|
||||
providers,
|
||||
provider_data,
|
||||
models=[
|
||||
ModelInput(
|
||||
model_id=embedding_model,
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
return test_stack.impls[Api.vector_io], test_stack.impls[Api.vector_dbs]
|
Loading…
Add table
Add a link
Reference in a new issue