mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
Merge 503ad16002 into 356f37b1ba
This commit is contained in:
commit
fdc9ba2687
17 changed files with 2066 additions and 3 deletions
|
|
@ -31,7 +31,7 @@ jobs:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
vector-io-provider: ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::chromadb", "remote::pgvector", "remote::weaviate", "remote::qdrant"]
|
vector-io-provider: ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::chromadb", "remote::pgvector", "remote::weaviate", "remote::qdrant", "remote::mongodb"]
|
||||||
python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }}
|
python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }}
|
||||||
fail-fast: false # we want to run all tests regardless of failure
|
fail-fast: false # we want to run all tests regardless of failure
|
||||||
|
|
||||||
|
|
@ -101,6 +101,16 @@ jobs:
|
||||||
-p 6333:6333 \
|
-p 6333:6333 \
|
||||||
qdrant/qdrant
|
qdrant/qdrant
|
||||||
|
|
||||||
|
- name: Setup MongoDB
|
||||||
|
if: matrix.vector-io-provider == 'remote::mongodb'
|
||||||
|
run: |
|
||||||
|
docker run --rm -d --pull always \
|
||||||
|
--name mongodb \
|
||||||
|
-p 27017:27017 \
|
||||||
|
-e MONGO_INITDB_ROOT_USERNAME=llamastack \
|
||||||
|
-e MONGO_INITDB_ROOT_PASSWORD=llamastack \
|
||||||
|
mongodb/mongodb-atlas-local:latest
|
||||||
|
|
||||||
- name: Wait for Qdrant to be ready
|
- name: Wait for Qdrant to be ready
|
||||||
if: matrix.vector-io-provider == 'remote::qdrant'
|
if: matrix.vector-io-provider == 'remote::qdrant'
|
||||||
run: |
|
run: |
|
||||||
|
|
@ -116,6 +126,21 @@ jobs:
|
||||||
docker logs qdrant
|
docker logs qdrant
|
||||||
exit 1
|
exit 1
|
||||||
|
|
||||||
|
- name: Wait for MongoDB to be ready
|
||||||
|
if: matrix.vector-io-provider == 'remote::mongodb'
|
||||||
|
run: |
|
||||||
|
echo "Waiting for MongoDB to be ready..."
|
||||||
|
for i in {1..30}; do
|
||||||
|
if docker exec mongodb mongosh --quiet --eval "db.adminCommand('ping').ok" > /dev/null 2>&1; then
|
||||||
|
echo "MongoDB is ready!"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
sleep 2
|
||||||
|
done
|
||||||
|
echo "MongoDB failed to start"
|
||||||
|
docker logs mongodb
|
||||||
|
exit 1
|
||||||
|
|
||||||
- name: Wait for ChromaDB to be ready
|
- name: Wait for ChromaDB to be ready
|
||||||
if: matrix.vector-io-provider == 'remote::chromadb'
|
if: matrix.vector-io-provider == 'remote::chromadb'
|
||||||
run: |
|
run: |
|
||||||
|
|
@ -170,6 +195,11 @@ jobs:
|
||||||
QDRANT_URL: ${{ matrix.vector-io-provider == 'remote::qdrant' && 'http://localhost:6333' || '' }}
|
QDRANT_URL: ${{ matrix.vector-io-provider == 'remote::qdrant' && 'http://localhost:6333' || '' }}
|
||||||
ENABLE_WEAVIATE: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'true' || '' }}
|
ENABLE_WEAVIATE: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'true' || '' }}
|
||||||
WEAVIATE_CLUSTER_URL: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'localhost:8080' || '' }}
|
WEAVIATE_CLUSTER_URL: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'localhost:8080' || '' }}
|
||||||
|
ENABLE_MONGODB: ${{ matrix.vector-io-provider == 'remote::mongodb' && 'true' || '' }}
|
||||||
|
MONGODB_HOST: ${{ matrix.vector-io-provider == 'remote::mongodb' && 'localhost' || '' }}
|
||||||
|
MONGODB_PORT: ${{ matrix.vector-io-provider == 'remote::mongodb' && '27017' || '' }}
|
||||||
|
MONGODB_USERNAME: ${{ matrix.vector-io-provider == 'remote::mongodb' && 'llamastack' || '' }}
|
||||||
|
MONGODB_PASSWORD: ${{ matrix.vector-io-provider == 'remote::mongodb' && 'llamastack' || '' }}
|
||||||
run: |
|
run: |
|
||||||
uv run --no-sync \
|
uv run --no-sync \
|
||||||
pytest -sv --stack-config="files=inline::localfs,inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \
|
pytest -sv --stack-config="files=inline::localfs,inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \
|
||||||
|
|
@ -196,6 +226,11 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
docker logs qdrant > qdrant.log
|
docker logs qdrant > qdrant.log
|
||||||
|
|
||||||
|
- name: Write MongoDB logs to file
|
||||||
|
if: ${{ always() && matrix.vector-io-provider == 'remote::mongodb' }}
|
||||||
|
run: |
|
||||||
|
docker logs mongodb > mongodb.log
|
||||||
|
|
||||||
- name: Upload all logs to artifacts
|
- name: Upload all logs to artifacts
|
||||||
if: ${{ always() }}
|
if: ${{ always() }}
|
||||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
|
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
|
||||||
|
|
|
||||||
276
docs/docs/providers/vector_io/remote_mongodb.mdx
Normal file
276
docs/docs/providers/vector_io/remote_mongodb.mdx
Normal file
|
|
@ -0,0 +1,276 @@
|
||||||
|
---
|
||||||
|
description: |
|
||||||
|
[MongoDB Atlas](https://www.mongodb.com/products/platform/atlas-vector-search) is a remote vector database provider for Llama Stack. It
|
||||||
|
uses MongoDB Atlas Vector Search to store and query vectors in the cloud.
|
||||||
|
That means you get enterprise-grade vector search with MongoDB's scalability and reliability.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Cloud-native vector search with MongoDB Atlas
|
||||||
|
- Fully integrated with Llama Stack
|
||||||
|
- Enterprise-grade security and scalability
|
||||||
|
- Supports multiple search modes: vector, keyword, and hybrid search
|
||||||
|
- Built-in metadata filtering and text search capabilities
|
||||||
|
- Automatic index management
|
||||||
|
|
||||||
|
## Search Modes
|
||||||
|
|
||||||
|
MongoDB Atlas Vector Search supports three different search modes:
|
||||||
|
|
||||||
|
### Vector Search
|
||||||
|
Vector search uses MongoDB's `$vectorSearch` aggregation stage to perform semantic similarity search using embedding vectors.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Vector search example
|
||||||
|
search_response = client.vector_stores.search(
|
||||||
|
vector_store_id=vector_store.id,
|
||||||
|
query="What is machine learning?",
|
||||||
|
search_mode="vector",
|
||||||
|
max_num_results=5,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Keyword Search
|
||||||
|
Keyword search uses MongoDB's text search capabilities with full-text indexes to find chunks containing specific terms.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Keyword search example
|
||||||
|
search_response = client.vector_stores.search(
|
||||||
|
vector_store_id=vector_store.id,
|
||||||
|
query="Python programming language",
|
||||||
|
search_mode="keyword",
|
||||||
|
max_num_results=5,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Hybrid Search
|
||||||
|
Hybrid search combines both vector and keyword search methods using configurable reranking algorithms.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Hybrid search with RRF ranker (default)
|
||||||
|
search_response = client.vector_stores.search(
|
||||||
|
vector_store_id=vector_store.id,
|
||||||
|
query="neural networks in Python",
|
||||||
|
search_mode="hybrid",
|
||||||
|
max_num_results=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Hybrid search with weighted ranker
|
||||||
|
search_response = client.vector_stores.search(
|
||||||
|
vector_store_id=vector_store.id,
|
||||||
|
query="neural networks in Python",
|
||||||
|
search_mode="hybrid",
|
||||||
|
max_num_results=5,
|
||||||
|
ranking_options={
|
||||||
|
"ranker": {
|
||||||
|
"type": "weighted",
|
||||||
|
"alpha": 0.7, # 70% vector search, 30% keyword search
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use MongoDB Atlas in your Llama Stack project, follow these steps:
|
||||||
|
|
||||||
|
1. Create a MongoDB Atlas cluster with Vector Search enabled
|
||||||
|
2. Install the necessary dependencies
|
||||||
|
3. Configure your Llama Stack project to use MongoDB
|
||||||
|
4. Start storing and querying vectors
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
Set up the following environment variable for your MongoDB Atlas connection:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export MONGODB_CONNECTION_STRING="mongodb+srv://username:password@cluster.mongodb.net/?retryWrites=true&w=majority&appName=llama-stack"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Configuration Example
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
vector_io:
|
||||||
|
- provider_id: mongodb_atlas
|
||||||
|
provider_type: remote::mongodb
|
||||||
|
config:
|
||||||
|
connection_string: "${env.MONGODB_CONNECTION_STRING}"
|
||||||
|
database_name: "llama_stack"
|
||||||
|
index_name: "vector_index"
|
||||||
|
similarity_metric: "cosine"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
You can install the MongoDB Python driver using pip:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install pymongo
|
||||||
|
```
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
|
||||||
|
See [MongoDB Atlas Vector Search documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-overview/) for more details about MongoDB Atlas Vector Search.
|
||||||
|
|
||||||
|
For general MongoDB documentation, visit [MongoDB Documentation](https://docs.mongodb.com/).
|
||||||
|
sidebar_label: Remote - Mongodb
|
||||||
|
title: remote::mongodb
|
||||||
|
---
|
||||||
|
|
||||||
|
# remote::mongodb
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
|
||||||
|
[MongoDB Atlas](https://www.mongodb.com/products/platform/atlas-vector-search) is a remote vector database provider for Llama Stack. It
|
||||||
|
uses MongoDB Atlas Vector Search to store and query vectors in the cloud.
|
||||||
|
That means you get enterprise-grade vector search with MongoDB's scalability and reliability.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Cloud-native vector search with MongoDB Atlas
|
||||||
|
- Fully integrated with Llama Stack
|
||||||
|
- Enterprise-grade security and scalability
|
||||||
|
- Supports multiple search modes: vector, keyword, and hybrid search
|
||||||
|
- Built-in metadata filtering and text search capabilities
|
||||||
|
- Automatic index management
|
||||||
|
|
||||||
|
## Search Modes
|
||||||
|
|
||||||
|
MongoDB Atlas Vector Search supports three different search modes:
|
||||||
|
|
||||||
|
### Vector Search
|
||||||
|
Vector search uses MongoDB's `$vectorSearch` aggregation stage to perform semantic similarity search using embedding vectors.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Vector search example
|
||||||
|
search_response = client.vector_stores.search(
|
||||||
|
vector_store_id=vector_store.id,
|
||||||
|
query="What is machine learning?",
|
||||||
|
search_mode="vector",
|
||||||
|
max_num_results=5,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Keyword Search
|
||||||
|
Keyword search uses MongoDB's text search capabilities with full-text indexes to find chunks containing specific terms.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Keyword search example
|
||||||
|
search_response = client.vector_stores.search(
|
||||||
|
vector_store_id=vector_store.id,
|
||||||
|
query="Python programming language",
|
||||||
|
search_mode="keyword",
|
||||||
|
max_num_results=5,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Hybrid Search
|
||||||
|
Hybrid search combines both vector and keyword search methods using configurable reranking algorithms.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Hybrid search with RRF ranker (default)
|
||||||
|
search_response = client.vector_stores.search(
|
||||||
|
vector_store_id=vector_store.id,
|
||||||
|
query="neural networks in Python",
|
||||||
|
search_mode="hybrid",
|
||||||
|
max_num_results=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Hybrid search with weighted ranker
|
||||||
|
search_response = client.vector_stores.search(
|
||||||
|
vector_store_id=vector_store.id,
|
||||||
|
query="neural networks in Python",
|
||||||
|
search_mode="hybrid",
|
||||||
|
max_num_results=5,
|
||||||
|
ranking_options={
|
||||||
|
"ranker": {
|
||||||
|
"type": "weighted",
|
||||||
|
"alpha": 0.7, # 70% vector search, 30% keyword search
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use MongoDB Atlas in your Llama Stack project, follow these steps:
|
||||||
|
|
||||||
|
1. Create a MongoDB Atlas cluster with Vector Search enabled
|
||||||
|
2. Install the necessary dependencies
|
||||||
|
3. Configure your Llama Stack project to use MongoDB
|
||||||
|
4. Start storing and querying vectors
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
Set up the following environment variable for your MongoDB Atlas connection:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export MONGODB_CONNECTION_STRING="mongodb+srv://username:password@cluster.mongodb.net/?retryWrites=true&w=majority&appName=llama-stack"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Configuration Example
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
vector_io:
|
||||||
|
- provider_id: mongodb_atlas
|
||||||
|
provider_type: remote::mongodb
|
||||||
|
config:
|
||||||
|
connection_string: "${env.MONGODB_CONNECTION_STRING}"
|
||||||
|
database_name: "llama_stack"
|
||||||
|
index_name: "vector_index"
|
||||||
|
similarity_metric: "cosine"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
You can install the MongoDB Python driver using pip:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install pymongo
|
||||||
|
```
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
|
||||||
|
See [MongoDB Atlas Vector Search documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-overview/) for more details about MongoDB Atlas Vector Search.
|
||||||
|
|
||||||
|
For general MongoDB documentation, visit [MongoDB Documentation](https://docs.mongodb.com/).
|
||||||
|
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
| Field | Type | Required | Default | Description |
|
||||||
|
|-------|------|----------|---------|-------------|
|
||||||
|
| `connection_string` | `str \| None` | No | | MongoDB connection string (e.g., mongodb://user:pass@localhost:27017/ or mongodb+srv://user:pass@cluster.mongodb.net/) |
|
||||||
|
| `host` | `str \| None` | No | | MongoDB host (used if connection_string is not provided) |
|
||||||
|
| `port` | `int \| None` | No | | MongoDB port (used if connection_string is not provided) |
|
||||||
|
| `username` | `str \| None` | No | | MongoDB username (used if connection_string is not provided) |
|
||||||
|
| `password` | `str \| None` | No | | MongoDB password (used if connection_string is not provided) |
|
||||||
|
| `database_name` | `<class 'str'>` | No | llama_stack | Database name to use for vector collections |
|
||||||
|
| `index_name` | `<class 'str'>` | No | vector_index | Name of the vector search index |
|
||||||
|
| `path_field` | `<class 'str'>` | No | embedding | Field name for storing embeddings |
|
||||||
|
| `similarity_metric` | `<class 'str'>` | No | cosine | Similarity metric: cosine, euclidean, or dotProduct |
|
||||||
|
| `max_pool_size` | `<class 'int'>` | No | 100 | Maximum connection pool size |
|
||||||
|
| `timeout_ms` | `<class 'int'>` | No | 30000 | Connection timeout in milliseconds |
|
||||||
|
| `persistence` | `llama_stack.core.storage.datatypes.KVStoreReference \| None` | No | | Config for KV store backend for metadata storage |
|
||||||
|
|
||||||
|
## Sample Configuration
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
connection_string: ${env.MONGODB_CONNECTION_STRING:=}
|
||||||
|
host: ${env.MONGODB_HOST:=localhost}
|
||||||
|
port: ${env.MONGODB_PORT:=27017}
|
||||||
|
username: ${env.MONGODB_USERNAME:=}
|
||||||
|
password: ${env.MONGODB_PASSWORD:=}
|
||||||
|
database_name: ${env.MONGODB_DATABASE_NAME:=llama_stack}
|
||||||
|
index_name: ${env.MONGODB_INDEX_NAME:=vector_index}
|
||||||
|
path_field: ${env.MONGODB_PATH_FIELD:=embedding}
|
||||||
|
similarity_metric: ${env.MONGODB_SIMILARITY_METRIC:=cosine}
|
||||||
|
max_pool_size: ${env.MONGODB_MAX_POOL_SIZE:=100}
|
||||||
|
timeout_ms: ${env.MONGODB_TIMEOUT_MS:=30000}
|
||||||
|
persistence:
|
||||||
|
namespace: vector_io::mongodb_atlas
|
||||||
|
backend: kv_default
|
||||||
|
```
|
||||||
20
llama_stack/providers/remote/vector_io/mongodb/__init__.py
Normal file
20
llama_stack/providers/remote/vector_io/mongodb/__init__.py
Normal file
|
|
@ -0,0 +1,20 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
|
from .config import MongoDBVectorIOConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: MongoDBVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||||
|
from .mongodb import MongoDBVectorIOAdapter
|
||||||
|
|
||||||
|
# Handle the deps resolution - if files API exists, pass it, otherwise None
|
||||||
|
files_api = deps.get(Api.files)
|
||||||
|
models_api = deps.get(Api.models)
|
||||||
|
impl = MongoDBVectorIOAdapter(config, deps[Api.inference], files_api, models_api)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
102
llama_stack/providers/remote/vector_io/mongodb/config.py
Normal file
102
llama_stack/providers/remote/vector_io/mongodb/config.py
Normal file
|
|
@ -0,0 +1,102 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class MongoDBVectorIOConfig(BaseModel):
|
||||||
|
"""Configuration for MongoDB Atlas Vector Search provider.
|
||||||
|
|
||||||
|
This provider connects to MongoDB Atlas and uses Vector Search for RAG operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# MongoDB connection details - either connection_string or individual parameters
|
||||||
|
connection_string: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="MongoDB connection string (e.g., mongodb://user:pass@localhost:27017/ or mongodb+srv://user:pass@cluster.mongodb.net/)",
|
||||||
|
)
|
||||||
|
host: str | None = Field(default=None, description="MongoDB host (used if connection_string is not provided)")
|
||||||
|
port: int | None = Field(default=None, description="MongoDB port (used if connection_string is not provided)")
|
||||||
|
username: str | None = Field(
|
||||||
|
default=None, description="MongoDB username (used if connection_string is not provided)"
|
||||||
|
)
|
||||||
|
password: str | None = Field(
|
||||||
|
default=None, description="MongoDB password (used if connection_string is not provided)"
|
||||||
|
)
|
||||||
|
database_name: str = Field(default="llama_stack", description="Database name to use for vector collections")
|
||||||
|
|
||||||
|
# Vector search configuration
|
||||||
|
index_name: str = Field(default="vector_index", description="Name of the vector search index")
|
||||||
|
path_field: str = Field(default="embedding", description="Field name for storing embeddings")
|
||||||
|
similarity_metric: str = Field(
|
||||||
|
default="cosine",
|
||||||
|
description="Similarity metric: cosine, euclidean, or dotProduct",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Connection options
|
||||||
|
max_pool_size: int = Field(default=100, description="Maximum connection pool size")
|
||||||
|
timeout_ms: int = Field(default=30000, description="Connection timeout in milliseconds")
|
||||||
|
|
||||||
|
# KV store configuration
|
||||||
|
persistence: KVStoreReference | None = Field(
|
||||||
|
description="Config for KV store backend for metadata storage", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_connection_string(self) -> str | None:
|
||||||
|
"""Build connection string from individual parameters if not provided directly.
|
||||||
|
|
||||||
|
If both connection_string and individual parameters (host/port) are provided,
|
||||||
|
individual parameters take precedence to allow test environment overrides.
|
||||||
|
"""
|
||||||
|
# Prioritize individual connection parameters over connection_string
|
||||||
|
# This allows test environments to override with MONGODB_HOST/PORT/etc
|
||||||
|
if self.host and self.port:
|
||||||
|
auth_part = ""
|
||||||
|
if self.username and self.password:
|
||||||
|
auth_part = f"{self.username}:{self.password}@"
|
||||||
|
return f"mongodb://{auth_part}{self.host}:{self.port}/"
|
||||||
|
|
||||||
|
# Fall back to connection_string if provided
|
||||||
|
if self.connection_string:
|
||||||
|
return self.connection_string
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(
|
||||||
|
cls,
|
||||||
|
__distro_dir__: str,
|
||||||
|
connection_string: str = "${env.MONGODB_CONNECTION_STRING:=}",
|
||||||
|
host: str = "${env.MONGODB_HOST:=localhost}",
|
||||||
|
port: int = "${env.MONGODB_PORT:=27017}",
|
||||||
|
username: str = "${env.MONGODB_USERNAME:=}",
|
||||||
|
password: str = "${env.MONGODB_PASSWORD:=}",
|
||||||
|
database_name: str = "${env.MONGODB_DATABASE_NAME:=llama_stack}",
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"connection_string": connection_string,
|
||||||
|
"host": host,
|
||||||
|
"port": port,
|
||||||
|
"username": username,
|
||||||
|
"password": password,
|
||||||
|
"database_name": database_name,
|
||||||
|
"index_name": "${env.MONGODB_INDEX_NAME:=vector_index}",
|
||||||
|
"path_field": "${env.MONGODB_PATH_FIELD:=embedding}",
|
||||||
|
"similarity_metric": "${env.MONGODB_SIMILARITY_METRIC:=cosine}",
|
||||||
|
"max_pool_size": "${env.MONGODB_MAX_POOL_SIZE:=100}",
|
||||||
|
"timeout_ms": "${env.MONGODB_TIMEOUT_MS:=30000}",
|
||||||
|
"persistence": KVStoreReference(
|
||||||
|
backend="kv_default",
|
||||||
|
namespace="vector_io::mongodb_atlas",
|
||||||
|
).model_dump(exclude_none=True),
|
||||||
|
}
|
||||||
609
llama_stack/providers/remote/vector_io/mongodb/mongodb.py
Normal file
609
llama_stack/providers/remote/vector_io/mongodb/mongodb.py
Normal file
|
|
@ -0,0 +1,609 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import heapq
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
from pymongo import MongoClient
|
||||||
|
from pymongo.collection import Collection
|
||||||
|
from pymongo.database import Database
|
||||||
|
from pymongo.operations import SearchIndexModel
|
||||||
|
from pymongo.server_api import ServerApi
|
||||||
|
|
||||||
|
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||||
|
from llama_stack.apis.inference import InterleavedContent
|
||||||
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
|
from llama_stack.apis.vector_stores import VectorStore
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.datatypes import (
|
||||||
|
HealthResponse,
|
||||||
|
HealthStatus,
|
||||||
|
VectorStoresProtocolPrivate,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
interleaved_content_as_str,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||||
|
from llama_stack.providers.utils.memory.openai_vector_store_mixin import (
|
||||||
|
OpenAIVectorStoreMixin,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
ChunkForDeletion,
|
||||||
|
EmbeddingIndex,
|
||||||
|
VectorStoreWithIndex,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.vector_io.vector_utils import (
|
||||||
|
WeightedInMemoryAggregator,
|
||||||
|
sanitize_collection_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .config import MongoDBVectorIOConfig
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="vector_io::mongodb")
|
||||||
|
|
||||||
|
VERSION = "v1"
|
||||||
|
VECTOR_DBS_PREFIX = f"vector_dbs:mongodb:{VERSION}::"
|
||||||
|
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:mongodb:{VERSION}::"
|
||||||
|
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:mongodb:{VERSION}::"
|
||||||
|
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:mongodb:{VERSION}::"
|
||||||
|
|
||||||
|
|
||||||
|
class MongoDBIndex(EmbeddingIndex):
|
||||||
|
"""MongoDB Atlas Vector Search index implementation optimized for RAG."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vector_store: VectorStore,
|
||||||
|
collection: Collection,
|
||||||
|
config: MongoDBVectorIOConfig,
|
||||||
|
):
|
||||||
|
self.vector_store = vector_store
|
||||||
|
self.collection = collection
|
||||||
|
self.config = config
|
||||||
|
self.dimension = vector_store.embedding_dimension
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
"""Initialize the MongoDB collection and ensure vector search index exists."""
|
||||||
|
try:
|
||||||
|
# Create the collection if it doesn't exist
|
||||||
|
collection_names = self.collection.database.list_collection_names()
|
||||||
|
if self.collection.name not in collection_names:
|
||||||
|
logger.info(f"Creating collection '{self.collection.name}'")
|
||||||
|
# Create collection by inserting a dummy document
|
||||||
|
dummy_doc = {"_id": "__dummy__", "dummy": True}
|
||||||
|
self.collection.insert_one(dummy_doc)
|
||||||
|
# Remove the dummy document
|
||||||
|
self.collection.delete_one({"_id": "__dummy__"})
|
||||||
|
logger.info(f"Collection '{self.collection.name}' created successfully")
|
||||||
|
|
||||||
|
# Create optimized vector search index for RAG
|
||||||
|
await self._create_vector_search_index()
|
||||||
|
|
||||||
|
# Create text index for hybrid search
|
||||||
|
await self._ensure_text_index()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(
|
||||||
|
f"Failed to initialize MongoDB index for vector_store: {self.vector_store.identifier}. "
|
||||||
|
f"Collection name: {self.collection.name}. Error: {str(e)}"
|
||||||
|
)
|
||||||
|
# Don't fail completely - just log the error and continue
|
||||||
|
logger.warning(
|
||||||
|
"Continuing without complete index initialization. "
|
||||||
|
"You may need to create indexes manually in MongoDB Atlas dashboard."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _create_vector_search_index(self) -> None:
|
||||||
|
"""Create optimized vector search index based on MongoDB RAG best practices."""
|
||||||
|
try:
|
||||||
|
# Check if vector search index exists
|
||||||
|
indexes = list(self.collection.list_search_indexes())
|
||||||
|
index_exists = any(idx.get("name") == self.config.index_name for idx in indexes)
|
||||||
|
|
||||||
|
if not index_exists:
|
||||||
|
# Create vector search index optimized for RAG
|
||||||
|
# Based on MongoDB's RAG example using new vectorSearch format
|
||||||
|
search_index_model = SearchIndexModel(
|
||||||
|
definition={
|
||||||
|
"fields": [
|
||||||
|
{
|
||||||
|
"type": "vector",
|
||||||
|
"numDimensions": self.dimension,
|
||||||
|
"path": self.config.path_field,
|
||||||
|
"similarity": self._convert_similarity_metric(self.config.similarity_metric),
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
name=self.config.index_name,
|
||||||
|
type="vectorSearch",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Creating vector search index '{self.config.index_name}' for RAG on collection '{self.collection.name}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.collection.create_search_index(model=search_index_model)
|
||||||
|
|
||||||
|
# Wait for index to be ready (like in MongoDB RAG example)
|
||||||
|
await self._wait_for_index_ready()
|
||||||
|
|
||||||
|
logger.info("Vector search index created and ready for RAG queries")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to create vector search index: {e}")
|
||||||
|
|
||||||
|
def _convert_similarity_metric(self, metric: str) -> str:
|
||||||
|
"""Convert internal similarity metric to MongoDB Atlas format."""
|
||||||
|
metric_map = {
|
||||||
|
"cosine": "cosine",
|
||||||
|
"euclidean": "euclidean",
|
||||||
|
"dotProduct": "dotProduct",
|
||||||
|
"dot_product": "dotProduct",
|
||||||
|
}
|
||||||
|
return metric_map.get(metric, "cosine")
|
||||||
|
|
||||||
|
async def _wait_for_index_ready(self) -> None:
|
||||||
|
"""Wait for the vector search index to be ready, based on MongoDB RAG example."""
|
||||||
|
logger.info("Waiting for vector search index to be ready...")
|
||||||
|
|
||||||
|
max_wait_time = 300 # 5 minutes max wait
|
||||||
|
wait_interval = 5
|
||||||
|
elapsed_time = 0
|
||||||
|
|
||||||
|
while elapsed_time < max_wait_time:
|
||||||
|
try:
|
||||||
|
indices = list(self.collection.list_search_indexes(self.config.index_name))
|
||||||
|
if len(indices) and indices[0].get("queryable") is True:
|
||||||
|
logger.info(f"Vector search index '{self.config.index_name}' is ready for querying")
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
time.sleep(wait_interval)
|
||||||
|
elapsed_time += wait_interval
|
||||||
|
|
||||||
|
logger.warning(f"Vector search index may not be fully ready after {max_wait_time}s")
|
||||||
|
|
||||||
|
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray) -> None:
|
||||||
|
"""Add chunks with embeddings to MongoDB collection optimized for RAG."""
|
||||||
|
if len(chunks) != len(embeddings):
|
||||||
|
raise ValueError(f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}")
|
||||||
|
|
||||||
|
documents = []
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
# Structure document for optimal RAG retrieval
|
||||||
|
doc = {
|
||||||
|
"_id": chunk.chunk_id,
|
||||||
|
"chunk_id": chunk.chunk_id,
|
||||||
|
"text": interleaved_content_as_str(chunk.content), # Key field for RAG context
|
||||||
|
"content": interleaved_content_as_str(chunk.content), # Backward compatibility
|
||||||
|
"metadata": chunk.metadata or {},
|
||||||
|
"chunk_metadata": (chunk.chunk_metadata.model_dump() if chunk.chunk_metadata else {}),
|
||||||
|
self.config.path_field: embeddings[i].tolist(), # Vector embeddings
|
||||||
|
"document": chunk.model_dump(), # Full chunk data
|
||||||
|
}
|
||||||
|
documents.append(doc)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use upsert behavior for chunks
|
||||||
|
for doc in documents:
|
||||||
|
self.collection.replace_one({"_id": doc["_id"]}, doc, upsert=True)
|
||||||
|
|
||||||
|
logger.debug(f"Successfully added {len(chunks)} chunks optimized for RAG to MongoDB collection")
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Failed to add chunks to MongoDB collection: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def query_vector(
|
||||||
|
self,
|
||||||
|
embedding: NDArray,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
"""Perform vector similarity search optimized for RAG based on MongoDB example."""
|
||||||
|
try:
|
||||||
|
# Use MongoDB's vector search aggregation pipeline optimized for RAG
|
||||||
|
pipeline = [
|
||||||
|
{
|
||||||
|
"$vectorSearch": {
|
||||||
|
"index": self.config.index_name,
|
||||||
|
"queryVector": embedding.tolist(),
|
||||||
|
"path": self.config.path_field,
|
||||||
|
"numCandidates": min(k * 10, 1000), # Cap at 1000 to prevent excessive candidates
|
||||||
|
"limit": k,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$project": {
|
||||||
|
"_id": 0,
|
||||||
|
"text": 1, # Primary field for RAG context
|
||||||
|
"content": 1, # Backward compatibility
|
||||||
|
"metadata": 1,
|
||||||
|
"chunk_metadata": 1,
|
||||||
|
"document": 1,
|
||||||
|
"score": {"$meta": "vectorSearchScore"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"$match": {"score": {"$gte": score_threshold}}},
|
||||||
|
]
|
||||||
|
|
||||||
|
results = list(self.collection.aggregate(pipeline))
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
scores = []
|
||||||
|
for result in results:
|
||||||
|
score = result.get("score", 0.0)
|
||||||
|
if score >= score_threshold:
|
||||||
|
chunk_data = result.get("document", {})
|
||||||
|
if chunk_data:
|
||||||
|
chunks.append(Chunk(**chunk_data))
|
||||||
|
scores.append(float(score))
|
||||||
|
|
||||||
|
logger.debug(f"Vector search for RAG returned {len(chunks)} results")
|
||||||
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Vector search for RAG failed: {e}")
|
||||||
|
raise RuntimeError(f"Vector search for RAG failed: {e}") from e
|
||||||
|
|
||||||
|
async def query_keyword(
|
||||||
|
self,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
"""Perform text search using MongoDB's text search for RAG context retrieval."""
|
||||||
|
try:
|
||||||
|
# Ensure text index exists
|
||||||
|
await self._ensure_text_index()
|
||||||
|
|
||||||
|
pipeline: list[dict[str, Any]] = [
|
||||||
|
{"$match": {"$text": {"$search": query_string}}},
|
||||||
|
{
|
||||||
|
"$project": {
|
||||||
|
"_id": 0,
|
||||||
|
"text": 1, # Primary field for RAG context
|
||||||
|
"content": 1, # Backward compatibility
|
||||||
|
"metadata": 1,
|
||||||
|
"chunk_metadata": 1,
|
||||||
|
"document": 1,
|
||||||
|
"score": {"$meta": "textScore"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"$match": {"score": {"$gte": score_threshold}}},
|
||||||
|
{"$sort": {"score": {"$meta": "textScore"}}},
|
||||||
|
{"$limit": k},
|
||||||
|
]
|
||||||
|
|
||||||
|
results = list(self.collection.aggregate(pipeline))
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
scores = []
|
||||||
|
for result in results:
|
||||||
|
score = result.get("score", 0.0)
|
||||||
|
if score >= score_threshold:
|
||||||
|
chunk_data = result.get("document", {})
|
||||||
|
if chunk_data:
|
||||||
|
chunks.append(Chunk(**chunk_data))
|
||||||
|
scores.append(float(score))
|
||||||
|
|
||||||
|
logger.debug(f"Keyword search for RAG returned {len(chunks)} results")
|
||||||
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Keyword search for RAG failed: {e}")
|
||||||
|
raise RuntimeError(f"Keyword search for RAG failed: {e}") from e
|
||||||
|
|
||||||
|
async def query_hybrid(
|
||||||
|
self,
|
||||||
|
embedding: NDArray,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
reranker_type: str,
|
||||||
|
reranker_params: dict[str, Any] | None = None,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
"""Perform hybrid search for enhanced RAG context retrieval."""
|
||||||
|
if reranker_params is None:
|
||||||
|
reranker_params = {}
|
||||||
|
|
||||||
|
# Get results from both search methods
|
||||||
|
vector_response = await self.query_vector(embedding, k, 0.0)
|
||||||
|
keyword_response = await self.query_keyword(query_string, k, 0.0)
|
||||||
|
|
||||||
|
# Convert responses to score dictionaries
|
||||||
|
vector_scores = {
|
||||||
|
chunk.chunk_id: score for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False)
|
||||||
|
}
|
||||||
|
keyword_scores = {
|
||||||
|
chunk.chunk_id: score
|
||||||
|
for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Combine scores using the reranking utility
|
||||||
|
combined_scores = WeightedInMemoryAggregator.combine_search_results(
|
||||||
|
vector_scores, keyword_scores, reranker_type, reranker_params
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get top-k results
|
||||||
|
top_k_items = heapq.nlargest(k, combined_scores.items(), key=lambda x: x[1])
|
||||||
|
|
||||||
|
# Filter by score threshold
|
||||||
|
filtered_items = [(doc_id, score) for doc_id, score in top_k_items if score >= score_threshold]
|
||||||
|
|
||||||
|
# Create chunk map
|
||||||
|
chunk_map = {c.chunk_id: c for c in vector_response.chunks + keyword_response.chunks}
|
||||||
|
|
||||||
|
# Build final results
|
||||||
|
chunks = []
|
||||||
|
scores = []
|
||||||
|
for doc_id, score in filtered_items:
|
||||||
|
if doc_id in chunk_map:
|
||||||
|
chunks.append(chunk_map[doc_id])
|
||||||
|
scores.append(score)
|
||||||
|
|
||||||
|
logger.debug(f"Hybrid search for RAG returned {len(chunks)} results")
|
||||||
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||||
|
"""Delete chunks from MongoDB collection."""
|
||||||
|
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
|
||||||
|
try:
|
||||||
|
result = self.collection.delete_many({"_id": {"$in": chunk_ids}})
|
||||||
|
logger.debug(f"Deleted {result.deleted_count} chunks from MongoDB collection")
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Failed to delete chunks: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def delete(self) -> None:
|
||||||
|
"""Delete the entire collection."""
|
||||||
|
try:
|
||||||
|
self.collection.drop()
|
||||||
|
logger.debug(f"Dropped MongoDB collection: {self.collection.name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Failed to drop collection: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _ensure_text_index(self) -> None:
|
||||||
|
"""Ensure text search index exists on content fields for RAG."""
|
||||||
|
try:
|
||||||
|
indexes = list(self.collection.list_indexes())
|
||||||
|
text_index_exists = any(
|
||||||
|
any(key.startswith(("content", "text")) for key in idx.get("key", {}).keys())
|
||||||
|
and idx.get("textIndexVersion") is not None
|
||||||
|
for idx in indexes
|
||||||
|
)
|
||||||
|
|
||||||
|
if not text_index_exists:
|
||||||
|
logger.info("Creating text search index on content fields for RAG")
|
||||||
|
# Index both 'text' and 'content' fields for comprehensive text search
|
||||||
|
self.collection.create_index([("text", "text"), ("content", "text")])
|
||||||
|
logger.info("Text search index created successfully for RAG")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to create text index for RAG: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class MongoDBVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
|
||||||
|
"""MongoDB Atlas Vector Search adapter for Llama Stack optimized for RAG workflows."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: MongoDBVectorIOConfig,
|
||||||
|
inference_api,
|
||||||
|
files_api=None,
|
||||||
|
models_api=None,
|
||||||
|
) -> None:
|
||||||
|
# Handle the case where files_api might be a ProviderSpec that needs resolution
|
||||||
|
resolved_files_api = files_api
|
||||||
|
super().__init__(files_api=resolved_files_api, kvstore=None)
|
||||||
|
self.config = config
|
||||||
|
self.inference_api = inference_api
|
||||||
|
self.models_api = models_api
|
||||||
|
self.client: MongoClient | None = None
|
||||||
|
self.database: Database | None = None
|
||||||
|
self.cache: dict[str, VectorStoreWithIndex] = {}
|
||||||
|
self.kvstore: KVStore | None = None
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
"""Initialize MongoDB connection optimized for RAG workflows."""
|
||||||
|
logger.info("Initializing MongoDB Atlas Vector IO adapter for RAG")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initialize KV store for metadata
|
||||||
|
if self.config.persistence:
|
||||||
|
self.kvstore = await kvstore_impl(self.config.persistence)
|
||||||
|
|
||||||
|
# Skip MongoDB connection if no connection string provided
|
||||||
|
# This allows other providers to work without MongoDB credentials
|
||||||
|
if not self.config.connection_string:
|
||||||
|
logger.warning(
|
||||||
|
"MongoDB connection_string not provided. "
|
||||||
|
"MongoDB vector store will not be available until credentials are configured."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Connect to MongoDB with optimized settings for RAG
|
||||||
|
self.client = MongoClient(
|
||||||
|
self.config.connection_string,
|
||||||
|
server_api=ServerApi("1"),
|
||||||
|
maxPoolSize=self.config.max_pool_size,
|
||||||
|
serverSelectionTimeoutMS=self.config.timeout_ms,
|
||||||
|
# Additional settings for RAG performance
|
||||||
|
retryWrites=True,
|
||||||
|
readPreference="primaryPreferred",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test connection
|
||||||
|
self.client.admin.command("ping")
|
||||||
|
logger.info("Successfully connected to MongoDB Atlas for RAG")
|
||||||
|
|
||||||
|
# Get database
|
||||||
|
self.database = self.client[self.config.database_name]
|
||||||
|
|
||||||
|
# Initialize OpenAI vector stores
|
||||||
|
await self.initialize_openai_vector_stores()
|
||||||
|
|
||||||
|
# Load existing vector databases
|
||||||
|
await self._load_existing_vector_dbs()
|
||||||
|
|
||||||
|
logger.info("MongoDB Atlas Vector IO adapter for RAG initialized successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to initialize MongoDB Atlas Vector IO adapter for RAG")
|
||||||
|
raise RuntimeError("Failed to initialize MongoDB Atlas Vector IO adapter for RAG") from e
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
"""Shutdown MongoDB connection."""
|
||||||
|
if self.client:
|
||||||
|
self.client.close()
|
||||||
|
logger.info("MongoDB Atlas RAG connection closed")
|
||||||
|
|
||||||
|
async def health(self) -> HealthResponse:
|
||||||
|
"""Perform health check on MongoDB connection."""
|
||||||
|
try:
|
||||||
|
if self.client:
|
||||||
|
self.client.admin.command("ping")
|
||||||
|
return HealthResponse(status=HealthStatus.OK)
|
||||||
|
else:
|
||||||
|
return HealthResponse(status=HealthStatus.ERROR, message="MongoDB client not initialized")
|
||||||
|
except Exception as e:
|
||||||
|
return HealthResponse(
|
||||||
|
status=HealthStatus.ERROR,
|
||||||
|
message=f"MongoDB RAG health check failed: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
||||||
|
"""Register a new vector store optimized for RAG."""
|
||||||
|
if self.database is None:
|
||||||
|
raise RuntimeError("MongoDB database not initialized")
|
||||||
|
|
||||||
|
# Create collection name from vector store identifier
|
||||||
|
collection_name = sanitize_collection_name(vector_store.identifier)
|
||||||
|
collection = self.database[collection_name]
|
||||||
|
|
||||||
|
# Create and initialize MongoDB index optimized for RAG
|
||||||
|
mongodb_index = MongoDBIndex(vector_store, collection, self.config)
|
||||||
|
await mongodb_index.initialize()
|
||||||
|
|
||||||
|
# Create vector store with index wrapper
|
||||||
|
vector_store_with_index = VectorStoreWithIndex(
|
||||||
|
vector_store=vector_store,
|
||||||
|
index=mongodb_index,
|
||||||
|
inference_api=self.inference_api,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cache the vector store
|
||||||
|
self.cache[vector_store.identifier] = vector_store_with_index
|
||||||
|
|
||||||
|
# Save vector store info to KVStore for persistence
|
||||||
|
if self.kvstore:
|
||||||
|
await self.kvstore.set(
|
||||||
|
f"{VECTOR_DBS_PREFIX}{vector_store.identifier}",
|
||||||
|
vector_store.model_dump_json(),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Registered vector store for RAG: {vector_store.identifier}")
|
||||||
|
|
||||||
|
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||||
|
"""Unregister a vector store."""
|
||||||
|
if vector_store_id in self.cache:
|
||||||
|
await self.cache[vector_store_id].index.delete()
|
||||||
|
del self.cache[vector_store_id]
|
||||||
|
|
||||||
|
# Clean up from KV store
|
||||||
|
if self.kvstore:
|
||||||
|
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_store_id}")
|
||||||
|
|
||||||
|
logger.info(f"Unregistered vector store: {vector_store_id}")
|
||||||
|
|
||||||
|
async def insert_chunks(
|
||||||
|
self,
|
||||||
|
vector_db_id: str,
|
||||||
|
chunks: list[Chunk],
|
||||||
|
ttl_seconds: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Insert chunks into the vector database optimized for RAG."""
|
||||||
|
vector_db_with_index = await self._get_vector_db_index(vector_db_id)
|
||||||
|
await vector_db_with_index.insert_chunks(chunks)
|
||||||
|
|
||||||
|
async def query_chunks(
|
||||||
|
self,
|
||||||
|
vector_db_id: str,
|
||||||
|
query: InterleavedContent,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
"""Query chunks from the vector database optimized for RAG context retrieval."""
|
||||||
|
vector_db_with_index = await self._get_vector_db_index(vector_db_id)
|
||||||
|
return await vector_db_with_index.query_chunks(query, params)
|
||||||
|
|
||||||
|
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||||
|
"""Delete chunks from the vector database."""
|
||||||
|
vector_db_with_index = await self._get_vector_db_index(store_id)
|
||||||
|
await vector_db_with_index.index.delete_chunks(chunks_for_deletion)
|
||||||
|
|
||||||
|
async def _get_vector_db_index(self, vector_db_id: str) -> VectorStoreWithIndex:
|
||||||
|
"""Get vector store index from cache."""
|
||||||
|
if vector_db_id in self.cache:
|
||||||
|
return self.cache[vector_db_id]
|
||||||
|
|
||||||
|
raise VectorStoreNotFoundError(vector_db_id)
|
||||||
|
|
||||||
|
async def _load_existing_vector_dbs(self) -> None:
|
||||||
|
"""Load existing vector databases from KVStore."""
|
||||||
|
if not self.kvstore:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use keys_in_range to get all vector database keys from KVStore
|
||||||
|
# This searches for keys with the prefix by using range scan
|
||||||
|
start_key = VECTOR_DBS_PREFIX
|
||||||
|
# Create an end key by incrementing the last character
|
||||||
|
end_key = VECTOR_DBS_PREFIX[:-1] + chr(ord(VECTOR_DBS_PREFIX[-1]) + 1)
|
||||||
|
|
||||||
|
vector_db_keys = await self.kvstore.keys_in_range(start_key, end_key)
|
||||||
|
|
||||||
|
for key in vector_db_keys:
|
||||||
|
try:
|
||||||
|
vector_store_data = await self.kvstore.get(key)
|
||||||
|
if vector_store_data:
|
||||||
|
import json
|
||||||
|
|
||||||
|
vector_store = VectorStore(**json.loads(vector_store_data))
|
||||||
|
# Register the vector store without re-initializing
|
||||||
|
await self._register_existing_vector_store(vector_store)
|
||||||
|
logger.info(f"Loaded existing RAG-optimized vector store: {vector_store.identifier}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load vector store from key {key}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load existing vector stores: {e}")
|
||||||
|
|
||||||
|
async def _register_existing_vector_store(self, vector_store: VectorStore) -> None:
|
||||||
|
"""Register an existing vector store without re-initialization."""
|
||||||
|
if self.database is None:
|
||||||
|
raise RuntimeError("MongoDB database not initialized")
|
||||||
|
|
||||||
|
# Create collection name from vector store identifier
|
||||||
|
collection_name = sanitize_collection_name(vector_store.identifier)
|
||||||
|
collection = self.database[collection_name]
|
||||||
|
|
||||||
|
# Create MongoDB index without initialization (collection already exists)
|
||||||
|
mongodb_index = MongoDBIndex(vector_store, collection, self.config)
|
||||||
|
|
||||||
|
# Create vector store with index wrapper
|
||||||
|
vector_store_with_index = VectorStoreWithIndex(
|
||||||
|
vector_store=vector_store,
|
||||||
|
index=mongodb_index,
|
||||||
|
inference_api=self.inference_api,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cache the vector store
|
||||||
|
self.cache[vector_store.identifier] = vector_store_with_index
|
||||||
|
|
@ -25,6 +25,7 @@ distribution_spec:
|
||||||
- provider_type: inline::milvus
|
- provider_type: inline::milvus
|
||||||
- provider_type: remote::chromadb
|
- provider_type: remote::chromadb
|
||||||
- provider_type: remote::pgvector
|
- provider_type: remote::pgvector
|
||||||
|
- provider_type: remote::mongodb
|
||||||
- provider_type: remote::qdrant
|
- provider_type: remote::qdrant
|
||||||
- provider_type: remote::weaviate
|
- provider_type: remote::weaviate
|
||||||
files:
|
files:
|
||||||
|
|
|
||||||
|
|
@ -131,6 +131,23 @@ providers:
|
||||||
persistence:
|
persistence:
|
||||||
namespace: vector_io::pgvector
|
namespace: vector_io::pgvector
|
||||||
backend: kv_default
|
backend: kv_default
|
||||||
|
- provider_id: ${env.MONGODB_CONNECTION_STRING:+mongodb_atlas}
|
||||||
|
provider_type: remote::mongodb
|
||||||
|
config:
|
||||||
|
connection_string: ${env.MONGODB_CONNECTION_STRING:=}
|
||||||
|
host: ${env.MONGODB_HOST:=localhost}
|
||||||
|
port: ${env.MONGODB_PORT:=27017}
|
||||||
|
username: ${env.MONGODB_USERNAME:=}
|
||||||
|
password: ${env.MONGODB_PASSWORD:=}
|
||||||
|
database_name: ${env.MONGODB_DATABASE_NAME:=llama_stack}
|
||||||
|
index_name: ${env.MONGODB_INDEX_NAME:=vector_index}
|
||||||
|
path_field: ${env.MONGODB_PATH_FIELD:=embedding}
|
||||||
|
similarity_metric: ${env.MONGODB_SIMILARITY_METRIC:=cosine}
|
||||||
|
max_pool_size: ${env.MONGODB_MAX_POOL_SIZE:=100}
|
||||||
|
timeout_ms: ${env.MONGODB_TIMEOUT_MS:=30000}
|
||||||
|
persistence:
|
||||||
|
namespace: vector_io::mongodb_atlas
|
||||||
|
backend: kv_default
|
||||||
- provider_id: ${env.QDRANT_URL:+qdrant}
|
- provider_id: ${env.QDRANT_URL:+qdrant}
|
||||||
provider_type: remote::qdrant
|
provider_type: remote::qdrant
|
||||||
config:
|
config:
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ distribution_spec:
|
||||||
- provider_type: inline::milvus
|
- provider_type: inline::milvus
|
||||||
- provider_type: remote::chromadb
|
- provider_type: remote::chromadb
|
||||||
- provider_type: remote::pgvector
|
- provider_type: remote::pgvector
|
||||||
|
- provider_type: remote::mongodb
|
||||||
- provider_type: remote::qdrant
|
- provider_type: remote::qdrant
|
||||||
- provider_type: remote::weaviate
|
- provider_type: remote::weaviate
|
||||||
files:
|
files:
|
||||||
|
|
|
||||||
|
|
@ -131,6 +131,23 @@ providers:
|
||||||
persistence:
|
persistence:
|
||||||
namespace: vector_io::pgvector
|
namespace: vector_io::pgvector
|
||||||
backend: kv_default
|
backend: kv_default
|
||||||
|
- provider_id: ${env.MONGODB_CONNECTION_STRING:+mongodb_atlas}
|
||||||
|
provider_type: remote::mongodb
|
||||||
|
config:
|
||||||
|
connection_string: ${env.MONGODB_CONNECTION_STRING:=}
|
||||||
|
host: ${env.MONGODB_HOST:=localhost}
|
||||||
|
port: ${env.MONGODB_PORT:=27017}
|
||||||
|
username: ${env.MONGODB_USERNAME:=}
|
||||||
|
password: ${env.MONGODB_PASSWORD:=}
|
||||||
|
database_name: ${env.MONGODB_DATABASE_NAME:=llama_stack}
|
||||||
|
index_name: ${env.MONGODB_INDEX_NAME:=vector_index}
|
||||||
|
path_field: ${env.MONGODB_PATH_FIELD:=embedding}
|
||||||
|
similarity_metric: ${env.MONGODB_SIMILARITY_METRIC:=cosine}
|
||||||
|
max_pool_size: ${env.MONGODB_MAX_POOL_SIZE:=100}
|
||||||
|
timeout_ms: ${env.MONGODB_TIMEOUT_MS:=30000}
|
||||||
|
persistence:
|
||||||
|
namespace: vector_io::mongodb_atlas
|
||||||
|
backend: kv_default
|
||||||
- provider_id: ${env.QDRANT_URL:+qdrant}
|
- provider_id: ${env.QDRANT_URL:+qdrant}
|
||||||
provider_type: remote::qdrant
|
provider_type: remote::qdrant
|
||||||
config:
|
config:
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ distribution_spec:
|
||||||
- provider_type: inline::milvus
|
- provider_type: inline::milvus
|
||||||
- provider_type: remote::chromadb
|
- provider_type: remote::chromadb
|
||||||
- provider_type: remote::pgvector
|
- provider_type: remote::pgvector
|
||||||
|
- provider_type: remote::mongodb
|
||||||
- provider_type: remote::qdrant
|
- provider_type: remote::qdrant
|
||||||
- provider_type: remote::weaviate
|
- provider_type: remote::weaviate
|
||||||
files:
|
files:
|
||||||
|
|
|
||||||
|
|
@ -131,6 +131,23 @@ providers:
|
||||||
persistence:
|
persistence:
|
||||||
namespace: vector_io::pgvector
|
namespace: vector_io::pgvector
|
||||||
backend: kv_default
|
backend: kv_default
|
||||||
|
- provider_id: ${env.MONGODB_CONNECTION_STRING:+mongodb_atlas}
|
||||||
|
provider_type: remote::mongodb
|
||||||
|
config:
|
||||||
|
connection_string: ${env.MONGODB_CONNECTION_STRING:=}
|
||||||
|
host: ${env.MONGODB_HOST:=localhost}
|
||||||
|
port: ${env.MONGODB_PORT:=27017}
|
||||||
|
username: ${env.MONGODB_USERNAME:=}
|
||||||
|
password: ${env.MONGODB_PASSWORD:=}
|
||||||
|
database_name: ${env.MONGODB_DATABASE_NAME:=llama_stack}
|
||||||
|
index_name: ${env.MONGODB_INDEX_NAME:=vector_index}
|
||||||
|
path_field: ${env.MONGODB_PATH_FIELD:=embedding}
|
||||||
|
similarity_metric: ${env.MONGODB_SIMILARITY_METRIC:=cosine}
|
||||||
|
max_pool_size: ${env.MONGODB_MAX_POOL_SIZE:=100}
|
||||||
|
timeout_ms: ${env.MONGODB_TIMEOUT_MS:=30000}
|
||||||
|
persistence:
|
||||||
|
namespace: vector_io::mongodb_atlas
|
||||||
|
backend: kv_default
|
||||||
- provider_id: ${env.QDRANT_URL:+qdrant}
|
- provider_id: ${env.QDRANT_URL:+qdrant}
|
||||||
provider_type: remote::qdrant
|
provider_type: remote::qdrant
|
||||||
config:
|
config:
|
||||||
|
|
|
||||||
|
|
@ -36,11 +36,14 @@ from llama_stack.providers.inline.vector_io.sqlite_vec.config import (
|
||||||
)
|
)
|
||||||
from llama_stack.providers.registry.inference import available_providers
|
from llama_stack.providers.registry.inference import available_providers
|
||||||
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
|
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
|
||||||
|
from llama_stack.providers.remote.vector_io.mongodb.config import MongoDBVectorIOConfig
|
||||||
from llama_stack.providers.remote.vector_io.pgvector.config import (
|
from llama_stack.providers.remote.vector_io.pgvector.config import (
|
||||||
PGVectorVectorIOConfig,
|
PGVectorVectorIOConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.vector_io.qdrant.config import QdrantVectorIOConfig
|
from llama_stack.providers.remote.vector_io.qdrant.config import QdrantVectorIOConfig
|
||||||
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
|
from llama_stack.providers.remote.vector_io.weaviate.config import (
|
||||||
|
WeaviateVectorIOConfig,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.kvstore.config import PostgresKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import PostgresKVStoreConfig
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
|
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
|
||||||
|
|
||||||
|
|
@ -124,6 +127,7 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
||||||
BuildProvider(provider_type="inline::milvus"),
|
BuildProvider(provider_type="inline::milvus"),
|
||||||
BuildProvider(provider_type="remote::chromadb"),
|
BuildProvider(provider_type="remote::chromadb"),
|
||||||
BuildProvider(provider_type="remote::pgvector"),
|
BuildProvider(provider_type="remote::pgvector"),
|
||||||
|
BuildProvider(provider_type="remote::mongodb"),
|
||||||
BuildProvider(provider_type="remote::qdrant"),
|
BuildProvider(provider_type="remote::qdrant"),
|
||||||
BuildProvider(provider_type="remote::weaviate"),
|
BuildProvider(provider_type="remote::weaviate"),
|
||||||
],
|
],
|
||||||
|
|
@ -254,7 +258,70 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
||||||
additional_pip_packages=list(set(PostgresSqlStoreConfig.pip_packages() + PostgresKVStoreConfig.pip_packages())),
|
additional_pip_packages=list(set(PostgresSqlStoreConfig.pip_packages() + PostgresKVStoreConfig.pip_packages())),
|
||||||
run_configs={
|
run_configs={
|
||||||
"run.yaml": RunConfigSettings(
|
"run.yaml": RunConfigSettings(
|
||||||
provider_overrides=default_overrides,
|
provider_overrides={
|
||||||
|
"inference": remote_inference_providers + [embedding_provider],
|
||||||
|
"vector_io": [
|
||||||
|
Provider(
|
||||||
|
provider_id="faiss",
|
||||||
|
provider_type="inline::faiss",
|
||||||
|
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||||
|
),
|
||||||
|
Provider(
|
||||||
|
provider_id="sqlite-vec",
|
||||||
|
provider_type="inline::sqlite-vec",
|
||||||
|
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||||
|
),
|
||||||
|
Provider(
|
||||||
|
provider_id="${env.MILVUS_URL:+milvus}",
|
||||||
|
provider_type="inline::milvus",
|
||||||
|
config=MilvusVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||||
|
),
|
||||||
|
Provider(
|
||||||
|
provider_id="${env.CHROMADB_URL:+chromadb}",
|
||||||
|
provider_type="remote::chromadb",
|
||||||
|
config=ChromaVectorIOConfig.sample_run_config(
|
||||||
|
f"~/.llama/distributions/{name}/",
|
||||||
|
url="${env.CHROMADB_URL:=}",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Provider(
|
||||||
|
provider_id="${env.PGVECTOR_DB:+pgvector}",
|
||||||
|
provider_type="remote::pgvector",
|
||||||
|
config=PGVectorVectorIOConfig.sample_run_config(
|
||||||
|
f"~/.llama/distributions/{name}",
|
||||||
|
db="${env.PGVECTOR_DB:=}",
|
||||||
|
user="${env.PGVECTOR_USER:=}",
|
||||||
|
password="${env.PGVECTOR_PASSWORD:=}",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Provider(
|
||||||
|
provider_id="${env.MONGODB_CONNECTION_STRING:+mongodb_atlas}",
|
||||||
|
provider_type="remote::mongodb",
|
||||||
|
config=MongoDBVectorIOConfig.sample_run_config(
|
||||||
|
f"~/.llama/distributions/{name}",
|
||||||
|
connection_string="${env.MONGODB_CONNECTION_STRING:=}",
|
||||||
|
database_name="${env.MONGODB_DATABASE_NAME:=llama_stack}",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Provider(
|
||||||
|
provider_id="${env.QDRANT_URL:+qdrant}",
|
||||||
|
provider_type="remote::qdrant",
|
||||||
|
config=QdrantVectorIOConfig.sample_run_config(
|
||||||
|
f"~/.llama/distributions/{name}",
|
||||||
|
url="${env.QDRANT_URL:=}",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Provider(
|
||||||
|
provider_id="${env.WEAVIATE_CLUSTER_URL:+weaviate}",
|
||||||
|
provider_type="remote::weaviate",
|
||||||
|
config=WeaviateVectorIOConfig.sample_run_config(
|
||||||
|
f"~/.llama/distributions/{name}",
|
||||||
|
cluster_url="${env.WEAVIATE_CLUSTER_URL:=}",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
"files": [files_provider],
|
||||||
|
},
|
||||||
default_models=[],
|
default_models=[],
|
||||||
default_tool_groups=default_tool_groups,
|
default_tool_groups=default_tool_groups,
|
||||||
default_shields=default_shields,
|
default_shields=default_shields,
|
||||||
|
|
@ -384,5 +451,13 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
||||||
"azure",
|
"azure",
|
||||||
"Azure API Type",
|
"Azure API Type",
|
||||||
),
|
),
|
||||||
|
"MONGODB_CONNECTION_STRING": (
|
||||||
|
"",
|
||||||
|
"MongoDB Atlas connection string (e.g., mongodb+srv://user:pass@cluster.mongodb.net/)",
|
||||||
|
),
|
||||||
|
"MONGODB_DATABASE_NAME": (
|
||||||
|
"llama_stack",
|
||||||
|
"MongoDB database name",
|
||||||
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -823,6 +823,132 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
|
||||||
optional_api_dependencies=[Api.files, Api.models],
|
optional_api_dependencies=[Api.files, Api.models],
|
||||||
description="""
|
description="""
|
||||||
Please refer to the remote provider documentation.
|
Please refer to the remote provider documentation.
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
RemoteProviderSpec(
|
||||||
|
api=Api.vector_io,
|
||||||
|
adapter_type="mongodb",
|
||||||
|
provider_type="remote::mongodb",
|
||||||
|
pip_packages=["pymongo>=4.0.0"],
|
||||||
|
module="llama_stack.providers.remote.vector_io.mongodb",
|
||||||
|
config_class="llama_stack.providers.remote.vector_io.mongodb.MongoDBVectorIOConfig",
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
|
optional_api_dependencies=[Api.files],
|
||||||
|
description="""
|
||||||
|
[MongoDB Atlas](https://www.mongodb.com/products/platform/atlas-vector-search) is a remote vector database provider for Llama Stack. It
|
||||||
|
uses MongoDB Atlas Vector Search to store and query vectors in the cloud.
|
||||||
|
That means you get enterprise-grade vector search with MongoDB's scalability and reliability.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Cloud-native vector search with MongoDB Atlas
|
||||||
|
- Fully integrated with Llama Stack
|
||||||
|
- Enterprise-grade security and scalability
|
||||||
|
- Supports multiple search modes: vector, keyword, and hybrid search
|
||||||
|
- Built-in metadata filtering and text search capabilities
|
||||||
|
- Automatic index management
|
||||||
|
|
||||||
|
## Search Modes
|
||||||
|
|
||||||
|
MongoDB Atlas Vector Search supports three different search modes:
|
||||||
|
|
||||||
|
### Vector Search
|
||||||
|
Vector search uses MongoDB's `$vectorSearch` aggregation stage to perform semantic similarity search using embedding vectors.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Vector search example
|
||||||
|
search_response = client.vector_stores.search(
|
||||||
|
vector_store_id=vector_store.id,
|
||||||
|
query="What is machine learning?",
|
||||||
|
search_mode="vector",
|
||||||
|
max_num_results=5,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Keyword Search
|
||||||
|
Keyword search uses MongoDB's text search capabilities with full-text indexes to find chunks containing specific terms.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Keyword search example
|
||||||
|
search_response = client.vector_stores.search(
|
||||||
|
vector_store_id=vector_store.id,
|
||||||
|
query="Python programming language",
|
||||||
|
search_mode="keyword",
|
||||||
|
max_num_results=5,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Hybrid Search
|
||||||
|
Hybrid search combines both vector and keyword search methods using configurable reranking algorithms.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Hybrid search with RRF ranker (default)
|
||||||
|
search_response = client.vector_stores.search(
|
||||||
|
vector_store_id=vector_store.id,
|
||||||
|
query="neural networks in Python",
|
||||||
|
search_mode="hybrid",
|
||||||
|
max_num_results=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Hybrid search with weighted ranker
|
||||||
|
search_response = client.vector_stores.search(
|
||||||
|
vector_store_id=vector_store.id,
|
||||||
|
query="neural networks in Python",
|
||||||
|
search_mode="hybrid",
|
||||||
|
max_num_results=5,
|
||||||
|
ranking_options={
|
||||||
|
"ranker": {
|
||||||
|
"type": "weighted",
|
||||||
|
"alpha": 0.7, # 70% vector search, 30% keyword search
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use MongoDB Atlas in your Llama Stack project, follow these steps:
|
||||||
|
|
||||||
|
1. Create a MongoDB Atlas cluster with Vector Search enabled
|
||||||
|
2. Install the necessary dependencies
|
||||||
|
3. Configure your Llama Stack project to use MongoDB
|
||||||
|
4. Start storing and querying vectors
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
Set up the following environment variable for your MongoDB Atlas connection:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export MONGODB_CONNECTION_STRING="mongodb+srv://username:password@cluster.mongodb.net/?retryWrites=true&w=majority&appName=llama-stack"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Configuration Example
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
vector_io:
|
||||||
|
- provider_id: mongodb_atlas
|
||||||
|
provider_type: remote::mongodb
|
||||||
|
config:
|
||||||
|
connection_string: "${env.MONGODB_CONNECTION_STRING}"
|
||||||
|
database_name: "llama_stack"
|
||||||
|
index_name: "vector_index"
|
||||||
|
similarity_metric: "cosine"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
You can install the MongoDB Python driver using pip:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install pymongo
|
||||||
|
```
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
|
||||||
|
See [MongoDB Atlas Vector Search documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-overview/) for more details about MongoDB Atlas Vector Search.
|
||||||
|
|
||||||
|
For general MongoDB documentation, visit [MongoDB Documentation](https://docs.mongodb.com/).
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,20 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
|
from .config import MongoDBVectorIOConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: MongoDBVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||||
|
from .mongodb import MongoDBVectorIOAdapter
|
||||||
|
|
||||||
|
# Handle the deps resolution - if files API exists, pass it, otherwise None
|
||||||
|
files_api = deps.get(Api.files)
|
||||||
|
models_api = deps.get(Api.models)
|
||||||
|
impl = MongoDBVectorIOAdapter(config, deps[Api.inference], files_api, models_api)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
110
src/llama_stack/providers/remote/vector_io/mongodb/config.py
Normal file
110
src/llama_stack/providers/remote/vector_io/mongodb/config.py
Normal file
|
|
@ -0,0 +1,110 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class MongoDBVectorIOConfig(BaseModel):
|
||||||
|
"""Configuration for MongoDB Atlas Vector Search provider.
|
||||||
|
|
||||||
|
This provider connects to MongoDB Atlas and uses Vector Search for RAG operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# MongoDB connection details - either connection_string or individual parameters
|
||||||
|
connection_string: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="MongoDB connection string (e.g., mongodb://user:pass@localhost:27017/ or mongodb+srv://user:pass@cluster.mongodb.net/)",
|
||||||
|
)
|
||||||
|
host: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="MongoDB host (used if connection_string is not provided)",
|
||||||
|
)
|
||||||
|
port: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="MongoDB port (used if connection_string is not provided)",
|
||||||
|
)
|
||||||
|
username: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="MongoDB username (used if connection_string is not provided)",
|
||||||
|
)
|
||||||
|
password: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="MongoDB password (used if connection_string is not provided)",
|
||||||
|
)
|
||||||
|
database_name: str = Field(default="llama_stack", description="Database name to use for vector collections")
|
||||||
|
|
||||||
|
# Vector search configuration
|
||||||
|
index_name: str = Field(default="vector_index", description="Name of the vector search index")
|
||||||
|
path_field: str = Field(default="embedding", description="Field name for storing embeddings")
|
||||||
|
similarity_metric: str = Field(
|
||||||
|
default="cosine",
|
||||||
|
description="Similarity metric: cosine, euclidean, or dotProduct",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Connection options
|
||||||
|
max_pool_size: int = Field(default=100, description="Maximum connection pool size")
|
||||||
|
timeout_ms: int = Field(default=30000, description="Connection timeout in milliseconds")
|
||||||
|
|
||||||
|
# KV store configuration
|
||||||
|
persistence: KVStoreReference | None = Field(
|
||||||
|
description="Config for KV store backend for metadata storage", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_connection_string(self) -> str | None:
|
||||||
|
"""Build connection string from individual parameters if not provided directly.
|
||||||
|
|
||||||
|
If both connection_string and individual parameters (host/port) are provided,
|
||||||
|
individual parameters take precedence to allow test environment overrides.
|
||||||
|
"""
|
||||||
|
# Prioritize individual connection parameters over connection_string
|
||||||
|
# This allows test environments to override with MONGODB_HOST/PORT/etc
|
||||||
|
if self.host and self.port:
|
||||||
|
auth_part = ""
|
||||||
|
if self.username and self.password:
|
||||||
|
auth_part = f"{self.username}:{self.password}@"
|
||||||
|
return f"mongodb://{auth_part}{self.host}:{self.port}/"
|
||||||
|
|
||||||
|
# Fall back to connection_string if provided
|
||||||
|
if self.connection_string:
|
||||||
|
return self.connection_string
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(
|
||||||
|
cls,
|
||||||
|
__distro_dir__: str,
|
||||||
|
connection_string: str = "${env.MONGODB_CONNECTION_STRING:=}",
|
||||||
|
host: str = "${env.MONGODB_HOST:=localhost}",
|
||||||
|
port: str = "${env.MONGODB_PORT:=27017}",
|
||||||
|
username: str = "${env.MONGODB_USERNAME:=}",
|
||||||
|
password: str = "${env.MONGODB_PASSWORD:=}",
|
||||||
|
database_name: str = "${env.MONGODB_DATABASE_NAME:=llama_stack}",
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"connection_string": connection_string,
|
||||||
|
"host": host,
|
||||||
|
"port": port,
|
||||||
|
"username": username,
|
||||||
|
"password": password,
|
||||||
|
"database_name": database_name,
|
||||||
|
"index_name": "${env.MONGODB_INDEX_NAME:=vector_index}",
|
||||||
|
"path_field": "${env.MONGODB_PATH_FIELD:=embedding}",
|
||||||
|
"similarity_metric": "${env.MONGODB_SIMILARITY_METRIC:=cosine}",
|
||||||
|
"max_pool_size": "${env.MONGODB_MAX_POOL_SIZE:=100}",
|
||||||
|
"timeout_ms": "${env.MONGODB_TIMEOUT_MS:=30000}",
|
||||||
|
"persistence": KVStoreReference(
|
||||||
|
backend="kv_default",
|
||||||
|
namespace="vector_io::mongodb_atlas",
|
||||||
|
).model_dump(exclude_none=True),
|
||||||
|
}
|
||||||
631
src/llama_stack/providers/remote/vector_io/mongodb/mongodb.py
Normal file
631
src/llama_stack/providers/remote/vector_io/mongodb/mongodb.py
Normal file
|
|
@ -0,0 +1,631 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import heapq
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
from pymongo import MongoClient
|
||||||
|
from pymongo.collection import Collection
|
||||||
|
from pymongo.database import Database
|
||||||
|
from pymongo.operations import SearchIndexModel
|
||||||
|
from pymongo.server_api import ServerApi
|
||||||
|
|
||||||
|
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||||
|
from llama_stack.apis.inference import InterleavedContent
|
||||||
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
|
from llama_stack.apis.vector_stores import VectorStore
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.datatypes import (
|
||||||
|
HealthResponse,
|
||||||
|
HealthStatus,
|
||||||
|
VectorStoresProtocolPrivate,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
interleaved_content_as_str,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||||
|
from llama_stack.providers.utils.memory.openai_vector_store_mixin import (
|
||||||
|
OpenAIVectorStoreMixin,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
ChunkForDeletion,
|
||||||
|
EmbeddingIndex,
|
||||||
|
VectorStoreWithIndex,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.vector_io.vector_utils import (
|
||||||
|
WeightedInMemoryAggregator,
|
||||||
|
sanitize_collection_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .config import MongoDBVectorIOConfig
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="vector_io::mongodb")
|
||||||
|
|
||||||
|
VERSION = "v1"
|
||||||
|
VECTOR_DBS_PREFIX = f"vector_dbs:mongodb:{VERSION}::"
|
||||||
|
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:mongodb:{VERSION}::"
|
||||||
|
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:mongodb:{VERSION}::"
|
||||||
|
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:mongodb:{VERSION}::"
|
||||||
|
|
||||||
|
|
||||||
|
class MongoDBIndex(EmbeddingIndex):
|
||||||
|
"""MongoDB Atlas Vector Search index implementation optimized for RAG."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vector_store: VectorStore,
|
||||||
|
collection: Collection,
|
||||||
|
config: MongoDBVectorIOConfig,
|
||||||
|
):
|
||||||
|
self.vector_store = vector_store
|
||||||
|
self.collection = collection
|
||||||
|
self.config = config
|
||||||
|
self.dimension = vector_store.embedding_dimension
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
"""Initialize the MongoDB collection and ensure vector search index exists."""
|
||||||
|
try:
|
||||||
|
# Create the collection if it doesn't exist
|
||||||
|
collection_names = self.collection.database.list_collection_names()
|
||||||
|
if self.collection.name not in collection_names:
|
||||||
|
logger.info(f"Creating collection '{self.collection.name}'")
|
||||||
|
# Create collection by inserting a dummy document
|
||||||
|
dummy_doc = {"_id": "__dummy__", "dummy": True}
|
||||||
|
self.collection.insert_one(dummy_doc)
|
||||||
|
# Remove the dummy document
|
||||||
|
self.collection.delete_one({"_id": "__dummy__"})
|
||||||
|
logger.info(f"Collection '{self.collection.name}' created successfully")
|
||||||
|
|
||||||
|
# Create optimized vector search index for RAG
|
||||||
|
await self._create_vector_search_index()
|
||||||
|
|
||||||
|
# Create text index for hybrid search
|
||||||
|
await self._ensure_text_index()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(
|
||||||
|
f"Failed to initialize MongoDB index for vector_store: {self.vector_store.identifier}. "
|
||||||
|
f"Collection name: {self.collection.name}. Error: {str(e)}"
|
||||||
|
)
|
||||||
|
# Don't fail completely - just log the error and continue
|
||||||
|
logger.warning(
|
||||||
|
"Continuing without complete index initialization. "
|
||||||
|
"You may need to create indexes manually in MongoDB Atlas dashboard."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _create_vector_search_index(self) -> None:
|
||||||
|
"""Create optimized vector search index based on MongoDB RAG best practices."""
|
||||||
|
try:
|
||||||
|
# Check if vector search index exists
|
||||||
|
indexes = list(self.collection.list_search_indexes())
|
||||||
|
index_exists = any(idx.get("name") == self.config.index_name for idx in indexes)
|
||||||
|
|
||||||
|
if not index_exists:
|
||||||
|
# Create vector search index optimized for RAG
|
||||||
|
# Based on MongoDB's RAG example using new vectorSearch format
|
||||||
|
search_index_model = SearchIndexModel(
|
||||||
|
definition={
|
||||||
|
"fields": [
|
||||||
|
{
|
||||||
|
"type": "vector",
|
||||||
|
"numDimensions": self.dimension,
|
||||||
|
"path": self.config.path_field,
|
||||||
|
"similarity": self._convert_similarity_metric(self.config.similarity_metric),
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
name=self.config.index_name,
|
||||||
|
type="vectorSearch",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Creating vector search index '{self.config.index_name}' for RAG on collection '{self.collection.name}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.collection.create_search_index(model=search_index_model)
|
||||||
|
|
||||||
|
# Wait for index to be ready (like in MongoDB RAG example)
|
||||||
|
await self._wait_for_index_ready()
|
||||||
|
|
||||||
|
logger.info("Vector search index created and ready for RAG queries")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to create vector search index: {e}")
|
||||||
|
|
||||||
|
def _convert_similarity_metric(self, metric: str) -> str:
|
||||||
|
"""Convert internal similarity metric to MongoDB Atlas format."""
|
||||||
|
metric_map = {
|
||||||
|
"cosine": "cosine",
|
||||||
|
"euclidean": "euclidean",
|
||||||
|
"dotProduct": "dotProduct",
|
||||||
|
"dot_product": "dotProduct",
|
||||||
|
}
|
||||||
|
return metric_map.get(metric, "cosine")
|
||||||
|
|
||||||
|
async def _wait_for_index_ready(self) -> None:
|
||||||
|
"""Wait for the vector search index to be ready, based on MongoDB RAG example."""
|
||||||
|
logger.info("Waiting for vector search index to be ready...")
|
||||||
|
|
||||||
|
max_wait_time = 300 # 5 minutes max wait
|
||||||
|
wait_interval = 5
|
||||||
|
elapsed_time = 0
|
||||||
|
|
||||||
|
while elapsed_time < max_wait_time:
|
||||||
|
try:
|
||||||
|
indices = list(self.collection.list_search_indexes(self.config.index_name))
|
||||||
|
if len(indices) and indices[0].get("queryable") is True:
|
||||||
|
logger.info(f"Vector search index '{self.config.index_name}' is ready for querying")
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
time.sleep(wait_interval)
|
||||||
|
elapsed_time += wait_interval
|
||||||
|
|
||||||
|
logger.warning(f"Vector search index may not be fully ready after {max_wait_time}s")
|
||||||
|
|
||||||
|
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray) -> None:
|
||||||
|
"""Add chunks with embeddings to MongoDB collection optimized for RAG."""
|
||||||
|
if len(chunks) != len(embeddings):
|
||||||
|
raise ValueError(f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}")
|
||||||
|
|
||||||
|
documents = []
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
# Structure document for optimal RAG retrieval
|
||||||
|
doc = {
|
||||||
|
"_id": chunk.chunk_id,
|
||||||
|
"chunk_id": chunk.chunk_id,
|
||||||
|
"text": interleaved_content_as_str(chunk.content), # Key field for RAG context
|
||||||
|
"content": interleaved_content_as_str(chunk.content), # Backward compatibility
|
||||||
|
"metadata": chunk.metadata or {},
|
||||||
|
"chunk_metadata": (chunk.chunk_metadata.model_dump() if chunk.chunk_metadata else {}),
|
||||||
|
self.config.path_field: embeddings[i].tolist(), # Vector embeddings
|
||||||
|
"document": chunk.model_dump(), # Full chunk data
|
||||||
|
}
|
||||||
|
documents.append(doc)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use upsert behavior for chunks
|
||||||
|
for doc in documents:
|
||||||
|
self.collection.replace_one({"_id": doc["_id"]}, doc, upsert=True)
|
||||||
|
|
||||||
|
logger.debug(f"Successfully added {len(chunks)} chunks optimized for RAG to MongoDB collection")
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Failed to add chunks to MongoDB collection: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def query_vector(
|
||||||
|
self,
|
||||||
|
embedding: NDArray,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
"""Perform vector similarity search optimized for RAG based on MongoDB example."""
|
||||||
|
try:
|
||||||
|
# Use MongoDB's vector search aggregation pipeline optimized for RAG
|
||||||
|
pipeline = [
|
||||||
|
{
|
||||||
|
"$vectorSearch": {
|
||||||
|
"index": self.config.index_name,
|
||||||
|
"queryVector": embedding.tolist(),
|
||||||
|
"path": self.config.path_field,
|
||||||
|
"numCandidates": min(k * 10, 1000), # Cap at 1000 to prevent excessive candidates
|
||||||
|
"limit": k,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$project": {
|
||||||
|
"_id": 0,
|
||||||
|
"text": 1, # Primary field for RAG context
|
||||||
|
"content": 1, # Backward compatibility
|
||||||
|
"metadata": 1,
|
||||||
|
"chunk_metadata": 1,
|
||||||
|
"document": 1,
|
||||||
|
"score": {"$meta": "vectorSearchScore"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"$match": {"score": {"$gte": score_threshold}}},
|
||||||
|
]
|
||||||
|
|
||||||
|
results = list(self.collection.aggregate(pipeline))
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
scores = []
|
||||||
|
for result in results:
|
||||||
|
score = result.get("score", 0.0)
|
||||||
|
if score >= score_threshold:
|
||||||
|
chunk_data = result.get("document", {})
|
||||||
|
if chunk_data:
|
||||||
|
chunks.append(Chunk(**chunk_data))
|
||||||
|
scores.append(float(score))
|
||||||
|
|
||||||
|
logger.debug(f"Vector search for RAG returned {len(chunks)} results")
|
||||||
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Vector search for RAG failed: {e}")
|
||||||
|
raise RuntimeError(f"Vector search for RAG failed: {e}") from e
|
||||||
|
|
||||||
|
async def query_keyword(
|
||||||
|
self,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
"""Perform text search using MongoDB's text search for RAG context retrieval."""
|
||||||
|
try:
|
||||||
|
# Ensure text index exists
|
||||||
|
await self._ensure_text_index()
|
||||||
|
|
||||||
|
pipeline: list[dict[str, Any]] = [
|
||||||
|
{"$match": {"$text": {"$search": query_string}}},
|
||||||
|
{
|
||||||
|
"$project": {
|
||||||
|
"_id": 0,
|
||||||
|
"text": 1, # Primary field for RAG context
|
||||||
|
"content": 1, # Backward compatibility
|
||||||
|
"metadata": 1,
|
||||||
|
"chunk_metadata": 1,
|
||||||
|
"document": 1,
|
||||||
|
"score": {"$meta": "textScore"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"$match": {"score": {"$gte": score_threshold}}},
|
||||||
|
{"$sort": {"score": {"$meta": "textScore"}}},
|
||||||
|
{"$limit": k},
|
||||||
|
]
|
||||||
|
|
||||||
|
results = list(self.collection.aggregate(pipeline))
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
scores = []
|
||||||
|
for result in results:
|
||||||
|
score = result.get("score", 0.0)
|
||||||
|
if score >= score_threshold:
|
||||||
|
chunk_data = result.get("document", {})
|
||||||
|
if chunk_data:
|
||||||
|
chunks.append(Chunk(**chunk_data))
|
||||||
|
scores.append(float(score))
|
||||||
|
|
||||||
|
logger.debug(f"Keyword search for RAG returned {len(chunks)} results")
|
||||||
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Keyword search for RAG failed: {e}")
|
||||||
|
raise RuntimeError(f"Keyword search for RAG failed: {e}") from e
|
||||||
|
|
||||||
|
async def query_hybrid(
|
||||||
|
self,
|
||||||
|
embedding: NDArray,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
reranker_type: str,
|
||||||
|
reranker_params: dict[str, Any] | None = None,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
"""Perform hybrid search for enhanced RAG context retrieval."""
|
||||||
|
if reranker_params is None:
|
||||||
|
reranker_params = {}
|
||||||
|
|
||||||
|
# Get results from both search methods
|
||||||
|
vector_response = await self.query_vector(embedding, k, 0.0)
|
||||||
|
keyword_response = await self.query_keyword(query_string, k, 0.0)
|
||||||
|
|
||||||
|
# Convert responses to score dictionaries
|
||||||
|
vector_scores = {
|
||||||
|
chunk.chunk_id: score for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False)
|
||||||
|
}
|
||||||
|
keyword_scores = {
|
||||||
|
chunk.chunk_id: score
|
||||||
|
for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Combine scores using the reranking utility
|
||||||
|
combined_scores = WeightedInMemoryAggregator.combine_search_results(
|
||||||
|
vector_scores, keyword_scores, reranker_type, reranker_params
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get top-k results
|
||||||
|
top_k_items = heapq.nlargest(k, combined_scores.items(), key=lambda x: x[1])
|
||||||
|
|
||||||
|
# Filter by score threshold
|
||||||
|
filtered_items = [(doc_id, score) for doc_id, score in top_k_items if score >= score_threshold]
|
||||||
|
|
||||||
|
# Create chunk map
|
||||||
|
chunk_map = {c.chunk_id: c for c in vector_response.chunks + keyword_response.chunks}
|
||||||
|
|
||||||
|
# Build final results
|
||||||
|
chunks = []
|
||||||
|
scores = []
|
||||||
|
for doc_id, score in filtered_items:
|
||||||
|
if doc_id in chunk_map:
|
||||||
|
chunks.append(chunk_map[doc_id])
|
||||||
|
scores.append(score)
|
||||||
|
|
||||||
|
logger.debug(f"Hybrid search for RAG returned {len(chunks)} results")
|
||||||
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||||
|
"""Delete chunks from MongoDB collection."""
|
||||||
|
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
|
||||||
|
try:
|
||||||
|
result = self.collection.delete_many({"_id": {"$in": chunk_ids}})
|
||||||
|
logger.debug(f"Deleted {result.deleted_count} chunks from MongoDB collection")
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Failed to delete chunks: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def delete(self) -> None:
|
||||||
|
"""Delete the entire collection."""
|
||||||
|
try:
|
||||||
|
self.collection.drop()
|
||||||
|
logger.debug(f"Dropped MongoDB collection: {self.collection.name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Failed to drop collection: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _ensure_text_index(self) -> None:
|
||||||
|
"""Ensure text search index exists on content fields for RAG."""
|
||||||
|
try:
|
||||||
|
indexes = list(self.collection.list_indexes())
|
||||||
|
text_index_exists = any(
|
||||||
|
any(key.startswith(("content", "text")) for key in idx.get("key", {}).keys())
|
||||||
|
and idx.get("textIndexVersion") is not None
|
||||||
|
for idx in indexes
|
||||||
|
)
|
||||||
|
|
||||||
|
if not text_index_exists:
|
||||||
|
logger.info("Creating text search index on content fields for RAG")
|
||||||
|
# Index both 'text' and 'content' fields for comprehensive text search
|
||||||
|
self.collection.create_index([("text", "text"), ("content", "text")])
|
||||||
|
logger.info("Text search index created successfully for RAG")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to create text index for RAG: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class MongoDBVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
|
||||||
|
"""MongoDB Atlas Vector Search adapter for Llama Stack optimized for RAG workflows."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: MongoDBVectorIOConfig,
|
||||||
|
inference_api,
|
||||||
|
files_api=None,
|
||||||
|
models_api=None,
|
||||||
|
) -> None:
|
||||||
|
# Handle the case where files_api might be a ProviderSpec that needs resolution
|
||||||
|
resolved_files_api = files_api
|
||||||
|
super().__init__(files_api=resolved_files_api, kvstore=None)
|
||||||
|
self.config = config
|
||||||
|
self.inference_api = inference_api
|
||||||
|
self.models_api = models_api
|
||||||
|
self.client: MongoClient | None = None
|
||||||
|
self.database: Database | None = None
|
||||||
|
self.cache: dict[str, VectorStoreWithIndex] = {}
|
||||||
|
self.kvstore: KVStore | None = None
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
"""Initialize MongoDB connection optimized for RAG workflows."""
|
||||||
|
logger.info("Initializing MongoDB Atlas Vector IO adapter for RAG")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initialize KV store for metadata
|
||||||
|
if self.config.persistence:
|
||||||
|
self.kvstore = await kvstore_impl(self.config.persistence)
|
||||||
|
|
||||||
|
# Get connection string from config (either direct or built from parameters)
|
||||||
|
connection_string = self.config.get_connection_string()
|
||||||
|
|
||||||
|
# Skip MongoDB connection if no connection parameters provided
|
||||||
|
# This allows other providers to work without MongoDB credentials
|
||||||
|
if not connection_string:
|
||||||
|
logger.warning(
|
||||||
|
"MongoDB connection parameters not provided. "
|
||||||
|
"MongoDB vector store will not be available until credentials are configured."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Connect to MongoDB with optimized settings for RAG
|
||||||
|
self.client = MongoClient(
|
||||||
|
connection_string,
|
||||||
|
server_api=ServerApi("1"),
|
||||||
|
maxPoolSize=self.config.max_pool_size,
|
||||||
|
serverSelectionTimeoutMS=self.config.timeout_ms,
|
||||||
|
# Additional settings for RAG performance
|
||||||
|
retryWrites=True,
|
||||||
|
readPreference="primaryPreferred",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test connection
|
||||||
|
try:
|
||||||
|
self.client.admin.command("ping")
|
||||||
|
logger.info("Successfully connected to MongoDB Atlas for RAG")
|
||||||
|
except Exception as conn_error:
|
||||||
|
# Extract just the basic error type without the full traceback
|
||||||
|
error_type = type(conn_error).__name__
|
||||||
|
logger.warning(
|
||||||
|
f"MongoDB connection failed ({error_type}). "
|
||||||
|
"MongoDB vector store will not be available. "
|
||||||
|
f"Attempted to connect to: {self.config.host or 'connection_string'}:{self.config.port or '(from connection_string)'}"
|
||||||
|
)
|
||||||
|
# Close the client and clear it
|
||||||
|
if self.client:
|
||||||
|
self.client.close()
|
||||||
|
self.client = None
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get database
|
||||||
|
self.database = self.client[self.config.database_name]
|
||||||
|
|
||||||
|
# Initialize OpenAI vector stores
|
||||||
|
await self.initialize_openai_vector_stores()
|
||||||
|
|
||||||
|
# Load existing vector databases
|
||||||
|
await self._load_existing_vector_dbs()
|
||||||
|
|
||||||
|
logger.info("MongoDB Atlas Vector IO adapter for RAG initialized successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to initialize MongoDB Atlas Vector IO adapter for RAG")
|
||||||
|
# Close the client if it was created
|
||||||
|
if self.client:
|
||||||
|
self.client.close()
|
||||||
|
self.client = None
|
||||||
|
# Log warning instead of raising to allow tests to skip gracefully
|
||||||
|
logger.warning(f"MongoDB initialization failed: {e}. MongoDB vector store will not be available.")
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
"""Shutdown MongoDB connection."""
|
||||||
|
if self.client:
|
||||||
|
self.client.close()
|
||||||
|
logger.info("MongoDB Atlas RAG connection closed")
|
||||||
|
|
||||||
|
async def health(self) -> HealthResponse:
|
||||||
|
"""Perform health check on MongoDB connection."""
|
||||||
|
try:
|
||||||
|
if self.client:
|
||||||
|
self.client.admin.command("ping")
|
||||||
|
return HealthResponse(status=HealthStatus.OK)
|
||||||
|
else:
|
||||||
|
return HealthResponse(status=HealthStatus.ERROR, message="MongoDB client not initialized")
|
||||||
|
except Exception as e:
|
||||||
|
return HealthResponse(
|
||||||
|
status=HealthStatus.ERROR,
|
||||||
|
message=f"MongoDB RAG health check failed: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
||||||
|
"""Register a new vector store optimized for RAG."""
|
||||||
|
if self.database is None:
|
||||||
|
raise RuntimeError("MongoDB database not initialized")
|
||||||
|
|
||||||
|
# Create collection name from vector store identifier
|
||||||
|
collection_name = sanitize_collection_name(vector_store.identifier)
|
||||||
|
collection = self.database[collection_name]
|
||||||
|
|
||||||
|
# Create and initialize MongoDB index optimized for RAG
|
||||||
|
mongodb_index = MongoDBIndex(vector_store, collection, self.config)
|
||||||
|
await mongodb_index.initialize()
|
||||||
|
|
||||||
|
# Create vector store with index wrapper
|
||||||
|
vector_store_with_index = VectorStoreWithIndex(
|
||||||
|
vector_store=vector_store,
|
||||||
|
index=mongodb_index,
|
||||||
|
inference_api=self.inference_api,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cache the vector store
|
||||||
|
self.cache[vector_store.identifier] = vector_store_with_index
|
||||||
|
|
||||||
|
# Save vector store info to KVStore for persistence
|
||||||
|
if self.kvstore:
|
||||||
|
await self.kvstore.set(
|
||||||
|
f"{VECTOR_DBS_PREFIX}{vector_store.identifier}",
|
||||||
|
vector_store.model_dump_json(),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Registered vector store for RAG: {vector_store.identifier}")
|
||||||
|
|
||||||
|
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||||
|
"""Unregister a vector store."""
|
||||||
|
if vector_store_id in self.cache:
|
||||||
|
await self.cache[vector_store_id].index.delete()
|
||||||
|
del self.cache[vector_store_id]
|
||||||
|
|
||||||
|
# Clean up from KV store
|
||||||
|
if self.kvstore:
|
||||||
|
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_store_id}")
|
||||||
|
|
||||||
|
logger.info(f"Unregistered vector store: {vector_store_id}")
|
||||||
|
|
||||||
|
async def insert_chunks(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
chunks: list[Chunk],
|
||||||
|
ttl_seconds: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Insert chunks into the vector database optimized for RAG."""
|
||||||
|
vector_db_with_index = await self._get_vector_db_index(vector_store_id)
|
||||||
|
await vector_db_with_index.insert_chunks(chunks)
|
||||||
|
|
||||||
|
async def query_chunks(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
query: InterleavedContent,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
"""Query chunks from the vector database optimized for RAG context retrieval."""
|
||||||
|
vector_db_with_index = await self._get_vector_db_index(vector_store_id)
|
||||||
|
return await vector_db_with_index.query_chunks(query, params)
|
||||||
|
|
||||||
|
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||||
|
"""Delete chunks from the vector database."""
|
||||||
|
vector_db_with_index = await self._get_vector_db_index(store_id)
|
||||||
|
await vector_db_with_index.index.delete_chunks(chunks_for_deletion)
|
||||||
|
|
||||||
|
async def _get_vector_db_index(self, vector_db_id: str) -> VectorStoreWithIndex:
|
||||||
|
"""Get vector store index from cache."""
|
||||||
|
if vector_db_id in self.cache:
|
||||||
|
return self.cache[vector_db_id]
|
||||||
|
|
||||||
|
raise VectorStoreNotFoundError(vector_db_id)
|
||||||
|
|
||||||
|
async def _load_existing_vector_dbs(self) -> None:
|
||||||
|
"""Load existing vector databases from KVStore."""
|
||||||
|
if not self.kvstore:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use keys_in_range to get all vector database keys from KVStore
|
||||||
|
# This searches for keys with the prefix by using range scan
|
||||||
|
start_key = VECTOR_DBS_PREFIX
|
||||||
|
# Create an end key by incrementing the last character
|
||||||
|
end_key = VECTOR_DBS_PREFIX[:-1] + chr(ord(VECTOR_DBS_PREFIX[-1]) + 1)
|
||||||
|
|
||||||
|
vector_db_keys = await self.kvstore.keys_in_range(start_key, end_key)
|
||||||
|
|
||||||
|
for key in vector_db_keys:
|
||||||
|
try:
|
||||||
|
vector_store_data = await self.kvstore.get(key)
|
||||||
|
if vector_store_data:
|
||||||
|
import json
|
||||||
|
|
||||||
|
vector_store = VectorStore(**json.loads(vector_store_data))
|
||||||
|
# Register the vector store without re-initializing
|
||||||
|
await self._register_existing_vector_store(vector_store)
|
||||||
|
logger.info(f"Loaded existing RAG-optimized vector store: {vector_store.identifier}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load vector store from key {key}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load existing vector stores: {e}")
|
||||||
|
|
||||||
|
async def _register_existing_vector_store(self, vector_store: VectorStore) -> None:
|
||||||
|
"""Register an existing vector store without re-initialization."""
|
||||||
|
if self.database is None:
|
||||||
|
raise RuntimeError("MongoDB database not initialized")
|
||||||
|
|
||||||
|
# Create collection name from vector store identifier
|
||||||
|
collection_name = sanitize_collection_name(vector_store.identifier)
|
||||||
|
collection = self.database[collection_name]
|
||||||
|
|
||||||
|
# Create MongoDB index without initialization (collection already exists)
|
||||||
|
mongodb_index = MongoDBIndex(vector_store, collection, self.config)
|
||||||
|
|
||||||
|
# Create vector store with index wrapper
|
||||||
|
vector_store_with_index = VectorStoreWithIndex(
|
||||||
|
vector_store=vector_store,
|
||||||
|
index=mongodb_index,
|
||||||
|
inference_api=self.inference_api,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cache the vector store
|
||||||
|
self.cache[vector_store.identifier] = vector_store_with_index
|
||||||
5
tests/unit/providers/vector_io/__init__.py
Normal file
5
tests/unit/providers/vector_io/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
Loading…
Add table
Add a link
Reference in a new issue