mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 10:13:05 +00:00
feat: adding mongodb vector_io module
updated mongodb sample run config
This commit is contained in:
parent
bfece15fb4
commit
ee981a0c02
6 changed files with 194 additions and 67 deletions
|
@ -50,6 +50,7 @@ Here is a list of the various API providers and available distributions that can
|
|||
| NVIDIA NIM | Hosted and Single Node | | ✅ | | | |
|
||||
| Chroma | Single Node | | | ✅ | | |
|
||||
| PG Vector | Single Node | | | ✅ | | |
|
||||
| MongoDB Atlas | Hosted | | | ✅ | | |
|
||||
| PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | |
|
||||
| vLLM | Hosted and Single Node | | ✅ | | | |
|
||||
| OpenAI | Hosted | | ✅ | | | |
|
||||
|
|
35
docs/source/providers/vector_io/mongodb.md
Normal file
35
docs/source/providers/vector_io/mongodb.md
Normal file
|
@ -0,0 +1,35 @@
|
|||
---
|
||||
orphan: true
|
||||
---
|
||||
# MongoDB Atlas
|
||||
|
||||
[MongoDB Atlas](https://www.mongodb.com/atlas) is a cloud database service that can be used as a vector store provider for Llama Stack. It supports vector search capabilities through its Atlas Vector Search feature, allowing you to store and query vectors within your MongoDB database.
|
||||
|
||||
## Features
|
||||
MongoDB Atlas Vector Search supports:
|
||||
- Store embeddings and their metadata
|
||||
- Vector search with multiple algorithms (cosine similarity, euclidean distance, dot product)
|
||||
- Hybrid search (combining vector and keyword search)
|
||||
- Metadata filtering
|
||||
- Scalable vector indexing
|
||||
- Managed cloud infrastructure
|
||||
|
||||
## Usage
|
||||
|
||||
To use MongoDB Atlas in your Llama Stack project, follow these steps:
|
||||
|
||||
1. Create a MongoDB Atlas account and cluster.
|
||||
2. Configure your Atlas cluster to enable Vector Search.
|
||||
3. Configure your Llama Stack project to use MongoDB Atlas.
|
||||
4. Start storing and querying vectors.
|
||||
|
||||
## Installation
|
||||
|
||||
You can install the MongoDB Python driver using pip:
|
||||
|
||||
```bash
|
||||
pip install pymongo
|
||||
```
|
||||
|
||||
## Documentation
|
||||
See [MongoDB Atlas Vector Search documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/) for more details about vector search capabilities in MongoDB Atlas.
|
|
@ -110,6 +110,16 @@ def available_providers() -> List[ProviderSpec]:
|
|||
),
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
adapter_type="mongodb",
|
||||
pip_packages=["pymongo"],
|
||||
module="llama_stack.providers.remote.vector_io.mongodb",
|
||||
config_class="llama_stack.providers.remote.vector_io.mongodb.MongoDBVectorIOConfig",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import MongoDBVectorIOConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: MongoDBVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
||||
from .mongodb import MongoDBVectorIOAdapter
|
||||
|
||||
impl = MongoDBVectorIOAdapter(config, deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -4,26 +4,23 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MongoDBVectorIOConfig(BaseModel):
|
||||
conncetion_str: str
|
||||
namespace: str = Field(None, description="Namespace of the MongoDB collection")
|
||||
index_name: Optional[str] = Field("default", description="Name of the index in the MongoDB collection")
|
||||
filter_fields: Optional[str] = Field(None, description="Fields to filter the MongoDB collection")
|
||||
embedding_field: Optional[str] = Field("embeddings", description="Field name for the embeddings in the MongoDB collection")
|
||||
text_field: Optional[str] = Field("text", description="Field name for the text in the MongoDB collection")
|
||||
connection_str: str = Field(None, description="Connection string for the MongoDB Atlas collection")
|
||||
namespace: str = Field(None, description="Namespace i.e. db_name.collection_name of the MongoDB Atlas collection")
|
||||
index_name: Optional[str] = Field("default", description="Name of the index in the MongoDB Atlas collection")
|
||||
filter_fields: Optional[List[str]] = Field([], description="Fields to filter along side vector search in MongoDB Atlas collection")
|
||||
embeddings_key: Optional[str] = Field("embeddings", description="Field name for the embeddings in the MongoDB Atlas collection")
|
||||
text_field: Optional[str] = Field("text", description="Field name for the text in the MongoDB Atlas collection")
|
||||
|
||||
|
||||
@classmethod
|
||||
def sample_config(cls) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"connection_str": "{env.MONGODB_CONNECTION_STR}",
|
||||
"namespace": "{env.MONGODB_NAMESPACE}",
|
||||
"index_name": "{env.MONGODB_INDEX_NAME}",
|
||||
"filter_fields": "{env.MONGODB_FILTER_FIELDS}",
|
||||
"embedding_field": "{env.MONGODB_EMBEDDING_FIELD}",
|
||||
"text_field": "{env.MONGODB_TEXT_FIELD}",
|
||||
}
|
||||
}
|
|
@ -1,5 +1,3 @@
|
|||
import pymongo
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
|
@ -11,8 +9,8 @@ import logging
|
|||
from typing import Any, Dict, List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pymongo import MongoClient,
|
||||
from pymongo.operations import UpdateOne, InsertOne, DeleteOne, DeleteMany, SearchIndexModel
|
||||
from pymongo import MongoClient
|
||||
from pymongo.operations import InsertOne, SearchIndexModel, UpdateOne
|
||||
import certifi
|
||||
from numpy.typing import NDArray
|
||||
|
||||
|
@ -25,19 +23,24 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
|
||||
from .config import MongoDBVectorIOConfig
|
||||
|
||||
from .config import MongoDBAtlasVectorIOConfig
|
||||
from time import sleep
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
CHUNK_ID_KEY = "_chunk_id"
|
||||
|
||||
|
||||
class MongoDBAtlasIndex(EmbeddingIndex):
|
||||
def __init__(self, client: MongoClient, namespace: str, embeddings_key: str, index_name: str):
|
||||
|
||||
def __init__(self, client: MongoClient, namespace: str, embeddings_key: str, embedding_dimension: str, index_name: str, filter_fields: List[str]):
|
||||
self.client = client
|
||||
self.namespace = namespace
|
||||
self.embeddings_key = embeddings_key
|
||||
self.index_name = index_name
|
||||
self.filter_fields = filter_fields
|
||||
self.embedding_dimension = embedding_dimension
|
||||
|
||||
def _get_index_config(self, collection, index_name):
|
||||
idxs = list(collection.list_search_indexes())
|
||||
|
@ -45,14 +48,39 @@ class MongoDBAtlasIndex(EmbeddingIndex):
|
|||
if ele["name"] == index_name:
|
||||
return ele
|
||||
|
||||
def _get_search_index_model(self):
|
||||
index_fields = [
|
||||
{
|
||||
"path": self.embeddings_key,
|
||||
"type": "vector",
|
||||
"numDimensions": self.embedding_dimension,
|
||||
"similarity": "cosine"
|
||||
}
|
||||
]
|
||||
|
||||
if len(self.filter_fields) > 0:
|
||||
for filter_field in self.filter_fields:
|
||||
index_fields.append(
|
||||
{
|
||||
"path": filter_field,
|
||||
"type": "filter"
|
||||
}
|
||||
)
|
||||
|
||||
return SearchIndexModel(
|
||||
name=self.index_name,
|
||||
type="vectorSearch",
|
||||
definition={
|
||||
"fields": index_fields
|
||||
}
|
||||
)
|
||||
|
||||
def _check_n_create_index(self):
|
||||
client = self.client
|
||||
db,collection = self.namespace.split(".")
|
||||
db, collection = self.namespace.split(".")
|
||||
collection = client[db][collection]
|
||||
index_name = self.index_name
|
||||
print(">>>>>>>>Index name: ", index_name, "<<<<<<<<<<")
|
||||
idx = self._get_index_config(collection, index_name)
|
||||
print(idx)
|
||||
if not idx:
|
||||
log.info("Creating search index ...")
|
||||
search_index_model = self._get_search_index_model()
|
||||
|
@ -60,10 +88,10 @@ class MongoDBAtlasIndex(EmbeddingIndex):
|
|||
while True:
|
||||
idx = self._get_index_config(collection, index_name)
|
||||
if idx and idx["queryable"]:
|
||||
print("Search index created successfully.")
|
||||
log.info("Search index created successfully.")
|
||||
break
|
||||
else:
|
||||
print("Waiting for search index to be created ...")
|
||||
log.info("Waiting for search index to be created ...")
|
||||
sleep(5)
|
||||
else:
|
||||
log.info("Search index already exists.")
|
||||
|
@ -77,46 +105,53 @@ class MongoDBAtlasIndex(EmbeddingIndex):
|
|||
operations = []
|
||||
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
||||
chunk_id = f"{chunk.metadata['document_id']}:chunk-{i}"
|
||||
|
||||
operations.append(
|
||||
InsertOne(
|
||||
UpdateOne(
|
||||
{CHUNK_ID_KEY: chunk_id},
|
||||
{
|
||||
CHUNK_ID_KEY: chunk_id,
|
||||
"chunk_content": chunk.model_dump_json(),
|
||||
self.embeddings_key: embedding.tolist(),
|
||||
}
|
||||
"$set": {
|
||||
CHUNK_ID_KEY: chunk_id,
|
||||
"chunk_content": json.loads(chunk.model_dump_json()),
|
||||
self.embeddings_key: embedding.tolist(),
|
||||
}
|
||||
},
|
||||
upsert=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Perform the bulk operations
|
||||
db,collection_name = self.namespace.split(".")
|
||||
db, collection_name = self.namespace.split(".")
|
||||
collection = self.client[db][collection_name]
|
||||
collection.bulk_write(operations)
|
||||
|
||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
|
||||
print(f"Added {len(chunks)} chunks to the collection")
|
||||
# Create a search index model
|
||||
print("Creating search index ...")
|
||||
self._check_n_create_index()
|
||||
|
||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
# Perform a query
|
||||
db,collection_name = self.namespace.split(".")
|
||||
db, collection_name = self.namespace.split(".")
|
||||
collection = self.client[db][collection_name]
|
||||
|
||||
# Create vector search query
|
||||
vs_query = {"$vectorSearch":
|
||||
{
|
||||
"index": "vector_index",
|
||||
"path": self.embeddings_key,
|
||||
"queryVector": embedding.tolist(),
|
||||
"numCandidates": k,
|
||||
"limit": k,
|
||||
}
|
||||
}
|
||||
vs_query = {"$vectorSearch":
|
||||
{
|
||||
"index": self.index_name,
|
||||
"path": self.embeddings_key,
|
||||
"queryVector": embedding.tolist(),
|
||||
"numCandidates": k,
|
||||
"limit": k,
|
||||
}
|
||||
}
|
||||
# Add a field to store the score
|
||||
score_add_field_query = {
|
||||
"$addFields": {
|
||||
"score": {"$meta": "vectorSearchScore"}
|
||||
}
|
||||
}
|
||||
if score_threshold is None:
|
||||
score_threshold = 0.01
|
||||
# Filter the results based on the score threshold
|
||||
filter_query = {
|
||||
"$match": {
|
||||
|
@ -141,60 +176,90 @@ class MongoDBAtlasIndex(EmbeddingIndex):
|
|||
chunks = []
|
||||
scores = []
|
||||
for result in results:
|
||||
content = result["chunk_content"]
|
||||
chunk = Chunk(
|
||||
metadata={"document_id": result[CHUNK_ID_KEY]},
|
||||
content=json.loads(result["chunk_content"]),
|
||||
metadata=content["metadata"],
|
||||
content=content["content"],
|
||||
)
|
||||
chunks.append(chunk)
|
||||
scores.append(result["score"])
|
||||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(self, config: MongoDBAtlasVectorIOConfig, inference_api: Api.inference):
|
||||
async def delete(self):
|
||||
db, _ = self.namespace.split(".")
|
||||
self.client.drop_database(db)
|
||||
|
||||
|
||||
class MongoDBVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(self, config: MongoDBVectorIOConfig, inference_api: Api.inference):
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
|
||||
self.cache = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.client = MongoClient(
|
||||
self.config.uri,
|
||||
self.config.connection_str,
|
||||
tlsCAFile=certifi.where(),
|
||||
)
|
||||
self.cache = {}
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.client.close()
|
||||
pass
|
||||
if not self.client:
|
||||
self.client.close()
|
||||
|
||||
async def register_vector_db( self, vector_db: VectorDB) -> None:
|
||||
index = VectorDBWithIndex(
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
index=MongoDBAtlasIndex(
|
||||
client=self.client,
|
||||
namespace=self.config.namespace,
|
||||
embeddings_key=self.config.embeddings_key,
|
||||
embedding_dimension=vector_db.embedding_dimension,
|
||||
index_name=self.config.index_name,
|
||||
filter_fields=self.config.filter_fields,
|
||||
)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=index,
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
|
||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex:
|
||||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
|
||||
self.cache[vector_db_id] = VectorDBWithIndex(
|
||||
vector_db=vector_db_id,
|
||||
index=MongoDBAtlasIndex(
|
||||
client=self.client,
|
||||
namespace=self.config.namespace,
|
||||
embeddings_key=self.config.embeddings_key,
|
||||
embedding_dimension=vector_db.embedding_dimension,
|
||||
index_name=self.config.index_name,
|
||||
filter_fields=self.config.filter_fields,
|
||||
),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_db] = index
|
||||
pass
|
||||
return self.cache[vector_db_id]
|
||||
|
||||
async def insert_chunks(self, vector_db_id, chunks, ttl_seconds = None):
|
||||
index = self.cache[vector_db_id].index
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
|
||||
async def insert_chunks(self,
|
||||
vector_db_id: str,
|
||||
chunks: List[Chunk],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
|
||||
async def query_chunks(self, vector_db_id, query, params = None):
|
||||
index = self.cache[vector_db_id].index
|
||||
async def query_chunks(self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
|
||||
|
||||
|
Loading…
Add table
Add a link
Reference in a new issue