mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 02:32:40 +00:00
feat: adding mongodb vector_io module
updated mongodb sample run config
This commit is contained in:
parent
bfece15fb4
commit
ee981a0c02
6 changed files with 194 additions and 67 deletions
|
@ -50,6 +50,7 @@ Here is a list of the various API providers and available distributions that can
|
||||||
| NVIDIA NIM | Hosted and Single Node | | ✅ | | | |
|
| NVIDIA NIM | Hosted and Single Node | | ✅ | | | |
|
||||||
| Chroma | Single Node | | | ✅ | | |
|
| Chroma | Single Node | | | ✅ | | |
|
||||||
| PG Vector | Single Node | | | ✅ | | |
|
| PG Vector | Single Node | | | ✅ | | |
|
||||||
|
| MongoDB Atlas | Hosted | | | ✅ | | |
|
||||||
| PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | |
|
| PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | |
|
||||||
| vLLM | Hosted and Single Node | | ✅ | | | |
|
| vLLM | Hosted and Single Node | | ✅ | | | |
|
||||||
| OpenAI | Hosted | | ✅ | | | |
|
| OpenAI | Hosted | | ✅ | | | |
|
||||||
|
|
35
docs/source/providers/vector_io/mongodb.md
Normal file
35
docs/source/providers/vector_io/mongodb.md
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
---
|
||||||
|
orphan: true
|
||||||
|
---
|
||||||
|
# MongoDB Atlas
|
||||||
|
|
||||||
|
[MongoDB Atlas](https://www.mongodb.com/atlas) is a cloud database service that can be used as a vector store provider for Llama Stack. It supports vector search capabilities through its Atlas Vector Search feature, allowing you to store and query vectors within your MongoDB database.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
MongoDB Atlas Vector Search supports:
|
||||||
|
- Store embeddings and their metadata
|
||||||
|
- Vector search with multiple algorithms (cosine similarity, euclidean distance, dot product)
|
||||||
|
- Hybrid search (combining vector and keyword search)
|
||||||
|
- Metadata filtering
|
||||||
|
- Scalable vector indexing
|
||||||
|
- Managed cloud infrastructure
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use MongoDB Atlas in your Llama Stack project, follow these steps:
|
||||||
|
|
||||||
|
1. Create a MongoDB Atlas account and cluster.
|
||||||
|
2. Configure your Atlas cluster to enable Vector Search.
|
||||||
|
3. Configure your Llama Stack project to use MongoDB Atlas.
|
||||||
|
4. Start storing and querying vectors.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
You can install the MongoDB Python driver using pip:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install pymongo
|
||||||
|
```
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
See [MongoDB Atlas Vector Search documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/) for more details about vector search capabilities in MongoDB Atlas.
|
|
@ -110,6 +110,16 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
),
|
),
|
||||||
api_dependencies=[Api.inference],
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
Api.vector_io,
|
||||||
|
AdapterSpec(
|
||||||
|
adapter_type="mongodb",
|
||||||
|
pip_packages=["pymongo"],
|
||||||
|
module="llama_stack.providers.remote.vector_io.mongodb",
|
||||||
|
config_class="llama_stack.providers.remote.vector_io.mongodb.MongoDBVectorIOConfig",
|
||||||
|
),
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
Api.vector_io,
|
Api.vector_io,
|
||||||
AdapterSpec(
|
AdapterSpec(
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
|
from .config import MongoDBVectorIOConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: MongoDBVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
||||||
|
from .mongodb import MongoDBVectorIOAdapter
|
||||||
|
|
||||||
|
impl = MongoDBVectorIOAdapter(config, deps[Api.inference])
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
|
@ -4,26 +4,23 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional, List
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
class MongoDBVectorIOConfig(BaseModel):
|
class MongoDBVectorIOConfig(BaseModel):
|
||||||
conncetion_str: str
|
connection_str: str = Field(None, description="Connection string for the MongoDB Atlas collection")
|
||||||
namespace: str = Field(None, description="Namespace of the MongoDB collection")
|
namespace: str = Field(None, description="Namespace i.e. db_name.collection_name of the MongoDB Atlas collection")
|
||||||
index_name: Optional[str] = Field("default", description="Name of the index in the MongoDB collection")
|
index_name: Optional[str] = Field("default", description="Name of the index in the MongoDB Atlas collection")
|
||||||
filter_fields: Optional[str] = Field(None, description="Fields to filter the MongoDB collection")
|
filter_fields: Optional[List[str]] = Field([], description="Fields to filter along side vector search in MongoDB Atlas collection")
|
||||||
embedding_field: Optional[str] = Field("embeddings", description="Field name for the embeddings in the MongoDB collection")
|
embeddings_key: Optional[str] = Field("embeddings", description="Field name for the embeddings in the MongoDB Atlas collection")
|
||||||
text_field: Optional[str] = Field("text", description="Field name for the text in the MongoDB collection")
|
text_field: Optional[str] = Field("text", description="Field name for the text in the MongoDB Atlas collection")
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_config(cls) -> Dict[str, Any]:
|
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"connection_str": "{env.MONGODB_CONNECTION_STR}",
|
"connection_str": "{env.MONGODB_CONNECTION_STR}",
|
||||||
"namespace": "{env.MONGODB_NAMESPACE}",
|
"namespace": "{env.MONGODB_NAMESPACE}",
|
||||||
"index_name": "{env.MONGODB_INDEX_NAME}",
|
}
|
||||||
"filter_fields": "{env.MONGODB_FILTER_FIELDS}",
|
|
||||||
"embedding_field": "{env.MONGODB_EMBEDDING_FIELD}",
|
|
||||||
"text_field": "{env.MONGODB_TEXT_FIELD}",
|
|
||||||
}
|
|
|
@ -1,5 +1,3 @@
|
||||||
import pymongo
|
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
#
|
#
|
||||||
|
@ -11,8 +9,8 @@ import logging
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from pymongo import MongoClient,
|
from pymongo import MongoClient
|
||||||
from pymongo.operations import UpdateOne, InsertOne, DeleteOne, DeleteMany, SearchIndexModel
|
from pymongo.operations import InsertOne, SearchIndexModel, UpdateOne
|
||||||
import certifi
|
import certifi
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
|
@ -25,19 +23,24 @@ from llama_stack.providers.utils.memory.vector_store import (
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
VectorDBWithIndex,
|
VectorDBWithIndex,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .config import MongoDBVectorIOConfig
|
||||||
|
|
||||||
from .config import MongoDBAtlasVectorIOConfig
|
|
||||||
from time import sleep
|
from time import sleep
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
CHUNK_ID_KEY = "_chunk_id"
|
CHUNK_ID_KEY = "_chunk_id"
|
||||||
|
|
||||||
|
|
||||||
class MongoDBAtlasIndex(EmbeddingIndex):
|
class MongoDBAtlasIndex(EmbeddingIndex):
|
||||||
def __init__(self, client: MongoClient, namespace: str, embeddings_key: str, index_name: str):
|
|
||||||
|
def __init__(self, client: MongoClient, namespace: str, embeddings_key: str, embedding_dimension: str, index_name: str, filter_fields: List[str]):
|
||||||
self.client = client
|
self.client = client
|
||||||
self.namespace = namespace
|
self.namespace = namespace
|
||||||
self.embeddings_key = embeddings_key
|
self.embeddings_key = embeddings_key
|
||||||
self.index_name = index_name
|
self.index_name = index_name
|
||||||
|
self.filter_fields = filter_fields
|
||||||
|
self.embedding_dimension = embedding_dimension
|
||||||
|
|
||||||
def _get_index_config(self, collection, index_name):
|
def _get_index_config(self, collection, index_name):
|
||||||
idxs = list(collection.list_search_indexes())
|
idxs = list(collection.list_search_indexes())
|
||||||
|
@ -45,14 +48,39 @@ class MongoDBAtlasIndex(EmbeddingIndex):
|
||||||
if ele["name"] == index_name:
|
if ele["name"] == index_name:
|
||||||
return ele
|
return ele
|
||||||
|
|
||||||
|
def _get_search_index_model(self):
|
||||||
|
index_fields = [
|
||||||
|
{
|
||||||
|
"path": self.embeddings_key,
|
||||||
|
"type": "vector",
|
||||||
|
"numDimensions": self.embedding_dimension,
|
||||||
|
"similarity": "cosine"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
if len(self.filter_fields) > 0:
|
||||||
|
for filter_field in self.filter_fields:
|
||||||
|
index_fields.append(
|
||||||
|
{
|
||||||
|
"path": filter_field,
|
||||||
|
"type": "filter"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return SearchIndexModel(
|
||||||
|
name=self.index_name,
|
||||||
|
type="vectorSearch",
|
||||||
|
definition={
|
||||||
|
"fields": index_fields
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def _check_n_create_index(self):
|
def _check_n_create_index(self):
|
||||||
client = self.client
|
client = self.client
|
||||||
db,collection = self.namespace.split(".")
|
db, collection = self.namespace.split(".")
|
||||||
collection = client[db][collection]
|
collection = client[db][collection]
|
||||||
index_name = self.index_name
|
index_name = self.index_name
|
||||||
print(">>>>>>>>Index name: ", index_name, "<<<<<<<<<<")
|
|
||||||
idx = self._get_index_config(collection, index_name)
|
idx = self._get_index_config(collection, index_name)
|
||||||
print(idx)
|
|
||||||
if not idx:
|
if not idx:
|
||||||
log.info("Creating search index ...")
|
log.info("Creating search index ...")
|
||||||
search_index_model = self._get_search_index_model()
|
search_index_model = self._get_search_index_model()
|
||||||
|
@ -60,10 +88,10 @@ class MongoDBAtlasIndex(EmbeddingIndex):
|
||||||
while True:
|
while True:
|
||||||
idx = self._get_index_config(collection, index_name)
|
idx = self._get_index_config(collection, index_name)
|
||||||
if idx and idx["queryable"]:
|
if idx and idx["queryable"]:
|
||||||
print("Search index created successfully.")
|
log.info("Search index created successfully.")
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
print("Waiting for search index to be created ...")
|
log.info("Waiting for search index to be created ...")
|
||||||
sleep(5)
|
sleep(5)
|
||||||
else:
|
else:
|
||||||
log.info("Search index already exists.")
|
log.info("Search index already exists.")
|
||||||
|
@ -77,46 +105,53 @@ class MongoDBAtlasIndex(EmbeddingIndex):
|
||||||
operations = []
|
operations = []
|
||||||
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
||||||
chunk_id = f"{chunk.metadata['document_id']}:chunk-{i}"
|
chunk_id = f"{chunk.metadata['document_id']}:chunk-{i}"
|
||||||
|
|
||||||
operations.append(
|
operations.append(
|
||||||
InsertOne(
|
UpdateOne(
|
||||||
|
{CHUNK_ID_KEY: chunk_id},
|
||||||
{
|
{
|
||||||
CHUNK_ID_KEY: chunk_id,
|
"$set": {
|
||||||
"chunk_content": chunk.model_dump_json(),
|
CHUNK_ID_KEY: chunk_id,
|
||||||
self.embeddings_key: embedding.tolist(),
|
"chunk_content": json.loads(chunk.model_dump_json()),
|
||||||
}
|
self.embeddings_key: embedding.tolist(),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
upsert=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Perform the bulk operations
|
# Perform the bulk operations
|
||||||
db,collection_name = self.namespace.split(".")
|
db, collection_name = self.namespace.split(".")
|
||||||
collection = self.client[db][collection_name]
|
collection = self.client[db][collection_name]
|
||||||
collection.bulk_write(operations)
|
collection.bulk_write(operations)
|
||||||
|
print(f"Added {len(chunks)} chunks to the collection")
|
||||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
|
||||||
|
|
||||||
# Create a search index model
|
# Create a search index model
|
||||||
|
print("Creating search index ...")
|
||||||
self._check_n_create_index()
|
self._check_n_create_index()
|
||||||
|
|
||||||
|
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
# Perform a query
|
# Perform a query
|
||||||
db,collection_name = self.namespace.split(".")
|
db, collection_name = self.namespace.split(".")
|
||||||
collection = self.client[db][collection_name]
|
collection = self.client[db][collection_name]
|
||||||
|
|
||||||
# Create vector search query
|
# Create vector search query
|
||||||
vs_query = {"$vectorSearch":
|
vs_query = {"$vectorSearch":
|
||||||
{
|
{
|
||||||
"index": "vector_index",
|
"index": self.index_name,
|
||||||
"path": self.embeddings_key,
|
"path": self.embeddings_key,
|
||||||
"queryVector": embedding.tolist(),
|
"queryVector": embedding.tolist(),
|
||||||
"numCandidates": k,
|
"numCandidates": k,
|
||||||
"limit": k,
|
"limit": k,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
# Add a field to store the score
|
# Add a field to store the score
|
||||||
score_add_field_query = {
|
score_add_field_query = {
|
||||||
"$addFields": {
|
"$addFields": {
|
||||||
"score": {"$meta": "vectorSearchScore"}
|
"score": {"$meta": "vectorSearchScore"}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if score_threshold is None:
|
||||||
|
score_threshold = 0.01
|
||||||
# Filter the results based on the score threshold
|
# Filter the results based on the score threshold
|
||||||
filter_query = {
|
filter_query = {
|
||||||
"$match": {
|
"$match": {
|
||||||
|
@ -141,60 +176,90 @@ class MongoDBAtlasIndex(EmbeddingIndex):
|
||||||
chunks = []
|
chunks = []
|
||||||
scores = []
|
scores = []
|
||||||
for result in results:
|
for result in results:
|
||||||
|
content = result["chunk_content"]
|
||||||
chunk = Chunk(
|
chunk = Chunk(
|
||||||
metadata={"document_id": result[CHUNK_ID_KEY]},
|
metadata=content["metadata"],
|
||||||
content=json.loads(result["chunk_content"]),
|
content=content["content"],
|
||||||
)
|
)
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
scores.append(result["score"])
|
scores.append(result["score"])
|
||||||
|
|
||||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
async def delete(self):
|
||||||
def __init__(self, config: MongoDBAtlasVectorIOConfig, inference_api: Api.inference):
|
db, _ = self.namespace.split(".")
|
||||||
|
self.client.drop_database(db)
|
||||||
|
|
||||||
|
|
||||||
|
class MongoDBVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
|
def __init__(self, config: MongoDBVectorIOConfig, inference_api: Api.inference):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
|
self.cache = {}
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
self.client = MongoClient(
|
self.client = MongoClient(
|
||||||
self.config.uri,
|
self.config.connection_str,
|
||||||
tlsCAFile=certifi.where(),
|
tlsCAFile=certifi.where(),
|
||||||
)
|
)
|
||||||
self.cache = {}
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
self.client.close()
|
if not self.client:
|
||||||
pass
|
self.client.close()
|
||||||
|
|
||||||
async def register_vector_db( self, vector_db: VectorDB) -> None:
|
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||||
index = VectorDBWithIndex(
|
index=MongoDBAtlasIndex(
|
||||||
|
client=self.client,
|
||||||
|
namespace=self.config.namespace,
|
||||||
|
embeddings_key=self.config.embeddings_key,
|
||||||
|
embedding_dimension=vector_db.embedding_dimension,
|
||||||
|
index_name=self.config.index_name,
|
||||||
|
filter_fields=self.config.filter_fields,
|
||||||
|
)
|
||||||
|
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||||
vector_db=vector_db,
|
vector_db=vector_db,
|
||||||
|
index=index,
|
||||||
|
inference_api=self.inference_api,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex:
|
||||||
|
if vector_db_id in self.cache:
|
||||||
|
return self.cache[vector_db_id]
|
||||||
|
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
|
||||||
|
self.cache[vector_db_id] = VectorDBWithIndex(
|
||||||
|
vector_db=vector_db_id,
|
||||||
index=MongoDBAtlasIndex(
|
index=MongoDBAtlasIndex(
|
||||||
client=self.client,
|
client=self.client,
|
||||||
namespace=self.config.namespace,
|
namespace=self.config.namespace,
|
||||||
embeddings_key=self.config.embeddings_key,
|
embeddings_key=self.config.embeddings_key,
|
||||||
|
embedding_dimension=vector_db.embedding_dimension,
|
||||||
index_name=self.config.index_name,
|
index_name=self.config.index_name,
|
||||||
|
filter_fields=self.config.filter_fields,
|
||||||
),
|
),
|
||||||
|
inference_api=self.inference_api,
|
||||||
)
|
)
|
||||||
self.cache[vector_db] = index
|
return self.cache[vector_db_id]
|
||||||
pass
|
|
||||||
|
|
||||||
async def insert_chunks(self, vector_db_id, chunks, ttl_seconds = None):
|
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||||
index = self.cache[vector_db_id].index
|
await self.cache[vector_db_id].index.delete()
|
||||||
|
del self.cache[vector_db_id]
|
||||||
|
|
||||||
|
async def insert_chunks(self,
|
||||||
|
vector_db_id: str,
|
||||||
|
chunks: List[Chunk],
|
||||||
|
ttl_seconds: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||||
if not index:
|
if not index:
|
||||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||||
await index.insert_chunks(chunks)
|
await index.insert_chunks(chunks)
|
||||||
|
|
||||||
|
|
||||||
async def query_chunks(self, vector_db_id, query, params = None):
|
async def query_chunks(self,
|
||||||
index = self.cache[vector_db_id].index
|
vector_db_id: str,
|
||||||
|
query: InterleavedContent,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||||
if not index:
|
if not index:
|
||||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||||
return await index.query_chunks(query, params)
|
return await index.query_chunks(query, params)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue