This commit is contained in:
Young Han 2025-11-12 18:16:28 +00:00 committed by GitHub
commit fdc9ba2687
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 2066 additions and 3 deletions

View file

@ -31,7 +31,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
vector-io-provider: ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::chromadb", "remote::pgvector", "remote::weaviate", "remote::qdrant"]
vector-io-provider: ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::chromadb", "remote::pgvector", "remote::weaviate", "remote::qdrant", "remote::mongodb"]
python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }}
fail-fast: false # we want to run all tests regardless of failure
@ -101,6 +101,16 @@ jobs:
-p 6333:6333 \
qdrant/qdrant
- name: Setup MongoDB
if: matrix.vector-io-provider == 'remote::mongodb'
run: |
docker run --rm -d --pull always \
--name mongodb \
-p 27017:27017 \
-e MONGO_INITDB_ROOT_USERNAME=llamastack \
-e MONGO_INITDB_ROOT_PASSWORD=llamastack \
mongodb/mongodb-atlas-local:latest
- name: Wait for Qdrant to be ready
if: matrix.vector-io-provider == 'remote::qdrant'
run: |
@ -116,6 +126,21 @@ jobs:
docker logs qdrant
exit 1
- name: Wait for MongoDB to be ready
if: matrix.vector-io-provider == 'remote::mongodb'
run: |
echo "Waiting for MongoDB to be ready..."
for i in {1..30}; do
if docker exec mongodb mongosh --quiet --eval "db.adminCommand('ping').ok" > /dev/null 2>&1; then
echo "MongoDB is ready!"
exit 0
fi
sleep 2
done
echo "MongoDB failed to start"
docker logs mongodb
exit 1
- name: Wait for ChromaDB to be ready
if: matrix.vector-io-provider == 'remote::chromadb'
run: |
@ -170,6 +195,11 @@ jobs:
QDRANT_URL: ${{ matrix.vector-io-provider == 'remote::qdrant' && 'http://localhost:6333' || '' }}
ENABLE_WEAVIATE: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'true' || '' }}
WEAVIATE_CLUSTER_URL: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'localhost:8080' || '' }}
ENABLE_MONGODB: ${{ matrix.vector-io-provider == 'remote::mongodb' && 'true' || '' }}
MONGODB_HOST: ${{ matrix.vector-io-provider == 'remote::mongodb' && 'localhost' || '' }}
MONGODB_PORT: ${{ matrix.vector-io-provider == 'remote::mongodb' && '27017' || '' }}
MONGODB_USERNAME: ${{ matrix.vector-io-provider == 'remote::mongodb' && 'llamastack' || '' }}
MONGODB_PASSWORD: ${{ matrix.vector-io-provider == 'remote::mongodb' && 'llamastack' || '' }}
run: |
uv run --no-sync \
pytest -sv --stack-config="files=inline::localfs,inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \
@ -196,6 +226,11 @@ jobs:
run: |
docker logs qdrant > qdrant.log
- name: Write MongoDB logs to file
if: ${{ always() && matrix.vector-io-provider == 'remote::mongodb' }}
run: |
docker logs mongodb > mongodb.log
- name: Upload all logs to artifacts
if: ${{ always() }}
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0

View file

@ -0,0 +1,276 @@
---
description: |
[MongoDB Atlas](https://www.mongodb.com/products/platform/atlas-vector-search) is a remote vector database provider for Llama Stack. It
uses MongoDB Atlas Vector Search to store and query vectors in the cloud.
That means you get enterprise-grade vector search with MongoDB's scalability and reliability.
## Features
- Cloud-native vector search with MongoDB Atlas
- Fully integrated with Llama Stack
- Enterprise-grade security and scalability
- Supports multiple search modes: vector, keyword, and hybrid search
- Built-in metadata filtering and text search capabilities
- Automatic index management
## Search Modes
MongoDB Atlas Vector Search supports three different search modes:
### Vector Search
Vector search uses MongoDB's `$vectorSearch` aggregation stage to perform semantic similarity search using embedding vectors.
```python
# Vector search example
search_response = client.vector_stores.search(
vector_store_id=vector_store.id,
query="What is machine learning?",
search_mode="vector",
max_num_results=5,
)
```
### Keyword Search
Keyword search uses MongoDB's text search capabilities with full-text indexes to find chunks containing specific terms.
```python
# Keyword search example
search_response = client.vector_stores.search(
vector_store_id=vector_store.id,
query="Python programming language",
search_mode="keyword",
max_num_results=5,
)
```
### Hybrid Search
Hybrid search combines both vector and keyword search methods using configurable reranking algorithms.
```python
# Hybrid search with RRF ranker (default)
search_response = client.vector_stores.search(
vector_store_id=vector_store.id,
query="neural networks in Python",
search_mode="hybrid",
max_num_results=5,
)
# Hybrid search with weighted ranker
search_response = client.vector_stores.search(
vector_store_id=vector_store.id,
query="neural networks in Python",
search_mode="hybrid",
max_num_results=5,
ranking_options={
"ranker": {
"type": "weighted",
"alpha": 0.7, # 70% vector search, 30% keyword search
}
},
)
```
## Usage
To use MongoDB Atlas in your Llama Stack project, follow these steps:
1. Create a MongoDB Atlas cluster with Vector Search enabled
2. Install the necessary dependencies
3. Configure your Llama Stack project to use MongoDB
4. Start storing and querying vectors
## Configuration
### Environment Variables
Set up the following environment variable for your MongoDB Atlas connection:
```bash
export MONGODB_CONNECTION_STRING="mongodb+srv://username:password@cluster.mongodb.net/?retryWrites=true&w=majority&appName=llama-stack"
```
### Configuration Example
```yaml
vector_io:
- provider_id: mongodb_atlas
provider_type: remote::mongodb
config:
connection_string: "${env.MONGODB_CONNECTION_STRING}"
database_name: "llama_stack"
index_name: "vector_index"
similarity_metric: "cosine"
```
## Installation
You can install the MongoDB Python driver using pip:
```bash
pip install pymongo
```
## Documentation
See [MongoDB Atlas Vector Search documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-overview/) for more details about MongoDB Atlas Vector Search.
For general MongoDB documentation, visit [MongoDB Documentation](https://docs.mongodb.com/).
sidebar_label: Remote - Mongodb
title: remote::mongodb
---
# remote::mongodb
## Description
[MongoDB Atlas](https://www.mongodb.com/products/platform/atlas-vector-search) is a remote vector database provider for Llama Stack. It
uses MongoDB Atlas Vector Search to store and query vectors in the cloud.
That means you get enterprise-grade vector search with MongoDB's scalability and reliability.
## Features
- Cloud-native vector search with MongoDB Atlas
- Fully integrated with Llama Stack
- Enterprise-grade security and scalability
- Supports multiple search modes: vector, keyword, and hybrid search
- Built-in metadata filtering and text search capabilities
- Automatic index management
## Search Modes
MongoDB Atlas Vector Search supports three different search modes:
### Vector Search
Vector search uses MongoDB's `$vectorSearch` aggregation stage to perform semantic similarity search using embedding vectors.
```python
# Vector search example
search_response = client.vector_stores.search(
vector_store_id=vector_store.id,
query="What is machine learning?",
search_mode="vector",
max_num_results=5,
)
```
### Keyword Search
Keyword search uses MongoDB's text search capabilities with full-text indexes to find chunks containing specific terms.
```python
# Keyword search example
search_response = client.vector_stores.search(
vector_store_id=vector_store.id,
query="Python programming language",
search_mode="keyword",
max_num_results=5,
)
```
### Hybrid Search
Hybrid search combines both vector and keyword search methods using configurable reranking algorithms.
```python
# Hybrid search with RRF ranker (default)
search_response = client.vector_stores.search(
vector_store_id=vector_store.id,
query="neural networks in Python",
search_mode="hybrid",
max_num_results=5,
)
# Hybrid search with weighted ranker
search_response = client.vector_stores.search(
vector_store_id=vector_store.id,
query="neural networks in Python",
search_mode="hybrid",
max_num_results=5,
ranking_options={
"ranker": {
"type": "weighted",
"alpha": 0.7, # 70% vector search, 30% keyword search
}
},
)
```
## Usage
To use MongoDB Atlas in your Llama Stack project, follow these steps:
1. Create a MongoDB Atlas cluster with Vector Search enabled
2. Install the necessary dependencies
3. Configure your Llama Stack project to use MongoDB
4. Start storing and querying vectors
## Configuration
### Environment Variables
Set up the following environment variable for your MongoDB Atlas connection:
```bash
export MONGODB_CONNECTION_STRING="mongodb+srv://username:password@cluster.mongodb.net/?retryWrites=true&w=majority&appName=llama-stack"
```
### Configuration Example
```yaml
vector_io:
- provider_id: mongodb_atlas
provider_type: remote::mongodb
config:
connection_string: "${env.MONGODB_CONNECTION_STRING}"
database_name: "llama_stack"
index_name: "vector_index"
similarity_metric: "cosine"
```
## Installation
You can install the MongoDB Python driver using pip:
```bash
pip install pymongo
```
## Documentation
See [MongoDB Atlas Vector Search documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-overview/) for more details about MongoDB Atlas Vector Search.
For general MongoDB documentation, visit [MongoDB Documentation](https://docs.mongodb.com/).
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `connection_string` | `str \| None` | No | | MongoDB connection string (e.g., mongodb://user:pass@localhost:27017/ or mongodb+srv://user:pass@cluster.mongodb.net/) |
| `host` | `str \| None` | No | | MongoDB host (used if connection_string is not provided) |
| `port` | `int \| None` | No | | MongoDB port (used if connection_string is not provided) |
| `username` | `str \| None` | No | | MongoDB username (used if connection_string is not provided) |
| `password` | `str \| None` | No | | MongoDB password (used if connection_string is not provided) |
| `database_name` | `<class 'str'>` | No | llama_stack | Database name to use for vector collections |
| `index_name` | `<class 'str'>` | No | vector_index | Name of the vector search index |
| `path_field` | `<class 'str'>` | No | embedding | Field name for storing embeddings |
| `similarity_metric` | `<class 'str'>` | No | cosine | Similarity metric: cosine, euclidean, or dotProduct |
| `max_pool_size` | `<class 'int'>` | No | 100 | Maximum connection pool size |
| `timeout_ms` | `<class 'int'>` | No | 30000 | Connection timeout in milliseconds |
| `persistence` | `llama_stack.core.storage.datatypes.KVStoreReference \| None` | No | | Config for KV store backend for metadata storage |
## Sample Configuration
```yaml
connection_string: ${env.MONGODB_CONNECTION_STRING:=}
host: ${env.MONGODB_HOST:=localhost}
port: ${env.MONGODB_PORT:=27017}
username: ${env.MONGODB_USERNAME:=}
password: ${env.MONGODB_PASSWORD:=}
database_name: ${env.MONGODB_DATABASE_NAME:=llama_stack}
index_name: ${env.MONGODB_INDEX_NAME:=vector_index}
path_field: ${env.MONGODB_PATH_FIELD:=embedding}
similarity_metric: ${env.MONGODB_SIMILARITY_METRIC:=cosine}
max_pool_size: ${env.MONGODB_MAX_POOL_SIZE:=100}
timeout_ms: ${env.MONGODB_TIMEOUT_MS:=30000}
persistence:
namespace: vector_io::mongodb_atlas
backend: kv_default
```

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import MongoDBVectorIOConfig
async def get_adapter_impl(config: MongoDBVectorIOConfig, deps: dict[Api, ProviderSpec]):
from .mongodb import MongoDBVectorIOAdapter
# Handle the deps resolution - if files API exists, pass it, otherwise None
files_api = deps.get(Api.files)
models_api = deps.get(Api.models)
impl = MongoDBVectorIOAdapter(config, deps[Api.inference], files_api, models_api)
await impl.initialize()
return impl

View file

@ -0,0 +1,102 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.core.storage.datatypes import KVStoreReference
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class MongoDBVectorIOConfig(BaseModel):
"""Configuration for MongoDB Atlas Vector Search provider.
This provider connects to MongoDB Atlas and uses Vector Search for RAG operations.
"""
# MongoDB connection details - either connection_string or individual parameters
connection_string: str | None = Field(
default=None,
description="MongoDB connection string (e.g., mongodb://user:pass@localhost:27017/ or mongodb+srv://user:pass@cluster.mongodb.net/)",
)
host: str | None = Field(default=None, description="MongoDB host (used if connection_string is not provided)")
port: int | None = Field(default=None, description="MongoDB port (used if connection_string is not provided)")
username: str | None = Field(
default=None, description="MongoDB username (used if connection_string is not provided)"
)
password: str | None = Field(
default=None, description="MongoDB password (used if connection_string is not provided)"
)
database_name: str = Field(default="llama_stack", description="Database name to use for vector collections")
# Vector search configuration
index_name: str = Field(default="vector_index", description="Name of the vector search index")
path_field: str = Field(default="embedding", description="Field name for storing embeddings")
similarity_metric: str = Field(
default="cosine",
description="Similarity metric: cosine, euclidean, or dotProduct",
)
# Connection options
max_pool_size: int = Field(default=100, description="Maximum connection pool size")
timeout_ms: int = Field(default=30000, description="Connection timeout in milliseconds")
# KV store configuration
persistence: KVStoreReference | None = Field(
description="Config for KV store backend for metadata storage", default=None
)
def get_connection_string(self) -> str | None:
"""Build connection string from individual parameters if not provided directly.
If both connection_string and individual parameters (host/port) are provided,
individual parameters take precedence to allow test environment overrides.
"""
# Prioritize individual connection parameters over connection_string
# This allows test environments to override with MONGODB_HOST/PORT/etc
if self.host and self.port:
auth_part = ""
if self.username and self.password:
auth_part = f"{self.username}:{self.password}@"
return f"mongodb://{auth_part}{self.host}:{self.port}/"
# Fall back to connection_string if provided
if self.connection_string:
return self.connection_string
return None
@classmethod
def sample_run_config(
cls,
__distro_dir__: str,
connection_string: str = "${env.MONGODB_CONNECTION_STRING:=}",
host: str = "${env.MONGODB_HOST:=localhost}",
port: int = "${env.MONGODB_PORT:=27017}",
username: str = "${env.MONGODB_USERNAME:=}",
password: str = "${env.MONGODB_PASSWORD:=}",
database_name: str = "${env.MONGODB_DATABASE_NAME:=llama_stack}",
**kwargs: Any,
) -> dict[str, Any]:
return {
"connection_string": connection_string,
"host": host,
"port": port,
"username": username,
"password": password,
"database_name": database_name,
"index_name": "${env.MONGODB_INDEX_NAME:=vector_index}",
"path_field": "${env.MONGODB_PATH_FIELD:=embedding}",
"similarity_metric": "${env.MONGODB_SIMILARITY_METRIC:=cosine}",
"max_pool_size": "${env.MONGODB_MAX_POOL_SIZE:=100}",
"timeout_ms": "${env.MONGODB_TIMEOUT_MS:=30000}",
"persistence": KVStoreReference(
backend="kv_default",
namespace="vector_io::mongodb_atlas",
).model_dump(exclude_none=True),
}

View file

@ -0,0 +1,609 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import heapq
import time
from typing import Any
from numpy.typing import NDArray
from pymongo import MongoClient
from pymongo.collection import Collection
from pymongo.database import Database
from pymongo.operations import SearchIndexModel
from pymongo.server_api import ServerApi
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
HealthResponse,
HealthStatus,
VectorStoresProtocolPrivate,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import (
OpenAIVectorStoreMixin,
)
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
EmbeddingIndex,
VectorStoreWithIndex,
)
from llama_stack.providers.utils.vector_io.vector_utils import (
WeightedInMemoryAggregator,
sanitize_collection_name,
)
from .config import MongoDBVectorIOConfig
logger = get_logger(name=__name__, category="vector_io::mongodb")
VERSION = "v1"
VECTOR_DBS_PREFIX = f"vector_dbs:mongodb:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:mongodb:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:mongodb:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:mongodb:{VERSION}::"
class MongoDBIndex(EmbeddingIndex):
"""MongoDB Atlas Vector Search index implementation optimized for RAG."""
def __init__(
self,
vector_store: VectorStore,
collection: Collection,
config: MongoDBVectorIOConfig,
):
self.vector_store = vector_store
self.collection = collection
self.config = config
self.dimension = vector_store.embedding_dimension
async def initialize(self) -> None:
"""Initialize the MongoDB collection and ensure vector search index exists."""
try:
# Create the collection if it doesn't exist
collection_names = self.collection.database.list_collection_names()
if self.collection.name not in collection_names:
logger.info(f"Creating collection '{self.collection.name}'")
# Create collection by inserting a dummy document
dummy_doc = {"_id": "__dummy__", "dummy": True}
self.collection.insert_one(dummy_doc)
# Remove the dummy document
self.collection.delete_one({"_id": "__dummy__"})
logger.info(f"Collection '{self.collection.name}' created successfully")
# Create optimized vector search index for RAG
await self._create_vector_search_index()
# Create text index for hybrid search
await self._ensure_text_index()
except Exception as e:
logger.exception(
f"Failed to initialize MongoDB index for vector_store: {self.vector_store.identifier}. "
f"Collection name: {self.collection.name}. Error: {str(e)}"
)
# Don't fail completely - just log the error and continue
logger.warning(
"Continuing without complete index initialization. "
"You may need to create indexes manually in MongoDB Atlas dashboard."
)
async def _create_vector_search_index(self) -> None:
"""Create optimized vector search index based on MongoDB RAG best practices."""
try:
# Check if vector search index exists
indexes = list(self.collection.list_search_indexes())
index_exists = any(idx.get("name") == self.config.index_name for idx in indexes)
if not index_exists:
# Create vector search index optimized for RAG
# Based on MongoDB's RAG example using new vectorSearch format
search_index_model = SearchIndexModel(
definition={
"fields": [
{
"type": "vector",
"numDimensions": self.dimension,
"path": self.config.path_field,
"similarity": self._convert_similarity_metric(self.config.similarity_metric),
}
]
},
name=self.config.index_name,
type="vectorSearch",
)
logger.info(
f"Creating vector search index '{self.config.index_name}' for RAG on collection '{self.collection.name}'"
)
self.collection.create_search_index(model=search_index_model)
# Wait for index to be ready (like in MongoDB RAG example)
await self._wait_for_index_ready()
logger.info("Vector search index created and ready for RAG queries")
except Exception as e:
logger.warning(f"Failed to create vector search index: {e}")
def _convert_similarity_metric(self, metric: str) -> str:
"""Convert internal similarity metric to MongoDB Atlas format."""
metric_map = {
"cosine": "cosine",
"euclidean": "euclidean",
"dotProduct": "dotProduct",
"dot_product": "dotProduct",
}
return metric_map.get(metric, "cosine")
async def _wait_for_index_ready(self) -> None:
"""Wait for the vector search index to be ready, based on MongoDB RAG example."""
logger.info("Waiting for vector search index to be ready...")
max_wait_time = 300 # 5 minutes max wait
wait_interval = 5
elapsed_time = 0
while elapsed_time < max_wait_time:
try:
indices = list(self.collection.list_search_indexes(self.config.index_name))
if len(indices) and indices[0].get("queryable") is True:
logger.info(f"Vector search index '{self.config.index_name}' is ready for querying")
return
except Exception:
pass
time.sleep(wait_interval)
elapsed_time += wait_interval
logger.warning(f"Vector search index may not be fully ready after {max_wait_time}s")
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray) -> None:
"""Add chunks with embeddings to MongoDB collection optimized for RAG."""
if len(chunks) != len(embeddings):
raise ValueError(f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}")
documents = []
for i, chunk in enumerate(chunks):
# Structure document for optimal RAG retrieval
doc = {
"_id": chunk.chunk_id,
"chunk_id": chunk.chunk_id,
"text": interleaved_content_as_str(chunk.content), # Key field for RAG context
"content": interleaved_content_as_str(chunk.content), # Backward compatibility
"metadata": chunk.metadata or {},
"chunk_metadata": (chunk.chunk_metadata.model_dump() if chunk.chunk_metadata else {}),
self.config.path_field: embeddings[i].tolist(), # Vector embeddings
"document": chunk.model_dump(), # Full chunk data
}
documents.append(doc)
try:
# Use upsert behavior for chunks
for doc in documents:
self.collection.replace_one({"_id": doc["_id"]}, doc, upsert=True)
logger.debug(f"Successfully added {len(chunks)} chunks optimized for RAG to MongoDB collection")
except Exception as e:
logger.exception(f"Failed to add chunks to MongoDB collection: {e}")
raise
async def query_vector(
self,
embedding: NDArray,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
"""Perform vector similarity search optimized for RAG based on MongoDB example."""
try:
# Use MongoDB's vector search aggregation pipeline optimized for RAG
pipeline = [
{
"$vectorSearch": {
"index": self.config.index_name,
"queryVector": embedding.tolist(),
"path": self.config.path_field,
"numCandidates": min(k * 10, 1000), # Cap at 1000 to prevent excessive candidates
"limit": k,
}
},
{
"$project": {
"_id": 0,
"text": 1, # Primary field for RAG context
"content": 1, # Backward compatibility
"metadata": 1,
"chunk_metadata": 1,
"document": 1,
"score": {"$meta": "vectorSearchScore"},
}
},
{"$match": {"score": {"$gte": score_threshold}}},
]
results = list(self.collection.aggregate(pipeline))
chunks = []
scores = []
for result in results:
score = result.get("score", 0.0)
if score >= score_threshold:
chunk_data = result.get("document", {})
if chunk_data:
chunks.append(Chunk(**chunk_data))
scores.append(float(score))
logger.debug(f"Vector search for RAG returned {len(chunks)} results")
return QueryChunksResponse(chunks=chunks, scores=scores)
except Exception as e:
logger.exception(f"Vector search for RAG failed: {e}")
raise RuntimeError(f"Vector search for RAG failed: {e}") from e
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
"""Perform text search using MongoDB's text search for RAG context retrieval."""
try:
# Ensure text index exists
await self._ensure_text_index()
pipeline: list[dict[str, Any]] = [
{"$match": {"$text": {"$search": query_string}}},
{
"$project": {
"_id": 0,
"text": 1, # Primary field for RAG context
"content": 1, # Backward compatibility
"metadata": 1,
"chunk_metadata": 1,
"document": 1,
"score": {"$meta": "textScore"},
}
},
{"$match": {"score": {"$gte": score_threshold}}},
{"$sort": {"score": {"$meta": "textScore"}}},
{"$limit": k},
]
results = list(self.collection.aggregate(pipeline))
chunks = []
scores = []
for result in results:
score = result.get("score", 0.0)
if score >= score_threshold:
chunk_data = result.get("document", {})
if chunk_data:
chunks.append(Chunk(**chunk_data))
scores.append(float(score))
logger.debug(f"Keyword search for RAG returned {len(chunks)} results")
return QueryChunksResponse(chunks=chunks, scores=scores)
except Exception as e:
logger.exception(f"Keyword search for RAG failed: {e}")
raise RuntimeError(f"Keyword search for RAG failed: {e}") from e
async def query_hybrid(
self,
embedding: NDArray,
query_string: str,
k: int,
score_threshold: float,
reranker_type: str,
reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
"""Perform hybrid search for enhanced RAG context retrieval."""
if reranker_params is None:
reranker_params = {}
# Get results from both search methods
vector_response = await self.query_vector(embedding, k, 0.0)
keyword_response = await self.query_keyword(query_string, k, 0.0)
# Convert responses to score dictionaries
vector_scores = {
chunk.chunk_id: score for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False)
}
keyword_scores = {
chunk.chunk_id: score
for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False)
}
# Combine scores using the reranking utility
combined_scores = WeightedInMemoryAggregator.combine_search_results(
vector_scores, keyword_scores, reranker_type, reranker_params
)
# Get top-k results
top_k_items = heapq.nlargest(k, combined_scores.items(), key=lambda x: x[1])
# Filter by score threshold
filtered_items = [(doc_id, score) for doc_id, score in top_k_items if score >= score_threshold]
# Create chunk map
chunk_map = {c.chunk_id: c for c in vector_response.chunks + keyword_response.chunks}
# Build final results
chunks = []
scores = []
for doc_id, score in filtered_items:
if doc_id in chunk_map:
chunks.append(chunk_map[doc_id])
scores.append(score)
logger.debug(f"Hybrid search for RAG returned {len(chunks)} results")
return QueryChunksResponse(chunks=chunks, scores=scores)
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete chunks from MongoDB collection."""
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
try:
result = self.collection.delete_many({"_id": {"$in": chunk_ids}})
logger.debug(f"Deleted {result.deleted_count} chunks from MongoDB collection")
except Exception as e:
logger.exception(f"Failed to delete chunks: {e}")
raise
async def delete(self) -> None:
"""Delete the entire collection."""
try:
self.collection.drop()
logger.debug(f"Dropped MongoDB collection: {self.collection.name}")
except Exception as e:
logger.exception(f"Failed to drop collection: {e}")
raise
async def _ensure_text_index(self) -> None:
"""Ensure text search index exists on content fields for RAG."""
try:
indexes = list(self.collection.list_indexes())
text_index_exists = any(
any(key.startswith(("content", "text")) for key in idx.get("key", {}).keys())
and idx.get("textIndexVersion") is not None
for idx in indexes
)
if not text_index_exists:
logger.info("Creating text search index on content fields for RAG")
# Index both 'text' and 'content' fields for comprehensive text search
self.collection.create_index([("text", "text"), ("content", "text")])
logger.info("Text search index created successfully for RAG")
except Exception as e:
logger.warning(f"Failed to create text index for RAG: {e}")
class MongoDBVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
"""MongoDB Atlas Vector Search adapter for Llama Stack optimized for RAG workflows."""
def __init__(
self,
config: MongoDBVectorIOConfig,
inference_api,
files_api=None,
models_api=None,
) -> None:
# Handle the case where files_api might be a ProviderSpec that needs resolution
resolved_files_api = files_api
super().__init__(files_api=resolved_files_api, kvstore=None)
self.config = config
self.inference_api = inference_api
self.models_api = models_api
self.client: MongoClient | None = None
self.database: Database | None = None
self.cache: dict[str, VectorStoreWithIndex] = {}
self.kvstore: KVStore | None = None
async def initialize(self) -> None:
"""Initialize MongoDB connection optimized for RAG workflows."""
logger.info("Initializing MongoDB Atlas Vector IO adapter for RAG")
try:
# Initialize KV store for metadata
if self.config.persistence:
self.kvstore = await kvstore_impl(self.config.persistence)
# Skip MongoDB connection if no connection string provided
# This allows other providers to work without MongoDB credentials
if not self.config.connection_string:
logger.warning(
"MongoDB connection_string not provided. "
"MongoDB vector store will not be available until credentials are configured."
)
return
# Connect to MongoDB with optimized settings for RAG
self.client = MongoClient(
self.config.connection_string,
server_api=ServerApi("1"),
maxPoolSize=self.config.max_pool_size,
serverSelectionTimeoutMS=self.config.timeout_ms,
# Additional settings for RAG performance
retryWrites=True,
readPreference="primaryPreferred",
)
# Test connection
self.client.admin.command("ping")
logger.info("Successfully connected to MongoDB Atlas for RAG")
# Get database
self.database = self.client[self.config.database_name]
# Initialize OpenAI vector stores
await self.initialize_openai_vector_stores()
# Load existing vector databases
await self._load_existing_vector_dbs()
logger.info("MongoDB Atlas Vector IO adapter for RAG initialized successfully")
except Exception as e:
logger.exception("Failed to initialize MongoDB Atlas Vector IO adapter for RAG")
raise RuntimeError("Failed to initialize MongoDB Atlas Vector IO adapter for RAG") from e
async def shutdown(self) -> None:
"""Shutdown MongoDB connection."""
if self.client:
self.client.close()
logger.info("MongoDB Atlas RAG connection closed")
async def health(self) -> HealthResponse:
"""Perform health check on MongoDB connection."""
try:
if self.client:
self.client.admin.command("ping")
return HealthResponse(status=HealthStatus.OK)
else:
return HealthResponse(status=HealthStatus.ERROR, message="MongoDB client not initialized")
except Exception as e:
return HealthResponse(
status=HealthStatus.ERROR,
message=f"MongoDB RAG health check failed: {str(e)}",
)
async def register_vector_store(self, vector_store: VectorStore) -> None:
"""Register a new vector store optimized for RAG."""
if self.database is None:
raise RuntimeError("MongoDB database not initialized")
# Create collection name from vector store identifier
collection_name = sanitize_collection_name(vector_store.identifier)
collection = self.database[collection_name]
# Create and initialize MongoDB index optimized for RAG
mongodb_index = MongoDBIndex(vector_store, collection, self.config)
await mongodb_index.initialize()
# Create vector store with index wrapper
vector_store_with_index = VectorStoreWithIndex(
vector_store=vector_store,
index=mongodb_index,
inference_api=self.inference_api,
)
# Cache the vector store
self.cache[vector_store.identifier] = vector_store_with_index
# Save vector store info to KVStore for persistence
if self.kvstore:
await self.kvstore.set(
f"{VECTOR_DBS_PREFIX}{vector_store.identifier}",
vector_store.model_dump_json(),
)
logger.info(f"Registered vector store for RAG: {vector_store.identifier}")
async def unregister_vector_store(self, vector_store_id: str) -> None:
"""Unregister a vector store."""
if vector_store_id in self.cache:
await self.cache[vector_store_id].index.delete()
del self.cache[vector_store_id]
# Clean up from KV store
if self.kvstore:
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_store_id}")
logger.info(f"Unregistered vector store: {vector_store_id}")
async def insert_chunks(
self,
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
"""Insert chunks into the vector database optimized for RAG."""
vector_db_with_index = await self._get_vector_db_index(vector_db_id)
await vector_db_with_index.insert_chunks(chunks)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
"""Query chunks from the vector database optimized for RAG context retrieval."""
vector_db_with_index = await self._get_vector_db_index(vector_db_id)
return await vector_db_with_index.query_chunks(query, params)
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete chunks from the vector database."""
vector_db_with_index = await self._get_vector_db_index(store_id)
await vector_db_with_index.index.delete_chunks(chunks_for_deletion)
async def _get_vector_db_index(self, vector_db_id: str) -> VectorStoreWithIndex:
"""Get vector store index from cache."""
if vector_db_id in self.cache:
return self.cache[vector_db_id]
raise VectorStoreNotFoundError(vector_db_id)
async def _load_existing_vector_dbs(self) -> None:
"""Load existing vector databases from KVStore."""
if not self.kvstore:
return
try:
# Use keys_in_range to get all vector database keys from KVStore
# This searches for keys with the prefix by using range scan
start_key = VECTOR_DBS_PREFIX
# Create an end key by incrementing the last character
end_key = VECTOR_DBS_PREFIX[:-1] + chr(ord(VECTOR_DBS_PREFIX[-1]) + 1)
vector_db_keys = await self.kvstore.keys_in_range(start_key, end_key)
for key in vector_db_keys:
try:
vector_store_data = await self.kvstore.get(key)
if vector_store_data:
import json
vector_store = VectorStore(**json.loads(vector_store_data))
# Register the vector store without re-initializing
await self._register_existing_vector_store(vector_store)
logger.info(f"Loaded existing RAG-optimized vector store: {vector_store.identifier}")
except Exception as e:
logger.warning(f"Failed to load vector store from key {key}: {e}")
continue
except Exception as e:
logger.warning(f"Failed to load existing vector stores: {e}")
async def _register_existing_vector_store(self, vector_store: VectorStore) -> None:
"""Register an existing vector store without re-initialization."""
if self.database is None:
raise RuntimeError("MongoDB database not initialized")
# Create collection name from vector store identifier
collection_name = sanitize_collection_name(vector_store.identifier)
collection = self.database[collection_name]
# Create MongoDB index without initialization (collection already exists)
mongodb_index = MongoDBIndex(vector_store, collection, self.config)
# Create vector store with index wrapper
vector_store_with_index = VectorStoreWithIndex(
vector_store=vector_store,
index=mongodb_index,
inference_api=self.inference_api,
)
# Cache the vector store
self.cache[vector_store.identifier] = vector_store_with_index

View file

@ -25,6 +25,7 @@ distribution_spec:
- provider_type: inline::milvus
- provider_type: remote::chromadb
- provider_type: remote::pgvector
- provider_type: remote::mongodb
- provider_type: remote::qdrant
- provider_type: remote::weaviate
files:

View file

@ -131,6 +131,23 @@ providers:
persistence:
namespace: vector_io::pgvector
backend: kv_default
- provider_id: ${env.MONGODB_CONNECTION_STRING:+mongodb_atlas}
provider_type: remote::mongodb
config:
connection_string: ${env.MONGODB_CONNECTION_STRING:=}
host: ${env.MONGODB_HOST:=localhost}
port: ${env.MONGODB_PORT:=27017}
username: ${env.MONGODB_USERNAME:=}
password: ${env.MONGODB_PASSWORD:=}
database_name: ${env.MONGODB_DATABASE_NAME:=llama_stack}
index_name: ${env.MONGODB_INDEX_NAME:=vector_index}
path_field: ${env.MONGODB_PATH_FIELD:=embedding}
similarity_metric: ${env.MONGODB_SIMILARITY_METRIC:=cosine}
max_pool_size: ${env.MONGODB_MAX_POOL_SIZE:=100}
timeout_ms: ${env.MONGODB_TIMEOUT_MS:=30000}
persistence:
namespace: vector_io::mongodb_atlas
backend: kv_default
- provider_id: ${env.QDRANT_URL:+qdrant}
provider_type: remote::qdrant
config:

View file

@ -26,6 +26,7 @@ distribution_spec:
- provider_type: inline::milvus
- provider_type: remote::chromadb
- provider_type: remote::pgvector
- provider_type: remote::mongodb
- provider_type: remote::qdrant
- provider_type: remote::weaviate
files:

View file

@ -131,6 +131,23 @@ providers:
persistence:
namespace: vector_io::pgvector
backend: kv_default
- provider_id: ${env.MONGODB_CONNECTION_STRING:+mongodb_atlas}
provider_type: remote::mongodb
config:
connection_string: ${env.MONGODB_CONNECTION_STRING:=}
host: ${env.MONGODB_HOST:=localhost}
port: ${env.MONGODB_PORT:=27017}
username: ${env.MONGODB_USERNAME:=}
password: ${env.MONGODB_PASSWORD:=}
database_name: ${env.MONGODB_DATABASE_NAME:=llama_stack}
index_name: ${env.MONGODB_INDEX_NAME:=vector_index}
path_field: ${env.MONGODB_PATH_FIELD:=embedding}
similarity_metric: ${env.MONGODB_SIMILARITY_METRIC:=cosine}
max_pool_size: ${env.MONGODB_MAX_POOL_SIZE:=100}
timeout_ms: ${env.MONGODB_TIMEOUT_MS:=30000}
persistence:
namespace: vector_io::mongodb_atlas
backend: kv_default
- provider_id: ${env.QDRANT_URL:+qdrant}
provider_type: remote::qdrant
config:

View file

@ -26,6 +26,7 @@ distribution_spec:
- provider_type: inline::milvus
- provider_type: remote::chromadb
- provider_type: remote::pgvector
- provider_type: remote::mongodb
- provider_type: remote::qdrant
- provider_type: remote::weaviate
files:

View file

@ -131,6 +131,23 @@ providers:
persistence:
namespace: vector_io::pgvector
backend: kv_default
- provider_id: ${env.MONGODB_CONNECTION_STRING:+mongodb_atlas}
provider_type: remote::mongodb
config:
connection_string: ${env.MONGODB_CONNECTION_STRING:=}
host: ${env.MONGODB_HOST:=localhost}
port: ${env.MONGODB_PORT:=27017}
username: ${env.MONGODB_USERNAME:=}
password: ${env.MONGODB_PASSWORD:=}
database_name: ${env.MONGODB_DATABASE_NAME:=llama_stack}
index_name: ${env.MONGODB_INDEX_NAME:=vector_index}
path_field: ${env.MONGODB_PATH_FIELD:=embedding}
similarity_metric: ${env.MONGODB_SIMILARITY_METRIC:=cosine}
max_pool_size: ${env.MONGODB_MAX_POOL_SIZE:=100}
timeout_ms: ${env.MONGODB_TIMEOUT_MS:=30000}
persistence:
namespace: vector_io::mongodb_atlas
backend: kv_default
- provider_id: ${env.QDRANT_URL:+qdrant}
provider_type: remote::qdrant
config:

View file

@ -36,11 +36,14 @@ from llama_stack.providers.inline.vector_io.sqlite_vec.config import (
)
from llama_stack.providers.registry.inference import available_providers
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
from llama_stack.providers.remote.vector_io.mongodb.config import MongoDBVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector.config import (
PGVectorVectorIOConfig,
)
from llama_stack.providers.remote.vector_io.qdrant.config import QdrantVectorIOConfig
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
from llama_stack.providers.remote.vector_io.weaviate.config import (
WeaviateVectorIOConfig,
)
from llama_stack.providers.utils.kvstore.config import PostgresKVStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
@ -124,6 +127,7 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
BuildProvider(provider_type="inline::milvus"),
BuildProvider(provider_type="remote::chromadb"),
BuildProvider(provider_type="remote::pgvector"),
BuildProvider(provider_type="remote::mongodb"),
BuildProvider(provider_type="remote::qdrant"),
BuildProvider(provider_type="remote::weaviate"),
],
@ -254,7 +258,70 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
additional_pip_packages=list(set(PostgresSqlStoreConfig.pip_packages() + PostgresKVStoreConfig.pip_packages())),
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides=default_overrides,
provider_overrides={
"inference": remote_inference_providers + [embedding_provider],
"vector_io": [
Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
),
Provider(
provider_id="sqlite-vec",
provider_type="inline::sqlite-vec",
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
),
Provider(
provider_id="${env.MILVUS_URL:+milvus}",
provider_type="inline::milvus",
config=MilvusVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
),
Provider(
provider_id="${env.CHROMADB_URL:+chromadb}",
provider_type="remote::chromadb",
config=ChromaVectorIOConfig.sample_run_config(
f"~/.llama/distributions/{name}/",
url="${env.CHROMADB_URL:=}",
),
),
Provider(
provider_id="${env.PGVECTOR_DB:+pgvector}",
provider_type="remote::pgvector",
config=PGVectorVectorIOConfig.sample_run_config(
f"~/.llama/distributions/{name}",
db="${env.PGVECTOR_DB:=}",
user="${env.PGVECTOR_USER:=}",
password="${env.PGVECTOR_PASSWORD:=}",
),
),
Provider(
provider_id="${env.MONGODB_CONNECTION_STRING:+mongodb_atlas}",
provider_type="remote::mongodb",
config=MongoDBVectorIOConfig.sample_run_config(
f"~/.llama/distributions/{name}",
connection_string="${env.MONGODB_CONNECTION_STRING:=}",
database_name="${env.MONGODB_DATABASE_NAME:=llama_stack}",
),
),
Provider(
provider_id="${env.QDRANT_URL:+qdrant}",
provider_type="remote::qdrant",
config=QdrantVectorIOConfig.sample_run_config(
f"~/.llama/distributions/{name}",
url="${env.QDRANT_URL:=}",
),
),
Provider(
provider_id="${env.WEAVIATE_CLUSTER_URL:+weaviate}",
provider_type="remote::weaviate",
config=WeaviateVectorIOConfig.sample_run_config(
f"~/.llama/distributions/{name}",
cluster_url="${env.WEAVIATE_CLUSTER_URL:=}",
),
),
],
"files": [files_provider],
},
default_models=[],
default_tool_groups=default_tool_groups,
default_shields=default_shields,
@ -384,5 +451,13 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
"azure",
"Azure API Type",
),
"MONGODB_CONNECTION_STRING": (
"",
"MongoDB Atlas connection string (e.g., mongodb+srv://user:pass@cluster.mongodb.net/)",
),
"MONGODB_DATABASE_NAME": (
"llama_stack",
"MongoDB database name",
),
},
)

View file

@ -823,6 +823,132 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
optional_api_dependencies=[Api.files, Api.models],
description="""
Please refer to the remote provider documentation.
""",
),
RemoteProviderSpec(
api=Api.vector_io,
adapter_type="mongodb",
provider_type="remote::mongodb",
pip_packages=["pymongo>=4.0.0"],
module="llama_stack.providers.remote.vector_io.mongodb",
config_class="llama_stack.providers.remote.vector_io.mongodb.MongoDBVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
description="""
[MongoDB Atlas](https://www.mongodb.com/products/platform/atlas-vector-search) is a remote vector database provider for Llama Stack. It
uses MongoDB Atlas Vector Search to store and query vectors in the cloud.
That means you get enterprise-grade vector search with MongoDB's scalability and reliability.
## Features
- Cloud-native vector search with MongoDB Atlas
- Fully integrated with Llama Stack
- Enterprise-grade security and scalability
- Supports multiple search modes: vector, keyword, and hybrid search
- Built-in metadata filtering and text search capabilities
- Automatic index management
## Search Modes
MongoDB Atlas Vector Search supports three different search modes:
### Vector Search
Vector search uses MongoDB's `$vectorSearch` aggregation stage to perform semantic similarity search using embedding vectors.
```python
# Vector search example
search_response = client.vector_stores.search(
vector_store_id=vector_store.id,
query="What is machine learning?",
search_mode="vector",
max_num_results=5,
)
```
### Keyword Search
Keyword search uses MongoDB's text search capabilities with full-text indexes to find chunks containing specific terms.
```python
# Keyword search example
search_response = client.vector_stores.search(
vector_store_id=vector_store.id,
query="Python programming language",
search_mode="keyword",
max_num_results=5,
)
```
### Hybrid Search
Hybrid search combines both vector and keyword search methods using configurable reranking algorithms.
```python
# Hybrid search with RRF ranker (default)
search_response = client.vector_stores.search(
vector_store_id=vector_store.id,
query="neural networks in Python",
search_mode="hybrid",
max_num_results=5,
)
# Hybrid search with weighted ranker
search_response = client.vector_stores.search(
vector_store_id=vector_store.id,
query="neural networks in Python",
search_mode="hybrid",
max_num_results=5,
ranking_options={
"ranker": {
"type": "weighted",
"alpha": 0.7, # 70% vector search, 30% keyword search
}
},
)
```
## Usage
To use MongoDB Atlas in your Llama Stack project, follow these steps:
1. Create a MongoDB Atlas cluster with Vector Search enabled
2. Install the necessary dependencies
3. Configure your Llama Stack project to use MongoDB
4. Start storing and querying vectors
## Configuration
### Environment Variables
Set up the following environment variable for your MongoDB Atlas connection:
```bash
export MONGODB_CONNECTION_STRING="mongodb+srv://username:password@cluster.mongodb.net/?retryWrites=true&w=majority&appName=llama-stack"
```
### Configuration Example
```yaml
vector_io:
- provider_id: mongodb_atlas
provider_type: remote::mongodb
config:
connection_string: "${env.MONGODB_CONNECTION_STRING}"
database_name: "llama_stack"
index_name: "vector_index"
similarity_metric: "cosine"
```
## Installation
You can install the MongoDB Python driver using pip:
```bash
pip install pymongo
```
## Documentation
See [MongoDB Atlas Vector Search documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-overview/) for more details about MongoDB Atlas Vector Search.
For general MongoDB documentation, visit [MongoDB Documentation](https://docs.mongodb.com/).
""",
),
]

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import MongoDBVectorIOConfig
async def get_adapter_impl(config: MongoDBVectorIOConfig, deps: dict[Api, ProviderSpec]):
from .mongodb import MongoDBVectorIOAdapter
# Handle the deps resolution - if files API exists, pass it, otherwise None
files_api = deps.get(Api.files)
models_api = deps.get(Api.models)
impl = MongoDBVectorIOAdapter(config, deps[Api.inference], files_api, models_api)
await impl.initialize()
return impl

View file

@ -0,0 +1,110 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.core.storage.datatypes import KVStoreReference
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class MongoDBVectorIOConfig(BaseModel):
"""Configuration for MongoDB Atlas Vector Search provider.
This provider connects to MongoDB Atlas and uses Vector Search for RAG operations.
"""
# MongoDB connection details - either connection_string or individual parameters
connection_string: str | None = Field(
default=None,
description="MongoDB connection string (e.g., mongodb://user:pass@localhost:27017/ or mongodb+srv://user:pass@cluster.mongodb.net/)",
)
host: str | None = Field(
default=None,
description="MongoDB host (used if connection_string is not provided)",
)
port: int | None = Field(
default=None,
description="MongoDB port (used if connection_string is not provided)",
)
username: str | None = Field(
default=None,
description="MongoDB username (used if connection_string is not provided)",
)
password: str | None = Field(
default=None,
description="MongoDB password (used if connection_string is not provided)",
)
database_name: str = Field(default="llama_stack", description="Database name to use for vector collections")
# Vector search configuration
index_name: str = Field(default="vector_index", description="Name of the vector search index")
path_field: str = Field(default="embedding", description="Field name for storing embeddings")
similarity_metric: str = Field(
default="cosine",
description="Similarity metric: cosine, euclidean, or dotProduct",
)
# Connection options
max_pool_size: int = Field(default=100, description="Maximum connection pool size")
timeout_ms: int = Field(default=30000, description="Connection timeout in milliseconds")
# KV store configuration
persistence: KVStoreReference | None = Field(
description="Config for KV store backend for metadata storage", default=None
)
def get_connection_string(self) -> str | None:
"""Build connection string from individual parameters if not provided directly.
If both connection_string and individual parameters (host/port) are provided,
individual parameters take precedence to allow test environment overrides.
"""
# Prioritize individual connection parameters over connection_string
# This allows test environments to override with MONGODB_HOST/PORT/etc
if self.host and self.port:
auth_part = ""
if self.username and self.password:
auth_part = f"{self.username}:{self.password}@"
return f"mongodb://{auth_part}{self.host}:{self.port}/"
# Fall back to connection_string if provided
if self.connection_string:
return self.connection_string
return None
@classmethod
def sample_run_config(
cls,
__distro_dir__: str,
connection_string: str = "${env.MONGODB_CONNECTION_STRING:=}",
host: str = "${env.MONGODB_HOST:=localhost}",
port: str = "${env.MONGODB_PORT:=27017}",
username: str = "${env.MONGODB_USERNAME:=}",
password: str = "${env.MONGODB_PASSWORD:=}",
database_name: str = "${env.MONGODB_DATABASE_NAME:=llama_stack}",
**kwargs: Any,
) -> dict[str, Any]:
return {
"connection_string": connection_string,
"host": host,
"port": port,
"username": username,
"password": password,
"database_name": database_name,
"index_name": "${env.MONGODB_INDEX_NAME:=vector_index}",
"path_field": "${env.MONGODB_PATH_FIELD:=embedding}",
"similarity_metric": "${env.MONGODB_SIMILARITY_METRIC:=cosine}",
"max_pool_size": "${env.MONGODB_MAX_POOL_SIZE:=100}",
"timeout_ms": "${env.MONGODB_TIMEOUT_MS:=30000}",
"persistence": KVStoreReference(
backend="kv_default",
namespace="vector_io::mongodb_atlas",
).model_dump(exclude_none=True),
}

View file

@ -0,0 +1,631 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import heapq
import time
from typing import Any
from numpy.typing import NDArray
from pymongo import MongoClient
from pymongo.collection import Collection
from pymongo.database import Database
from pymongo.operations import SearchIndexModel
from pymongo.server_api import ServerApi
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
HealthResponse,
HealthStatus,
VectorStoresProtocolPrivate,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import (
OpenAIVectorStoreMixin,
)
from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion,
EmbeddingIndex,
VectorStoreWithIndex,
)
from llama_stack.providers.utils.vector_io.vector_utils import (
WeightedInMemoryAggregator,
sanitize_collection_name,
)
from .config import MongoDBVectorIOConfig
logger = get_logger(name=__name__, category="vector_io::mongodb")
VERSION = "v1"
VECTOR_DBS_PREFIX = f"vector_dbs:mongodb:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:mongodb:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:mongodb:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:mongodb:{VERSION}::"
class MongoDBIndex(EmbeddingIndex):
"""MongoDB Atlas Vector Search index implementation optimized for RAG."""
def __init__(
self,
vector_store: VectorStore,
collection: Collection,
config: MongoDBVectorIOConfig,
):
self.vector_store = vector_store
self.collection = collection
self.config = config
self.dimension = vector_store.embedding_dimension
async def initialize(self) -> None:
"""Initialize the MongoDB collection and ensure vector search index exists."""
try:
# Create the collection if it doesn't exist
collection_names = self.collection.database.list_collection_names()
if self.collection.name not in collection_names:
logger.info(f"Creating collection '{self.collection.name}'")
# Create collection by inserting a dummy document
dummy_doc = {"_id": "__dummy__", "dummy": True}
self.collection.insert_one(dummy_doc)
# Remove the dummy document
self.collection.delete_one({"_id": "__dummy__"})
logger.info(f"Collection '{self.collection.name}' created successfully")
# Create optimized vector search index for RAG
await self._create_vector_search_index()
# Create text index for hybrid search
await self._ensure_text_index()
except Exception as e:
logger.exception(
f"Failed to initialize MongoDB index for vector_store: {self.vector_store.identifier}. "
f"Collection name: {self.collection.name}. Error: {str(e)}"
)
# Don't fail completely - just log the error and continue
logger.warning(
"Continuing without complete index initialization. "
"You may need to create indexes manually in MongoDB Atlas dashboard."
)
async def _create_vector_search_index(self) -> None:
"""Create optimized vector search index based on MongoDB RAG best practices."""
try:
# Check if vector search index exists
indexes = list(self.collection.list_search_indexes())
index_exists = any(idx.get("name") == self.config.index_name for idx in indexes)
if not index_exists:
# Create vector search index optimized for RAG
# Based on MongoDB's RAG example using new vectorSearch format
search_index_model = SearchIndexModel(
definition={
"fields": [
{
"type": "vector",
"numDimensions": self.dimension,
"path": self.config.path_field,
"similarity": self._convert_similarity_metric(self.config.similarity_metric),
}
]
},
name=self.config.index_name,
type="vectorSearch",
)
logger.info(
f"Creating vector search index '{self.config.index_name}' for RAG on collection '{self.collection.name}'"
)
self.collection.create_search_index(model=search_index_model)
# Wait for index to be ready (like in MongoDB RAG example)
await self._wait_for_index_ready()
logger.info("Vector search index created and ready for RAG queries")
except Exception as e:
logger.warning(f"Failed to create vector search index: {e}")
def _convert_similarity_metric(self, metric: str) -> str:
"""Convert internal similarity metric to MongoDB Atlas format."""
metric_map = {
"cosine": "cosine",
"euclidean": "euclidean",
"dotProduct": "dotProduct",
"dot_product": "dotProduct",
}
return metric_map.get(metric, "cosine")
async def _wait_for_index_ready(self) -> None:
"""Wait for the vector search index to be ready, based on MongoDB RAG example."""
logger.info("Waiting for vector search index to be ready...")
max_wait_time = 300 # 5 minutes max wait
wait_interval = 5
elapsed_time = 0
while elapsed_time < max_wait_time:
try:
indices = list(self.collection.list_search_indexes(self.config.index_name))
if len(indices) and indices[0].get("queryable") is True:
logger.info(f"Vector search index '{self.config.index_name}' is ready for querying")
return
except Exception:
pass
time.sleep(wait_interval)
elapsed_time += wait_interval
logger.warning(f"Vector search index may not be fully ready after {max_wait_time}s")
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray) -> None:
"""Add chunks with embeddings to MongoDB collection optimized for RAG."""
if len(chunks) != len(embeddings):
raise ValueError(f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}")
documents = []
for i, chunk in enumerate(chunks):
# Structure document for optimal RAG retrieval
doc = {
"_id": chunk.chunk_id,
"chunk_id": chunk.chunk_id,
"text": interleaved_content_as_str(chunk.content), # Key field for RAG context
"content": interleaved_content_as_str(chunk.content), # Backward compatibility
"metadata": chunk.metadata or {},
"chunk_metadata": (chunk.chunk_metadata.model_dump() if chunk.chunk_metadata else {}),
self.config.path_field: embeddings[i].tolist(), # Vector embeddings
"document": chunk.model_dump(), # Full chunk data
}
documents.append(doc)
try:
# Use upsert behavior for chunks
for doc in documents:
self.collection.replace_one({"_id": doc["_id"]}, doc, upsert=True)
logger.debug(f"Successfully added {len(chunks)} chunks optimized for RAG to MongoDB collection")
except Exception as e:
logger.exception(f"Failed to add chunks to MongoDB collection: {e}")
raise
async def query_vector(
self,
embedding: NDArray,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
"""Perform vector similarity search optimized for RAG based on MongoDB example."""
try:
# Use MongoDB's vector search aggregation pipeline optimized for RAG
pipeline = [
{
"$vectorSearch": {
"index": self.config.index_name,
"queryVector": embedding.tolist(),
"path": self.config.path_field,
"numCandidates": min(k * 10, 1000), # Cap at 1000 to prevent excessive candidates
"limit": k,
}
},
{
"$project": {
"_id": 0,
"text": 1, # Primary field for RAG context
"content": 1, # Backward compatibility
"metadata": 1,
"chunk_metadata": 1,
"document": 1,
"score": {"$meta": "vectorSearchScore"},
}
},
{"$match": {"score": {"$gte": score_threshold}}},
]
results = list(self.collection.aggregate(pipeline))
chunks = []
scores = []
for result in results:
score = result.get("score", 0.0)
if score >= score_threshold:
chunk_data = result.get("document", {})
if chunk_data:
chunks.append(Chunk(**chunk_data))
scores.append(float(score))
logger.debug(f"Vector search for RAG returned {len(chunks)} results")
return QueryChunksResponse(chunks=chunks, scores=scores)
except Exception as e:
logger.exception(f"Vector search for RAG failed: {e}")
raise RuntimeError(f"Vector search for RAG failed: {e}") from e
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
"""Perform text search using MongoDB's text search for RAG context retrieval."""
try:
# Ensure text index exists
await self._ensure_text_index()
pipeline: list[dict[str, Any]] = [
{"$match": {"$text": {"$search": query_string}}},
{
"$project": {
"_id": 0,
"text": 1, # Primary field for RAG context
"content": 1, # Backward compatibility
"metadata": 1,
"chunk_metadata": 1,
"document": 1,
"score": {"$meta": "textScore"},
}
},
{"$match": {"score": {"$gte": score_threshold}}},
{"$sort": {"score": {"$meta": "textScore"}}},
{"$limit": k},
]
results = list(self.collection.aggregate(pipeline))
chunks = []
scores = []
for result in results:
score = result.get("score", 0.0)
if score >= score_threshold:
chunk_data = result.get("document", {})
if chunk_data:
chunks.append(Chunk(**chunk_data))
scores.append(float(score))
logger.debug(f"Keyword search for RAG returned {len(chunks)} results")
return QueryChunksResponse(chunks=chunks, scores=scores)
except Exception as e:
logger.exception(f"Keyword search for RAG failed: {e}")
raise RuntimeError(f"Keyword search for RAG failed: {e}") from e
async def query_hybrid(
self,
embedding: NDArray,
query_string: str,
k: int,
score_threshold: float,
reranker_type: str,
reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
"""Perform hybrid search for enhanced RAG context retrieval."""
if reranker_params is None:
reranker_params = {}
# Get results from both search methods
vector_response = await self.query_vector(embedding, k, 0.0)
keyword_response = await self.query_keyword(query_string, k, 0.0)
# Convert responses to score dictionaries
vector_scores = {
chunk.chunk_id: score for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False)
}
keyword_scores = {
chunk.chunk_id: score
for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False)
}
# Combine scores using the reranking utility
combined_scores = WeightedInMemoryAggregator.combine_search_results(
vector_scores, keyword_scores, reranker_type, reranker_params
)
# Get top-k results
top_k_items = heapq.nlargest(k, combined_scores.items(), key=lambda x: x[1])
# Filter by score threshold
filtered_items = [(doc_id, score) for doc_id, score in top_k_items if score >= score_threshold]
# Create chunk map
chunk_map = {c.chunk_id: c for c in vector_response.chunks + keyword_response.chunks}
# Build final results
chunks = []
scores = []
for doc_id, score in filtered_items:
if doc_id in chunk_map:
chunks.append(chunk_map[doc_id])
scores.append(score)
logger.debug(f"Hybrid search for RAG returned {len(chunks)} results")
return QueryChunksResponse(chunks=chunks, scores=scores)
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete chunks from MongoDB collection."""
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
try:
result = self.collection.delete_many({"_id": {"$in": chunk_ids}})
logger.debug(f"Deleted {result.deleted_count} chunks from MongoDB collection")
except Exception as e:
logger.exception(f"Failed to delete chunks: {e}")
raise
async def delete(self) -> None:
"""Delete the entire collection."""
try:
self.collection.drop()
logger.debug(f"Dropped MongoDB collection: {self.collection.name}")
except Exception as e:
logger.exception(f"Failed to drop collection: {e}")
raise
async def _ensure_text_index(self) -> None:
"""Ensure text search index exists on content fields for RAG."""
try:
indexes = list(self.collection.list_indexes())
text_index_exists = any(
any(key.startswith(("content", "text")) for key in idx.get("key", {}).keys())
and idx.get("textIndexVersion") is not None
for idx in indexes
)
if not text_index_exists:
logger.info("Creating text search index on content fields for RAG")
# Index both 'text' and 'content' fields for comprehensive text search
self.collection.create_index([("text", "text"), ("content", "text")])
logger.info("Text search index created successfully for RAG")
except Exception as e:
logger.warning(f"Failed to create text index for RAG: {e}")
class MongoDBVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
"""MongoDB Atlas Vector Search adapter for Llama Stack optimized for RAG workflows."""
def __init__(
self,
config: MongoDBVectorIOConfig,
inference_api,
files_api=None,
models_api=None,
) -> None:
# Handle the case where files_api might be a ProviderSpec that needs resolution
resolved_files_api = files_api
super().__init__(files_api=resolved_files_api, kvstore=None)
self.config = config
self.inference_api = inference_api
self.models_api = models_api
self.client: MongoClient | None = None
self.database: Database | None = None
self.cache: dict[str, VectorStoreWithIndex] = {}
self.kvstore: KVStore | None = None
async def initialize(self) -> None:
"""Initialize MongoDB connection optimized for RAG workflows."""
logger.info("Initializing MongoDB Atlas Vector IO adapter for RAG")
try:
# Initialize KV store for metadata
if self.config.persistence:
self.kvstore = await kvstore_impl(self.config.persistence)
# Get connection string from config (either direct or built from parameters)
connection_string = self.config.get_connection_string()
# Skip MongoDB connection if no connection parameters provided
# This allows other providers to work without MongoDB credentials
if not connection_string:
logger.warning(
"MongoDB connection parameters not provided. "
"MongoDB vector store will not be available until credentials are configured."
)
return
# Connect to MongoDB with optimized settings for RAG
self.client = MongoClient(
connection_string,
server_api=ServerApi("1"),
maxPoolSize=self.config.max_pool_size,
serverSelectionTimeoutMS=self.config.timeout_ms,
# Additional settings for RAG performance
retryWrites=True,
readPreference="primaryPreferred",
)
# Test connection
try:
self.client.admin.command("ping")
logger.info("Successfully connected to MongoDB Atlas for RAG")
except Exception as conn_error:
# Extract just the basic error type without the full traceback
error_type = type(conn_error).__name__
logger.warning(
f"MongoDB connection failed ({error_type}). "
"MongoDB vector store will not be available. "
f"Attempted to connect to: {self.config.host or 'connection_string'}:{self.config.port or '(from connection_string)'}"
)
# Close the client and clear it
if self.client:
self.client.close()
self.client = None
return
# Get database
self.database = self.client[self.config.database_name]
# Initialize OpenAI vector stores
await self.initialize_openai_vector_stores()
# Load existing vector databases
await self._load_existing_vector_dbs()
logger.info("MongoDB Atlas Vector IO adapter for RAG initialized successfully")
except Exception as e:
logger.exception("Failed to initialize MongoDB Atlas Vector IO adapter for RAG")
# Close the client if it was created
if self.client:
self.client.close()
self.client = None
# Log warning instead of raising to allow tests to skip gracefully
logger.warning(f"MongoDB initialization failed: {e}. MongoDB vector store will not be available.")
async def shutdown(self) -> None:
"""Shutdown MongoDB connection."""
if self.client:
self.client.close()
logger.info("MongoDB Atlas RAG connection closed")
async def health(self) -> HealthResponse:
"""Perform health check on MongoDB connection."""
try:
if self.client:
self.client.admin.command("ping")
return HealthResponse(status=HealthStatus.OK)
else:
return HealthResponse(status=HealthStatus.ERROR, message="MongoDB client not initialized")
except Exception as e:
return HealthResponse(
status=HealthStatus.ERROR,
message=f"MongoDB RAG health check failed: {str(e)}",
)
async def register_vector_store(self, vector_store: VectorStore) -> None:
"""Register a new vector store optimized for RAG."""
if self.database is None:
raise RuntimeError("MongoDB database not initialized")
# Create collection name from vector store identifier
collection_name = sanitize_collection_name(vector_store.identifier)
collection = self.database[collection_name]
# Create and initialize MongoDB index optimized for RAG
mongodb_index = MongoDBIndex(vector_store, collection, self.config)
await mongodb_index.initialize()
# Create vector store with index wrapper
vector_store_with_index = VectorStoreWithIndex(
vector_store=vector_store,
index=mongodb_index,
inference_api=self.inference_api,
)
# Cache the vector store
self.cache[vector_store.identifier] = vector_store_with_index
# Save vector store info to KVStore for persistence
if self.kvstore:
await self.kvstore.set(
f"{VECTOR_DBS_PREFIX}{vector_store.identifier}",
vector_store.model_dump_json(),
)
logger.info(f"Registered vector store for RAG: {vector_store.identifier}")
async def unregister_vector_store(self, vector_store_id: str) -> None:
"""Unregister a vector store."""
if vector_store_id in self.cache:
await self.cache[vector_store_id].index.delete()
del self.cache[vector_store_id]
# Clean up from KV store
if self.kvstore:
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_store_id}")
logger.info(f"Unregistered vector store: {vector_store_id}")
async def insert_chunks(
self,
vector_store_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
"""Insert chunks into the vector database optimized for RAG."""
vector_db_with_index = await self._get_vector_db_index(vector_store_id)
await vector_db_with_index.insert_chunks(chunks)
async def query_chunks(
self,
vector_store_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
"""Query chunks from the vector database optimized for RAG context retrieval."""
vector_db_with_index = await self._get_vector_db_index(vector_store_id)
return await vector_db_with_index.query_chunks(query, params)
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete chunks from the vector database."""
vector_db_with_index = await self._get_vector_db_index(store_id)
await vector_db_with_index.index.delete_chunks(chunks_for_deletion)
async def _get_vector_db_index(self, vector_db_id: str) -> VectorStoreWithIndex:
"""Get vector store index from cache."""
if vector_db_id in self.cache:
return self.cache[vector_db_id]
raise VectorStoreNotFoundError(vector_db_id)
async def _load_existing_vector_dbs(self) -> None:
"""Load existing vector databases from KVStore."""
if not self.kvstore:
return
try:
# Use keys_in_range to get all vector database keys from KVStore
# This searches for keys with the prefix by using range scan
start_key = VECTOR_DBS_PREFIX
# Create an end key by incrementing the last character
end_key = VECTOR_DBS_PREFIX[:-1] + chr(ord(VECTOR_DBS_PREFIX[-1]) + 1)
vector_db_keys = await self.kvstore.keys_in_range(start_key, end_key)
for key in vector_db_keys:
try:
vector_store_data = await self.kvstore.get(key)
if vector_store_data:
import json
vector_store = VectorStore(**json.loads(vector_store_data))
# Register the vector store without re-initializing
await self._register_existing_vector_store(vector_store)
logger.info(f"Loaded existing RAG-optimized vector store: {vector_store.identifier}")
except Exception as e:
logger.warning(f"Failed to load vector store from key {key}: {e}")
continue
except Exception as e:
logger.warning(f"Failed to load existing vector stores: {e}")
async def _register_existing_vector_store(self, vector_store: VectorStore) -> None:
"""Register an existing vector store without re-initialization."""
if self.database is None:
raise RuntimeError("MongoDB database not initialized")
# Create collection name from vector store identifier
collection_name = sanitize_collection_name(vector_store.identifier)
collection = self.database[collection_name]
# Create MongoDB index without initialization (collection already exists)
mongodb_index = MongoDBIndex(vector_store, collection, self.config)
# Create vector store with index wrapper
vector_store_with_index = VectorStoreWithIndex(
vector_store=vector_store,
index=mongodb_index,
inference_api=self.inference_api,
)
# Cache the vector store
self.cache[vector_store.identifier] = vector_store_with_index

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.