feat: adding mongodb vector_io module

updated mongodb sample run config
This commit is contained in:
Ashwin Gangadhar 2025-02-19 21:48:05 +05:30
parent bfece15fb4
commit ee981a0c02
6 changed files with 194 additions and 67 deletions

View file

@ -50,6 +50,7 @@ Here is a list of the various API providers and available distributions that can
| NVIDIA NIM | Hosted and Single Node | | ✅ | | | |
| Chroma | Single Node | | | ✅ | | |
| PG Vector | Single Node | | | ✅ | | |
| MongoDB Atlas | Hosted | | | ✅ | | |
| PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | |
| vLLM | Hosted and Single Node | | ✅ | | | |
| OpenAI | Hosted | | ✅ | | | |

View file

@ -0,0 +1,35 @@
---
orphan: true
---
# MongoDB Atlas
[MongoDB Atlas](https://www.mongodb.com/atlas) is a cloud database service that can be used as a vector store provider for Llama Stack. It supports vector search capabilities through its Atlas Vector Search feature, allowing you to store and query vectors within your MongoDB database.
## Features
MongoDB Atlas Vector Search supports:
- Store embeddings and their metadata
- Vector search with multiple algorithms (cosine similarity, euclidean distance, dot product)
- Hybrid search (combining vector and keyword search)
- Metadata filtering
- Scalable vector indexing
- Managed cloud infrastructure
## Usage
To use MongoDB Atlas in your Llama Stack project, follow these steps:
1. Create a MongoDB Atlas account and cluster.
2. Configure your Atlas cluster to enable Vector Search.
3. Configure your Llama Stack project to use MongoDB Atlas.
4. Start storing and querying vectors.
## Installation
You can install the MongoDB Python driver using pip:
```bash
pip install pymongo
```
## Documentation
See [MongoDB Atlas Vector Search documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/) for more details about vector search capabilities in MongoDB Atlas.

View file

@ -110,6 +110,16 @@ def available_providers() -> List[ProviderSpec]:
),
api_dependencies=[Api.inference],
),
remote_provider_spec(
Api.vector_io,
AdapterSpec(
adapter_type="mongodb",
pip_packages=["pymongo"],
module="llama_stack.providers.remote.vector_io.mongodb",
config_class="llama_stack.providers.remote.vector_io.mongodb.MongoDBVectorIOConfig",
),
api_dependencies=[Api.inference],
),
remote_provider_spec(
Api.vector_io,
AdapterSpec(

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import MongoDBVectorIOConfig
async def get_adapter_impl(config: MongoDBVectorIOConfig, deps: Dict[Api, ProviderSpec]):
from .mongodb import MongoDBVectorIOAdapter
impl = MongoDBVectorIOAdapter(config, deps[Api.inference])
await impl.initialize()
return impl

View file

@ -4,26 +4,23 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, List
from pydantic import BaseModel, Field
class MongoDBVectorIOConfig(BaseModel):
conncetion_str: str
namespace: str = Field(None, description="Namespace of the MongoDB collection")
index_name: Optional[str] = Field("default", description="Name of the index in the MongoDB collection")
filter_fields: Optional[str] = Field(None, description="Fields to filter the MongoDB collection")
embedding_field: Optional[str] = Field("embeddings", description="Field name for the embeddings in the MongoDB collection")
text_field: Optional[str] = Field("text", description="Field name for the text in the MongoDB collection")
connection_str: str = Field(None, description="Connection string for the MongoDB Atlas collection")
namespace: str = Field(None, description="Namespace i.e. db_name.collection_name of the MongoDB Atlas collection")
index_name: Optional[str] = Field("default", description="Name of the index in the MongoDB Atlas collection")
filter_fields: Optional[List[str]] = Field([], description="Fields to filter along side vector search in MongoDB Atlas collection")
embeddings_key: Optional[str] = Field("embeddings", description="Field name for the embeddings in the MongoDB Atlas collection")
text_field: Optional[str] = Field("text", description="Field name for the text in the MongoDB Atlas collection")
@classmethod
def sample_config(cls) -> Dict[str, Any]:
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
return {
"connection_str": "{env.MONGODB_CONNECTION_STR}",
"namespace": "{env.MONGODB_NAMESPACE}",
"index_name": "{env.MONGODB_INDEX_NAME}",
"filter_fields": "{env.MONGODB_FILTER_FIELDS}",
"embedding_field": "{env.MONGODB_EMBEDDING_FIELD}",
"text_field": "{env.MONGODB_TEXT_FIELD}",
}
}

View file

@ -1,5 +1,3 @@
import pymongo
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
@ -11,8 +9,8 @@ import logging
from typing import Any, Dict, List, Optional, Union
from urllib.parse import urlparse
from pymongo import MongoClient,
from pymongo.operations import UpdateOne, InsertOne, DeleteOne, DeleteMany, SearchIndexModel
from pymongo import MongoClient
from pymongo.operations import InsertOne, SearchIndexModel, UpdateOne
import certifi
from numpy.typing import NDArray
@ -25,19 +23,24 @@ from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex,
VectorDBWithIndex,
)
from .config import MongoDBVectorIOConfig
from .config import MongoDBAtlasVectorIOConfig
from time import sleep
log = logging.getLogger(__name__)
CHUNK_ID_KEY = "_chunk_id"
class MongoDBAtlasIndex(EmbeddingIndex):
def __init__(self, client: MongoClient, namespace: str, embeddings_key: str, index_name: str):
def __init__(self, client: MongoClient, namespace: str, embeddings_key: str, embedding_dimension: str, index_name: str, filter_fields: List[str]):
self.client = client
self.namespace = namespace
self.embeddings_key = embeddings_key
self.index_name = index_name
self.filter_fields = filter_fields
self.embedding_dimension = embedding_dimension
def _get_index_config(self, collection, index_name):
idxs = list(collection.list_search_indexes())
@ -45,14 +48,39 @@ class MongoDBAtlasIndex(EmbeddingIndex):
if ele["name"] == index_name:
return ele
def _get_search_index_model(self):
index_fields = [
{
"path": self.embeddings_key,
"type": "vector",
"numDimensions": self.embedding_dimension,
"similarity": "cosine"
}
]
if len(self.filter_fields) > 0:
for filter_field in self.filter_fields:
index_fields.append(
{
"path": filter_field,
"type": "filter"
}
)
return SearchIndexModel(
name=self.index_name,
type="vectorSearch",
definition={
"fields": index_fields
}
)
def _check_n_create_index(self):
client = self.client
db,collection = self.namespace.split(".")
db, collection = self.namespace.split(".")
collection = client[db][collection]
index_name = self.index_name
print(">>>>>>>>Index name: ", index_name, "<<<<<<<<<<")
idx = self._get_index_config(collection, index_name)
print(idx)
if not idx:
log.info("Creating search index ...")
search_index_model = self._get_search_index_model()
@ -60,10 +88,10 @@ class MongoDBAtlasIndex(EmbeddingIndex):
while True:
idx = self._get_index_config(collection, index_name)
if idx and idx["queryable"]:
print("Search index created successfully.")
log.info("Search index created successfully.")
break
else:
print("Waiting for search index to be created ...")
log.info("Waiting for search index to be created ...")
sleep(5)
else:
log.info("Search index already exists.")
@ -77,46 +105,53 @@ class MongoDBAtlasIndex(EmbeddingIndex):
operations = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
chunk_id = f"{chunk.metadata['document_id']}:chunk-{i}"
operations.append(
InsertOne(
UpdateOne(
{CHUNK_ID_KEY: chunk_id},
{
CHUNK_ID_KEY: chunk_id,
"chunk_content": chunk.model_dump_json(),
self.embeddings_key: embedding.tolist(),
}
"$set": {
CHUNK_ID_KEY: chunk_id,
"chunk_content": json.loads(chunk.model_dump_json()),
self.embeddings_key: embedding.tolist(),
}
},
upsert=True,
)
)
# Perform the bulk operations
db,collection_name = self.namespace.split(".")
db, collection_name = self.namespace.split(".")
collection = self.client[db][collection_name]
collection.bulk_write(operations)
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
print(f"Added {len(chunks)} chunks to the collection")
# Create a search index model
print("Creating search index ...")
self._check_n_create_index()
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
# Perform a query
db,collection_name = self.namespace.split(".")
db, collection_name = self.namespace.split(".")
collection = self.client[db][collection_name]
# Create vector search query
vs_query = {"$vectorSearch":
{
"index": "vector_index",
"path": self.embeddings_key,
"queryVector": embedding.tolist(),
"numCandidates": k,
"limit": k,
}
}
vs_query = {"$vectorSearch":
{
"index": self.index_name,
"path": self.embeddings_key,
"queryVector": embedding.tolist(),
"numCandidates": k,
"limit": k,
}
}
# Add a field to store the score
score_add_field_query = {
"$addFields": {
"score": {"$meta": "vectorSearchScore"}
}
}
if score_threshold is None:
score_threshold = 0.01
# Filter the results based on the score threshold
filter_query = {
"$match": {
@ -141,60 +176,90 @@ class MongoDBAtlasIndex(EmbeddingIndex):
chunks = []
scores = []
for result in results:
content = result["chunk_content"]
chunk = Chunk(
metadata={"document_id": result[CHUNK_ID_KEY]},
content=json.loads(result["chunk_content"]),
metadata=content["metadata"],
content=content["content"],
)
chunks.append(chunk)
scores.append(result["score"])
return QueryChunksResponse(chunks=chunks, scores=scores)
class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: MongoDBAtlasVectorIOConfig, inference_api: Api.inference):
async def delete(self):
db, _ = self.namespace.split(".")
self.client.drop_database(db)
class MongoDBVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: MongoDBVectorIOConfig, inference_api: Api.inference):
self.config = config
self.inference_api = inference_api
self.cache = {}
async def initialize(self) -> None:
self.client = MongoClient(
self.config.uri,
self.config.connection_str,
tlsCAFile=certifi.where(),
)
self.cache = {}
pass
async def shutdown(self) -> None:
self.client.close()
pass
if not self.client:
self.client.close()
async def register_vector_db( self, vector_db: VectorDB) -> None:
index = VectorDBWithIndex(
async def register_vector_db(self, vector_db: VectorDB) -> None:
index=MongoDBAtlasIndex(
client=self.client,
namespace=self.config.namespace,
embeddings_key=self.config.embeddings_key,
embedding_dimension=vector_db.embedding_dimension,
index_name=self.config.index_name,
filter_fields=self.config.filter_fields,
)
self.cache[vector_db.identifier] = VectorDBWithIndex(
vector_db=vector_db,
index=index,
inference_api=self.inference_api,
)
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex:
if vector_db_id in self.cache:
return self.cache[vector_db_id]
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
self.cache[vector_db_id] = VectorDBWithIndex(
vector_db=vector_db_id,
index=MongoDBAtlasIndex(
client=self.client,
namespace=self.config.namespace,
embeddings_key=self.config.embeddings_key,
embedding_dimension=vector_db.embedding_dimension,
index_name=self.config.index_name,
filter_fields=self.config.filter_fields,
),
inference_api=self.inference_api,
)
self.cache[vector_db] = index
pass
return self.cache[vector_db_id]
async def insert_chunks(self, vector_db_id, chunks, ttl_seconds = None):
index = self.cache[vector_db_id].index
async def unregister_vector_db(self, vector_db_id: str) -> None:
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
async def insert_chunks(self,
vector_db_id: str,
chunks: List[Chunk],
ttl_seconds: Optional[int] = None,
) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
await index.insert_chunks(chunks)
async def query_chunks(self, vector_db_id, query, params = None):
index = self.cache[vector_db_id].index
async def query_chunks(self,
vector_db_id: str,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
return await index.query_chunks(query, params)