adding mongodb vector_io module

updated mongodb.py from print to log

add documentation for mongodb vector search module

changed insert to update mongodb

bug fix mongodb json object conversion error
This commit is contained in:
Ashwin Gangadhar 2025-02-19 21:48:05 +05:30
parent d224ae0c8e
commit 80d9d50954
8 changed files with 503 additions and 65 deletions

View file

@ -48,6 +48,7 @@ Here is a list of the various API providers and available distributions that can
| NVIDIA NIM | Hosted and Single Node | | ✅ | | | |
| Chroma | Single Node | | | ✅ | | |
| PG Vector | Single Node | | | ✅ | | |
| MongoDB Atlas | Hosted | | | ✅ | | |
| PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | |
| vLLM | Hosted and Single Node | | ✅ | | | |

View file

@ -0,0 +1,35 @@
---
orphan: true
---
# MongoDB Atlas
[MongoDB Atlas](https://www.mongodb.com/atlas) is a cloud database service that can be used as a vector store provider for Llama Stack. It supports vector search capabilities through its Atlas Vector Search feature, allowing you to store and query vectors within your MongoDB database.
## Features
MongoDB Atlas Vector Search supports:
- Store embeddings and their metadata
- Vector search with multiple algorithms (cosine similarity, euclidean distance, dot product)
- Hybrid search (combining vector and keyword search)
- Metadata filtering
- Scalable vector indexing
- Managed cloud infrastructure
## Usage
To use MongoDB Atlas in your Llama Stack project, follow these steps:
1. Create a MongoDB Atlas account and cluster.
2. Configure your Atlas cluster to enable Vector Search.
3. Configure your Llama Stack project to use MongoDB Atlas.
4. Start storing and querying vectors.
## Installation
You can install the MongoDB Python driver using pip:
```bash
pip install pymongo
```
## Documentation
See [MongoDB Atlas Vector Search documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/) for more details about vector search capabilities in MongoDB Atlas.

View file

@ -112,6 +112,16 @@ def available_providers() -> List[ProviderSpec]:
),
api_dependencies=[Api.inference],
),
remote_provider_spec(
Api.vector_io,
AdapterSpec(
adapter_type="mongodb",
pip_packages=["pymongo"],
module="llama_stack.providers.remote.vector_io.mongodb",
config_class="llama_stack.providers.remote.vector_io.mongodb.MongoDBVectorIOConfig",
),
api_dependencies=[Api.inference],
),
remote_provider_spec(
Api.vector_io,
AdapterSpec(

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import MongoDBVectorIOConfig
async def get_adapter_impl(config: MongoDBVectorIOConfig, deps: Dict[Api, ProviderSpec]):
from .mongodb import MongoDBVectorIOAdapter
impl = MongoDBVectorIOAdapter(config, deps[Api.inference])
await impl.initialize()
return impl

View file

@ -4,26 +4,22 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, List
from pydantic import BaseModel, Field
class MongoDBVectorIOConfig(BaseModel):
conncetion_str: str
namespace: str = Field(None, description="Namespace of the MongoDB collection")
index_name: Optional[str] = Field("default", description="Name of the index in the MongoDB collection")
filter_fields: Optional[str] = Field(None, description="Fields to filter the MongoDB collection")
embedding_field: Optional[str] = Field("embeddings", description="Field name for the embeddings in the MongoDB collection")
text_field: Optional[str] = Field("text", description="Field name for the text in the MongoDB collection")
connection_str: str = Field(None, description="Connection string for the MongoDB Atlas collection")
namespace: str = Field(None, description="Namespace i.e. db_name.collection_name of the MongoDB Atlas collection")
index_name: Optional[str] = Field("default", description="Name of the index in the MongoDB Atlas collection")
filter_fields: Optional[List[str]] = Field([], description="Fields to filter along side vector search in MongoDB Atlas collection")
embeddings_key: Optional[str] = Field("embeddings", description="Field name for the embeddings in the MongoDB Atlas collection")
text_field: Optional[str] = Field("text", description="Field name for the text in the MongoDB Atlas collection")
@classmethod
def sample_config(cls) -> Dict[str, Any]:
return {
"connection_str": "{env.MONGODB_CONNECTION_STR}",
"namespace": "{env.MONGODB_NAMESPACE}",
"index_name": "{env.MONGODB_INDEX_NAME}",
"filter_fields": "{env.MONGODB_FILTER_FIELDS}",
"embedding_field": "{env.MONGODB_EMBEDDING_FIELD}",
"text_field": "{env.MONGODB_TEXT_FIELD}",
}

View file

@ -1,5 +1,3 @@
import pymongo
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
@ -11,8 +9,8 @@ import logging
from typing import Any, Dict, List, Optional, Union
from urllib.parse import urlparse
from pymongo import MongoClient,
from pymongo.operations import UpdateOne, InsertOne, DeleteOne, DeleteMany, SearchIndexModel
from pymongo import MongoClient
from pymongo.operations import InsertOne, SearchIndexModel, UpdateOne
import certifi
from numpy.typing import NDArray
@ -25,19 +23,24 @@ from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex,
VectorDBWithIndex,
)
from .config import MongoDBVectorIOConfig
from .config import MongoDBAtlasVectorIOConfig
from time import sleep
log = logging.getLogger(__name__)
CHUNK_ID_KEY = "_chunk_id"
class MongoDBAtlasIndex(EmbeddingIndex):
def __init__(self, client: MongoClient, namespace: str, embeddings_key: str, index_name: str):
def __init__(self, client: MongoClient, namespace: str, embeddings_key: str, embedding_dimension: str, index_name: str, filter_fields: List[str]):
self.client = client
self.namespace = namespace
self.embeddings_key = embeddings_key
self.index_name = index_name
self.filter_fields = filter_fields
self.embedding_dimension = embedding_dimension
def _get_index_config(self, collection, index_name):
idxs = list(collection.list_search_indexes())
@ -45,14 +48,39 @@ class MongoDBAtlasIndex(EmbeddingIndex):
if ele["name"] == index_name:
return ele
def _get_search_index_model(self):
index_fields = [
{
"path": self.embeddings_key,
"type": "vector",
"numDimensions": self.embedding_dimension,
"similarity": "cosine"
}
]
if len(self.filter_fields) > 0:
for filter_field in self.filter_fields:
index_fields.append(
{
"path": filter_field,
"type": "filter"
}
)
return SearchIndexModel(
name=self.index_name,
type="vectorSearch",
definition={
"fields": index_fields
}
)
def _check_n_create_index(self):
client = self.client
db,collection = self.namespace.split(".")
db, collection = self.namespace.split(".")
collection = client[db][collection]
index_name = self.index_name
print(">>>>>>>>Index name: ", index_name, "<<<<<<<<<<")
idx = self._get_index_config(collection, index_name)
print(idx)
if not idx:
log.info("Creating search index ...")
search_index_model = self._get_search_index_model()
@ -60,10 +88,10 @@ class MongoDBAtlasIndex(EmbeddingIndex):
while True:
idx = self._get_index_config(collection, index_name)
if idx and idx["queryable"]:
print("Search index created successfully.")
log.info("Search index created successfully.")
break
else:
print("Waiting for search index to be created ...")
log.info("Waiting for search index to be created ...")
sleep(5)
else:
log.info("Search index already exists.")
@ -77,46 +105,53 @@ class MongoDBAtlasIndex(EmbeddingIndex):
operations = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
chunk_id = f"{chunk.metadata['document_id']}:chunk-{i}"
operations.append(
InsertOne(
UpdateOne(
{CHUNK_ID_KEY: chunk_id},
{
CHUNK_ID_KEY: chunk_id,
"chunk_content": chunk.model_dump_json(),
self.embeddings_key: embedding.tolist(),
}
"$set": {
CHUNK_ID_KEY: chunk_id,
"chunk_content": json.loads(chunk.model_dump_json()),
self.embeddings_key: embedding.tolist(),
}
},
upsert=True,
)
)
# Perform the bulk operations
db,collection_name = self.namespace.split(".")
db, collection_name = self.namespace.split(".")
collection = self.client[db][collection_name]
collection.bulk_write(operations)
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
print(f"Added {len(chunks)} chunks to the collection")
# Create a search index model
print("Creating search index ...")
self._check_n_create_index()
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
# Perform a query
db,collection_name = self.namespace.split(".")
db, collection_name = self.namespace.split(".")
collection = self.client[db][collection_name]
# Create vector search query
vs_query = {"$vectorSearch":
{
"index": "vector_index",
"path": self.embeddings_key,
"queryVector": embedding.tolist(),
"numCandidates": k,
"limit": k,
}
}
vs_query = {"$vectorSearch":
{
"index": self.index_name,
"path": self.embeddings_key,
"queryVector": embedding.tolist(),
"numCandidates": k,
"limit": k,
}
}
# Add a field to store the score
score_add_field_query = {
"$addFields": {
"score": {"$meta": "vectorSearchScore"}
}
}
if score_threshold is None:
score_threshold = 0.01
# Filter the results based on the score threshold
filter_query = {
"$match": {
@ -141,60 +176,90 @@ class MongoDBAtlasIndex(EmbeddingIndex):
chunks = []
scores = []
for result in results:
content = result["chunk_content"]
chunk = Chunk(
metadata={"document_id": result[CHUNK_ID_KEY]},
content=json.loads(result["chunk_content"]),
metadata=content["metadata"],
content=content["content"],
)
chunks.append(chunk)
scores.append(result["score"])
return QueryChunksResponse(chunks=chunks, scores=scores)
class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: MongoDBAtlasVectorIOConfig, inference_api: Api.inference):
async def delete(self):
db, _ = self.namespace.split(".")
self.client.drop_database(db)
class MongoDBVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: MongoDBVectorIOConfig, inference_api: Api.inference):
self.config = config
self.inference_api = inference_api
self.cache = {}
async def initialize(self) -> None:
self.client = MongoClient(
self.config.uri,
self.config.connection_str,
tlsCAFile=certifi.where(),
)
self.cache = {}
pass
async def shutdown(self) -> None:
self.client.close()
pass
if not self.client:
self.client.close()
async def register_vector_db( self, vector_db: VectorDB) -> None:
index = VectorDBWithIndex(
async def register_vector_db(self, vector_db: VectorDB) -> None:
index=MongoDBAtlasIndex(
client=self.client,
namespace=self.config.namespace,
embeddings_key=self.config.embeddings_key,
embedding_dimension=vector_db.embedding_dimension,
index_name=self.config.index_name,
filter_fields=self.config.filter_fields,
)
self.cache[vector_db.identifier] = VectorDBWithIndex(
vector_db=vector_db,
index=index,
inference_api=self.inference_api,
)
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex:
if vector_db_id in self.cache:
return self.cache[vector_db_id]
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
self.cache[vector_db_id] = VectorDBWithIndex(
vector_db=vector_db_id,
index=MongoDBAtlasIndex(
client=self.client,
namespace=self.config.namespace,
embeddings_key=self.config.embeddings_key,
embedding_dimension=vector_db.embedding_dimension,
index_name=self.config.index_name,
filter_fields=self.config.filter_fields,
),
inference_api=self.inference_api,
)
self.cache[vector_db] = index
pass
return self.cache[vector_db_id]
async def insert_chunks(self, vector_db_id, chunks, ttl_seconds = None):
index = self.cache[vector_db_id].index
async def unregister_vector_db(self, vector_db_id: str) -> None:
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
async def insert_chunks(self,
vector_db_id: str,
chunks: List[Chunk],
ttl_seconds: Optional[int] = None,
) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
await index.insert_chunks(chunks)
async def query_chunks(self, vector_db_id, query, params = None):
index = self.cache[vector_db_id].index
async def query_chunks(self,
vector_db_id: str,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
return await index.query_chunks(query, params)

View file

@ -0,0 +1,116 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from ..conftest import (
get_provider_fixture_overrides,
get_provider_fixture_overrides_from_test_config,
get_test_config_for_api,
)
from ..inference.fixtures import INFERENCE_FIXTURES
from .fixtures import VECTOR_IO_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"inference": "sentence_transformers",
"vector_io": "faiss",
},
id="sentence_transformers",
marks=pytest.mark.sentence_transformers,
),
pytest.param(
{
"inference": "ollama",
"vector_io": "pgvector",
},
id="pgvector",
marks=pytest.mark.pgvector,
),
pytest.param(
{
"inference": "ollama",
"vector_io": "faiss",
},
id="ollama",
marks=pytest.mark.ollama,
),
pytest.param(
{
"inference": "ollama",
"vector_io": "sqlite_vec",
},
id="sqlite_vec",
marks=pytest.mark.ollama,
),
pytest.param(
{
"inference": "sentence_transformers",
"vector_io": "chroma",
},
id="chroma",
marks=pytest.mark.chroma,
),
pytest.param(
{
"inference": "ollama",
"vector_io": "qdrant",
},
id="qdrant",
marks=pytest.mark.qdrant,
),
pytest.param(
{
"inference": "fireworks",
"vector_io": "weaviate",
},
id="weaviate",
marks=pytest.mark.weaviate,
),
pytest.param(
{
"inference": "bedrock",
"vector_io": "mongodb",
},
id="mongodb",
marks=pytest.mark.mongodb,
),
]
def pytest_configure(config):
for fixture_name in VECTOR_IO_FIXTURES:
config.addinivalue_line(
"markers",
f"{fixture_name}: marks tests as {fixture_name} specific",
)
def pytest_generate_tests(metafunc):
test_config = get_test_config_for_api(metafunc.config, "vector_io")
if "embedding_model" in metafunc.fixturenames:
model = getattr(test_config, "embedding_model", None)
# Fall back to the default if not specified by the config file
model = model or metafunc.config.getoption("--embedding-model")
if model:
params = [pytest.param(model, id="")]
else:
params = [pytest.param("all-minilm:l6-v2", id="")]
metafunc.parametrize("embedding_model", params, indirect=True)
if "vector_io_stack" in metafunc.fixturenames:
available_fixtures = {
"inference": INFERENCE_FIXTURES,
"vector_io": VECTOR_IO_FIXTURES,
}
combinations = (
get_provider_fixture_overrides_from_test_config(metafunc.config, "vector_io", DEFAULT_PROVIDER_COMBINATIONS)
or get_provider_fixture_overrides(metafunc.config, available_fixtures)
or DEFAULT_PROVIDER_COMBINATIONS
)
metafunc.parametrize("vector_io_stack", combinations, indirect=True)

View file

@ -0,0 +1,196 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import tempfile
import pytest
import pytest_asyncio
from llama_stack.apis.models import ModelInput, ModelType
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
from llama_stack.providers.inline.vector_io.faiss import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
from llama_stack.providers.remote.vector_io.chroma import ChromaVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector import PGVectorVectorIOConfig
from llama_stack.providers.remote.vector_io.qdrant import QdrantVectorIOConfig
from llama_stack.providers.remote.vector_io.weaviate import WeaviateVectorIOConfig
from llama_stack.providers.remote.vector_io.mongodb import MongoDBVectorIOConfig
from llama_stack.providers.tests.resolver import construct_stack_for_test
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail
@pytest.fixture(scope="session")
def embedding_model(request):
if hasattr(request, "param"):
return request.param
return request.config.getoption("--embedding-model", None)
@pytest.fixture(scope="session")
def vector_io_remote() -> ProviderFixture:
return remote_stack_fixture()
@pytest.fixture(scope="session")
def vector_io_faiss() -> ProviderFixture:
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
return ProviderFixture(
providers=[
Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig(
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
).model_dump(),
)
],
)
@pytest.fixture(scope="session")
def vector_io_sqlite_vec() -> ProviderFixture:
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
return ProviderFixture(
providers=[
Provider(
provider_id="sqlite_vec",
provider_type="inline::sqlite_vec",
config=SQLiteVectorIOConfig(
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
).model_dump(),
)
],
)
@pytest.fixture(scope="session")
def vector_io_pgvector() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="pgvector",
provider_type="remote::pgvector",
config=PGVectorVectorIOConfig(
host=os.getenv("PGVECTOR_HOST", "localhost"),
port=os.getenv("PGVECTOR_PORT", 5432),
db=get_env_or_fail("PGVECTOR_DB"),
user=get_env_or_fail("PGVECTOR_USER"),
password=get_env_or_fail("PGVECTOR_PASSWORD"),
).model_dump(),
)
],
)
@pytest.fixture(scope="session")
def vector_io_weaviate() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="weaviate",
provider_type="remote::weaviate",
config=WeaviateVectorIOConfig().model_dump(),
)
],
provider_data=dict(
weaviate_api_key=get_env_or_fail("WEAVIATE_API_KEY"),
weaviate_cluster_url=get_env_or_fail("WEAVIATE_CLUSTER_URL"),
),
)
@pytest.fixture(scope="session")
def vector_io_chroma() -> ProviderFixture:
url = os.getenv("CHROMA_URL")
if url:
config = ChromaVectorIOConfig(url=url)
provider_type = "remote::chromadb"
else:
if not os.getenv("CHROMA_DB_PATH"):
raise ValueError("CHROMA_DB_PATH or CHROMA_URL must be set")
config = InlineChromaVectorIOConfig(db_path=os.getenv("CHROMA_DB_PATH"))
provider_type = "inline::chromadb"
return ProviderFixture(
providers=[
Provider(
provider_id="chroma",
provider_type=provider_type,
config=config.model_dump(),
)
]
)
@pytest.fixture(scope="session")
def vector_io_qdrant() -> ProviderFixture:
url = os.getenv("QDRANT_URL")
if url:
config = QdrantVectorIOConfig(url=url)
provider_type = "remote::qdrant"
else:
raise ValueError("QDRANT_URL must be set")
return ProviderFixture(
providers=[
Provider(
provider_id="qdrant",
provider_type=provider_type,
config=config.model_dump(),
)
]
)
@pytest.fixture(scope="session")
def vector_io_mongodb() -> ProviderFixture:
connection_str = get_env_or_fail("MONGODB_CONNECTION_STR")
namespace = get_env_or_fail("MONGODB_NAMESPACE")
config = MongoDBVectorIOConfig(connection_str=connection_str, namespace=namespace)
provider_type = "remote::mongodb"
return ProviderFixture(
providers=[
Provider(
provider_id="mongodb",
provider_type=provider_type,
config=config.model_dump(),
)
]
)
VECTOR_IO_FIXTURES = ["faiss", "pgvector", "weaviate", "chroma", "qdrant", "sqlite_vec", "mongodb"]
@pytest_asyncio.fixture(scope="session")
async def vector_io_stack(embedding_model, request):
fixture_dict = request.param
providers = {}
provider_data = {}
for key in ["inference", "vector_io"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if fixture.provider_data:
provider_data.update(fixture.provider_data)
test_stack = await construct_stack_for_test(
[Api.vector_io, Api.inference],
providers,
provider_data,
models=[
ModelInput(
model_id=embedding_model,
model_type=ModelType.embedding,
metadata={
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),
},
)
],
)
return test_stack.impls[Api.vector_io], test_stack.impls[Api.vector_dbs]