diff --git a/README.md b/README.md index 918433d51..e2529cac1 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,7 @@ Here is a list of the various API providers and available distributions that can | NVIDIA NIM | Hosted and Single Node | | ✅ | | | | | Chroma | Single Node | | | ✅ | | | | PG Vector | Single Node | | | ✅ | | | +| MongoDB Atlas | Hosted | | | ✅ | | | | PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | | | vLLM | Hosted and Single Node | | ✅ | | | | | OpenAI | Hosted | | ✅ | | | | diff --git a/docs/source/providers/vector_io/mongodb.md b/docs/source/providers/vector_io/mongodb.md new file mode 100644 index 000000000..67e4adec0 --- /dev/null +++ b/docs/source/providers/vector_io/mongodb.md @@ -0,0 +1,35 @@ +--- +orphan: true +--- +# MongoDB Atlas + +[MongoDB Atlas](https://www.mongodb.com/atlas) is a cloud database service that can be used as a vector store provider for Llama Stack. It supports vector search capabilities through its Atlas Vector Search feature, allowing you to store and query vectors within your MongoDB database. + +## Features +MongoDB Atlas Vector Search supports: +- Store embeddings and their metadata +- Vector search with multiple algorithms (cosine similarity, euclidean distance, dot product) +- Hybrid search (combining vector and keyword search) +- Metadata filtering +- Scalable vector indexing +- Managed cloud infrastructure + +## Usage + +To use MongoDB Atlas in your Llama Stack project, follow these steps: + +1. Create a MongoDB Atlas account and cluster. +2. Configure your Atlas cluster to enable Vector Search. +3. Configure your Llama Stack project to use MongoDB Atlas. +4. Start storing and querying vectors. + +## Installation + +You can install the MongoDB Python driver using pip: + +```bash +pip install pymongo +``` + +## Documentation +See [MongoDB Atlas Vector Search documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/) for more details about vector search capabilities in MongoDB Atlas. diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py index 93031763d..ee2d4c706 100644 --- a/llama_stack/providers/registry/vector_io.py +++ b/llama_stack/providers/registry/vector_io.py @@ -110,6 +110,16 @@ def available_providers() -> List[ProviderSpec]: ), api_dependencies=[Api.inference], ), + remote_provider_spec( + Api.vector_io, + AdapterSpec( + adapter_type="mongodb", + pip_packages=["pymongo"], + module="llama_stack.providers.remote.vector_io.mongodb", + config_class="llama_stack.providers.remote.vector_io.mongodb.MongoDBVectorIOConfig", + ), + api_dependencies=[Api.inference], + ), remote_provider_spec( Api.vector_io, AdapterSpec( diff --git a/llama_stack/providers/remote/vector_io/mongodb/__init__.py b/llama_stack/providers/remote/vector_io/mongodb/__init__.py index e69de29bb..fd46551b1 100644 --- a/llama_stack/providers/remote/vector_io/mongodb/__init__.py +++ b/llama_stack/providers/remote/vector_io/mongodb/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Dict + +from llama_stack.providers.datatypes import Api, ProviderSpec + +from .config import MongoDBVectorIOConfig + + +async def get_adapter_impl(config: MongoDBVectorIOConfig, deps: Dict[Api, ProviderSpec]): + from .mongodb import MongoDBVectorIOAdapter + + impl = MongoDBVectorIOAdapter(config, deps[Api.inference]) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/vector_io/mongodb/config.py b/llama_stack/providers/remote/vector_io/mongodb/config.py index 620594566..0ebaad9f6 100644 --- a/llama_stack/providers/remote/vector_io/mongodb/config.py +++ b/llama_stack/providers/remote/vector_io/mongodb/config.py @@ -4,26 +4,23 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, List from pydantic import BaseModel, Field class MongoDBVectorIOConfig(BaseModel): - conncetion_str: str - namespace: str = Field(None, description="Namespace of the MongoDB collection") - index_name: Optional[str] = Field("default", description="Name of the index in the MongoDB collection") - filter_fields: Optional[str] = Field(None, description="Fields to filter the MongoDB collection") - embedding_field: Optional[str] = Field("embeddings", description="Field name for the embeddings in the MongoDB collection") - text_field: Optional[str] = Field("text", description="Field name for the text in the MongoDB collection") + connection_str: str = Field(None, description="Connection string for the MongoDB Atlas collection") + namespace: str = Field(None, description="Namespace i.e. db_name.collection_name of the MongoDB Atlas collection") + index_name: Optional[str] = Field("default", description="Name of the index in the MongoDB Atlas collection") + filter_fields: Optional[List[str]] = Field([], description="Fields to filter along side vector search in MongoDB Atlas collection") + embeddings_key: Optional[str] = Field("embeddings", description="Field name for the embeddings in the MongoDB Atlas collection") + text_field: Optional[str] = Field("text", description="Field name for the text in the MongoDB Atlas collection") + @classmethod - def sample_config(cls) -> Dict[str, Any]: + def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]: return { "connection_str": "{env.MONGODB_CONNECTION_STR}", "namespace": "{env.MONGODB_NAMESPACE}", - "index_name": "{env.MONGODB_INDEX_NAME}", - "filter_fields": "{env.MONGODB_FILTER_FIELDS}", - "embedding_field": "{env.MONGODB_EMBEDDING_FIELD}", - "text_field": "{env.MONGODB_TEXT_FIELD}", - } + } \ No newline at end of file diff --git a/llama_stack/providers/remote/vector_io/mongodb/mongodb_atlas.py b/llama_stack/providers/remote/vector_io/mongodb/mongodb.py similarity index 51% rename from llama_stack/providers/remote/vector_io/mongodb/mongodb_atlas.py rename to llama_stack/providers/remote/vector_io/mongodb/mongodb.py index bedb90754..cc172fc70 100644 --- a/llama_stack/providers/remote/vector_io/mongodb/mongodb_atlas.py +++ b/llama_stack/providers/remote/vector_io/mongodb/mongodb.py @@ -1,5 +1,3 @@ -import pymongo - # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # @@ -11,8 +9,8 @@ import logging from typing import Any, Dict, List, Optional, Union from urllib.parse import urlparse -from pymongo import MongoClient, -from pymongo.operations import UpdateOne, InsertOne, DeleteOne, DeleteMany, SearchIndexModel +from pymongo import MongoClient +from pymongo.operations import InsertOne, SearchIndexModel, UpdateOne import certifi from numpy.typing import NDArray @@ -25,19 +23,24 @@ from llama_stack.providers.utils.memory.vector_store import ( EmbeddingIndex, VectorDBWithIndex, ) + +from .config import MongoDBVectorIOConfig -from .config import MongoDBAtlasVectorIOConfig from time import sleep log = logging.getLogger(__name__) CHUNK_ID_KEY = "_chunk_id" + class MongoDBAtlasIndex(EmbeddingIndex): - def __init__(self, client: MongoClient, namespace: str, embeddings_key: str, index_name: str): + + def __init__(self, client: MongoClient, namespace: str, embeddings_key: str, embedding_dimension: str, index_name: str, filter_fields: List[str]): self.client = client self.namespace = namespace self.embeddings_key = embeddings_key self.index_name = index_name + self.filter_fields = filter_fields + self.embedding_dimension = embedding_dimension def _get_index_config(self, collection, index_name): idxs = list(collection.list_search_indexes()) @@ -45,14 +48,39 @@ class MongoDBAtlasIndex(EmbeddingIndex): if ele["name"] == index_name: return ele + def _get_search_index_model(self): + index_fields = [ + { + "path": self.embeddings_key, + "type": "vector", + "numDimensions": self.embedding_dimension, + "similarity": "cosine" + } + ] + + if len(self.filter_fields) > 0: + for filter_field in self.filter_fields: + index_fields.append( + { + "path": filter_field, + "type": "filter" + } + ) + + return SearchIndexModel( + name=self.index_name, + type="vectorSearch", + definition={ + "fields": index_fields + } + ) + def _check_n_create_index(self): client = self.client - db,collection = self.namespace.split(".") + db, collection = self.namespace.split(".") collection = client[db][collection] index_name = self.index_name - print(">>>>>>>>Index name: ", index_name, "<<<<<<<<<<") idx = self._get_index_config(collection, index_name) - print(idx) if not idx: log.info("Creating search index ...") search_index_model = self._get_search_index_model() @@ -60,10 +88,10 @@ class MongoDBAtlasIndex(EmbeddingIndex): while True: idx = self._get_index_config(collection, index_name) if idx and idx["queryable"]: - print("Search index created successfully.") + log.info("Search index created successfully.") break else: - print("Waiting for search index to be created ...") + log.info("Waiting for search index to be created ...") sleep(5) else: log.info("Search index already exists.") @@ -77,46 +105,53 @@ class MongoDBAtlasIndex(EmbeddingIndex): operations = [] for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): chunk_id = f"{chunk.metadata['document_id']}:chunk-{i}" + operations.append( - InsertOne( + UpdateOne( + {CHUNK_ID_KEY: chunk_id}, { - CHUNK_ID_KEY: chunk_id, - "chunk_content": chunk.model_dump_json(), - self.embeddings_key: embedding.tolist(), - } + "$set": { + CHUNK_ID_KEY: chunk_id, + "chunk_content": json.loads(chunk.model_dump_json()), + self.embeddings_key: embedding.tolist(), + } + }, + upsert=True, ) ) # Perform the bulk operations - db,collection_name = self.namespace.split(".") + db, collection_name = self.namespace.split(".") collection = self.client[db][collection_name] collection.bulk_write(operations) - - async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: - + print(f"Added {len(chunks)} chunks to the collection") # Create a search index model + print("Creating search index ...") self._check_n_create_index() + async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: # Perform a query - db,collection_name = self.namespace.split(".") + db, collection_name = self.namespace.split(".") collection = self.client[db][collection_name] # Create vector search query - vs_query = {"$vectorSearch": - { - "index": "vector_index", - "path": self.embeddings_key, - "queryVector": embedding.tolist(), - "numCandidates": k, - "limit": k, - } - } + vs_query = {"$vectorSearch": + { + "index": self.index_name, + "path": self.embeddings_key, + "queryVector": embedding.tolist(), + "numCandidates": k, + "limit": k, + } + } # Add a field to store the score score_add_field_query = { "$addFields": { "score": {"$meta": "vectorSearchScore"} } } + if score_threshold is None: + score_threshold = 0.01 # Filter the results based on the score threshold filter_query = { "$match": { @@ -141,60 +176,90 @@ class MongoDBAtlasIndex(EmbeddingIndex): chunks = [] scores = [] for result in results: + content = result["chunk_content"] chunk = Chunk( - metadata={"document_id": result[CHUNK_ID_KEY]}, - content=json.loads(result["chunk_content"]), + metadata=content["metadata"], + content=content["content"], ) chunks.append(chunk) scores.append(result["score"]) return QueryChunksResponse(chunks=chunks, scores=scores) - -class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): - def __init__(self, config: MongoDBAtlasVectorIOConfig, inference_api: Api.inference): + async def delete(self): + db, _ = self.namespace.split(".") + self.client.drop_database(db) + + +class MongoDBVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): + def __init__(self, config: MongoDBVectorIOConfig, inference_api: Api.inference): self.config = config self.inference_api = inference_api - + self.cache = {} async def initialize(self) -> None: self.client = MongoClient( - self.config.uri, + self.config.connection_str, tlsCAFile=certifi.where(), ) - self.cache = {} - pass async def shutdown(self) -> None: - self.client.close() - pass + if not self.client: + self.client.close() - async def register_vector_db( self, vector_db: VectorDB) -> None: - index = VectorDBWithIndex( + async def register_vector_db(self, vector_db: VectorDB) -> None: + index=MongoDBAtlasIndex( + client=self.client, + namespace=self.config.namespace, + embeddings_key=self.config.embeddings_key, + embedding_dimension=vector_db.embedding_dimension, + index_name=self.config.index_name, + filter_fields=self.config.filter_fields, + ) + self.cache[vector_db.identifier] = VectorDBWithIndex( vector_db=vector_db, + index=index, + inference_api=self.inference_api, + ) + + async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex: + if vector_db_id in self.cache: + return self.cache[vector_db_id] + vector_db = await self.vector_db_store.get_vector_db(vector_db_id) + self.cache[vector_db_id] = VectorDBWithIndex( + vector_db=vector_db_id, index=MongoDBAtlasIndex( client=self.client, namespace=self.config.namespace, embeddings_key=self.config.embeddings_key, + embedding_dimension=vector_db.embedding_dimension, index_name=self.config.index_name, + filter_fields=self.config.filter_fields, ), + inference_api=self.inference_api, ) - self.cache[vector_db] = index - pass + return self.cache[vector_db_id] - async def insert_chunks(self, vector_db_id, chunks, ttl_seconds = None): - index = self.cache[vector_db_id].index + async def unregister_vector_db(self, vector_db_id: str) -> None: + await self.cache[vector_db_id].index.delete() + del self.cache[vector_db_id] + + async def insert_chunks(self, + vector_db_id: str, + chunks: List[Chunk], + ttl_seconds: Optional[int] = None, + ) -> None: + index = await self._get_and_cache_vector_db_index(vector_db_id) if not index: raise ValueError(f"Vector DB {vector_db_id} not found") await index.insert_chunks(chunks) - - async def query_chunks(self, vector_db_id, query, params = None): - index = self.cache[vector_db_id].index + async def query_chunks(self, + vector_db_id: str, + query: InterleavedContent, + params: Optional[Dict[str, Any]] = None, + ) -> QueryChunksResponse: + index = await self._get_and_cache_vector_db_index(vector_db_id) if not index: raise ValueError(f"Vector DB {vector_db_id} not found") return await index.query_chunks(query, params) - - - -