mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
Merge 503ad16002 into 356f37b1ba
This commit is contained in:
commit
fdc9ba2687
17 changed files with 2066 additions and 3 deletions
|
|
@ -31,7 +31,7 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
vector-io-provider: ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::chromadb", "remote::pgvector", "remote::weaviate", "remote::qdrant"]
|
||||
vector-io-provider: ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::chromadb", "remote::pgvector", "remote::weaviate", "remote::qdrant", "remote::mongodb"]
|
||||
python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }}
|
||||
fail-fast: false # we want to run all tests regardless of failure
|
||||
|
||||
|
|
@ -101,6 +101,16 @@ jobs:
|
|||
-p 6333:6333 \
|
||||
qdrant/qdrant
|
||||
|
||||
- name: Setup MongoDB
|
||||
if: matrix.vector-io-provider == 'remote::mongodb'
|
||||
run: |
|
||||
docker run --rm -d --pull always \
|
||||
--name mongodb \
|
||||
-p 27017:27017 \
|
||||
-e MONGO_INITDB_ROOT_USERNAME=llamastack \
|
||||
-e MONGO_INITDB_ROOT_PASSWORD=llamastack \
|
||||
mongodb/mongodb-atlas-local:latest
|
||||
|
||||
- name: Wait for Qdrant to be ready
|
||||
if: matrix.vector-io-provider == 'remote::qdrant'
|
||||
run: |
|
||||
|
|
@ -116,6 +126,21 @@ jobs:
|
|||
docker logs qdrant
|
||||
exit 1
|
||||
|
||||
- name: Wait for MongoDB to be ready
|
||||
if: matrix.vector-io-provider == 'remote::mongodb'
|
||||
run: |
|
||||
echo "Waiting for MongoDB to be ready..."
|
||||
for i in {1..30}; do
|
||||
if docker exec mongodb mongosh --quiet --eval "db.adminCommand('ping').ok" > /dev/null 2>&1; then
|
||||
echo "MongoDB is ready!"
|
||||
exit 0
|
||||
fi
|
||||
sleep 2
|
||||
done
|
||||
echo "MongoDB failed to start"
|
||||
docker logs mongodb
|
||||
exit 1
|
||||
|
||||
- name: Wait for ChromaDB to be ready
|
||||
if: matrix.vector-io-provider == 'remote::chromadb'
|
||||
run: |
|
||||
|
|
@ -170,6 +195,11 @@ jobs:
|
|||
QDRANT_URL: ${{ matrix.vector-io-provider == 'remote::qdrant' && 'http://localhost:6333' || '' }}
|
||||
ENABLE_WEAVIATE: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'true' || '' }}
|
||||
WEAVIATE_CLUSTER_URL: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'localhost:8080' || '' }}
|
||||
ENABLE_MONGODB: ${{ matrix.vector-io-provider == 'remote::mongodb' && 'true' || '' }}
|
||||
MONGODB_HOST: ${{ matrix.vector-io-provider == 'remote::mongodb' && 'localhost' || '' }}
|
||||
MONGODB_PORT: ${{ matrix.vector-io-provider == 'remote::mongodb' && '27017' || '' }}
|
||||
MONGODB_USERNAME: ${{ matrix.vector-io-provider == 'remote::mongodb' && 'llamastack' || '' }}
|
||||
MONGODB_PASSWORD: ${{ matrix.vector-io-provider == 'remote::mongodb' && 'llamastack' || '' }}
|
||||
run: |
|
||||
uv run --no-sync \
|
||||
pytest -sv --stack-config="files=inline::localfs,inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \
|
||||
|
|
@ -196,6 +226,11 @@ jobs:
|
|||
run: |
|
||||
docker logs qdrant > qdrant.log
|
||||
|
||||
- name: Write MongoDB logs to file
|
||||
if: ${{ always() && matrix.vector-io-provider == 'remote::mongodb' }}
|
||||
run: |
|
||||
docker logs mongodb > mongodb.log
|
||||
|
||||
- name: Upload all logs to artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
|
||||
|
|
|
|||
276
docs/docs/providers/vector_io/remote_mongodb.mdx
Normal file
276
docs/docs/providers/vector_io/remote_mongodb.mdx
Normal file
|
|
@ -0,0 +1,276 @@
|
|||
---
|
||||
description: |
|
||||
[MongoDB Atlas](https://www.mongodb.com/products/platform/atlas-vector-search) is a remote vector database provider for Llama Stack. It
|
||||
uses MongoDB Atlas Vector Search to store and query vectors in the cloud.
|
||||
That means you get enterprise-grade vector search with MongoDB's scalability and reliability.
|
||||
|
||||
## Features
|
||||
|
||||
- Cloud-native vector search with MongoDB Atlas
|
||||
- Fully integrated with Llama Stack
|
||||
- Enterprise-grade security and scalability
|
||||
- Supports multiple search modes: vector, keyword, and hybrid search
|
||||
- Built-in metadata filtering and text search capabilities
|
||||
- Automatic index management
|
||||
|
||||
## Search Modes
|
||||
|
||||
MongoDB Atlas Vector Search supports three different search modes:
|
||||
|
||||
### Vector Search
|
||||
Vector search uses MongoDB's `$vectorSearch` aggregation stage to perform semantic similarity search using embedding vectors.
|
||||
|
||||
```python
|
||||
# Vector search example
|
||||
search_response = client.vector_stores.search(
|
||||
vector_store_id=vector_store.id,
|
||||
query="What is machine learning?",
|
||||
search_mode="vector",
|
||||
max_num_results=5,
|
||||
)
|
||||
```
|
||||
|
||||
### Keyword Search
|
||||
Keyword search uses MongoDB's text search capabilities with full-text indexes to find chunks containing specific terms.
|
||||
|
||||
```python
|
||||
# Keyword search example
|
||||
search_response = client.vector_stores.search(
|
||||
vector_store_id=vector_store.id,
|
||||
query="Python programming language",
|
||||
search_mode="keyword",
|
||||
max_num_results=5,
|
||||
)
|
||||
```
|
||||
|
||||
### Hybrid Search
|
||||
Hybrid search combines both vector and keyword search methods using configurable reranking algorithms.
|
||||
|
||||
```python
|
||||
# Hybrid search with RRF ranker (default)
|
||||
search_response = client.vector_stores.search(
|
||||
vector_store_id=vector_store.id,
|
||||
query="neural networks in Python",
|
||||
search_mode="hybrid",
|
||||
max_num_results=5,
|
||||
)
|
||||
|
||||
# Hybrid search with weighted ranker
|
||||
search_response = client.vector_stores.search(
|
||||
vector_store_id=vector_store.id,
|
||||
query="neural networks in Python",
|
||||
search_mode="hybrid",
|
||||
max_num_results=5,
|
||||
ranking_options={
|
||||
"ranker": {
|
||||
"type": "weighted",
|
||||
"alpha": 0.7, # 70% vector search, 30% keyword search
|
||||
}
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
To use MongoDB Atlas in your Llama Stack project, follow these steps:
|
||||
|
||||
1. Create a MongoDB Atlas cluster with Vector Search enabled
|
||||
2. Install the necessary dependencies
|
||||
3. Configure your Llama Stack project to use MongoDB
|
||||
4. Start storing and querying vectors
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables
|
||||
Set up the following environment variable for your MongoDB Atlas connection:
|
||||
|
||||
```bash
|
||||
export MONGODB_CONNECTION_STRING="mongodb+srv://username:password@cluster.mongodb.net/?retryWrites=true&w=majority&appName=llama-stack"
|
||||
```
|
||||
|
||||
### Configuration Example
|
||||
|
||||
```yaml
|
||||
vector_io:
|
||||
- provider_id: mongodb_atlas
|
||||
provider_type: remote::mongodb
|
||||
config:
|
||||
connection_string: "${env.MONGODB_CONNECTION_STRING}"
|
||||
database_name: "llama_stack"
|
||||
index_name: "vector_index"
|
||||
similarity_metric: "cosine"
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
You can install the MongoDB Python driver using pip:
|
||||
|
||||
```bash
|
||||
pip install pymongo
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
||||
See [MongoDB Atlas Vector Search documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-overview/) for more details about MongoDB Atlas Vector Search.
|
||||
|
||||
For general MongoDB documentation, visit [MongoDB Documentation](https://docs.mongodb.com/).
|
||||
sidebar_label: Remote - Mongodb
|
||||
title: remote::mongodb
|
||||
---
|
||||
|
||||
# remote::mongodb
|
||||
|
||||
## Description
|
||||
|
||||
|
||||
[MongoDB Atlas](https://www.mongodb.com/products/platform/atlas-vector-search) is a remote vector database provider for Llama Stack. It
|
||||
uses MongoDB Atlas Vector Search to store and query vectors in the cloud.
|
||||
That means you get enterprise-grade vector search with MongoDB's scalability and reliability.
|
||||
|
||||
## Features
|
||||
|
||||
- Cloud-native vector search with MongoDB Atlas
|
||||
- Fully integrated with Llama Stack
|
||||
- Enterprise-grade security and scalability
|
||||
- Supports multiple search modes: vector, keyword, and hybrid search
|
||||
- Built-in metadata filtering and text search capabilities
|
||||
- Automatic index management
|
||||
|
||||
## Search Modes
|
||||
|
||||
MongoDB Atlas Vector Search supports three different search modes:
|
||||
|
||||
### Vector Search
|
||||
Vector search uses MongoDB's `$vectorSearch` aggregation stage to perform semantic similarity search using embedding vectors.
|
||||
|
||||
```python
|
||||
# Vector search example
|
||||
search_response = client.vector_stores.search(
|
||||
vector_store_id=vector_store.id,
|
||||
query="What is machine learning?",
|
||||
search_mode="vector",
|
||||
max_num_results=5,
|
||||
)
|
||||
```
|
||||
|
||||
### Keyword Search
|
||||
Keyword search uses MongoDB's text search capabilities with full-text indexes to find chunks containing specific terms.
|
||||
|
||||
```python
|
||||
# Keyword search example
|
||||
search_response = client.vector_stores.search(
|
||||
vector_store_id=vector_store.id,
|
||||
query="Python programming language",
|
||||
search_mode="keyword",
|
||||
max_num_results=5,
|
||||
)
|
||||
```
|
||||
|
||||
### Hybrid Search
|
||||
Hybrid search combines both vector and keyword search methods using configurable reranking algorithms.
|
||||
|
||||
```python
|
||||
# Hybrid search with RRF ranker (default)
|
||||
search_response = client.vector_stores.search(
|
||||
vector_store_id=vector_store.id,
|
||||
query="neural networks in Python",
|
||||
search_mode="hybrid",
|
||||
max_num_results=5,
|
||||
)
|
||||
|
||||
# Hybrid search with weighted ranker
|
||||
search_response = client.vector_stores.search(
|
||||
vector_store_id=vector_store.id,
|
||||
query="neural networks in Python",
|
||||
search_mode="hybrid",
|
||||
max_num_results=5,
|
||||
ranking_options={
|
||||
"ranker": {
|
||||
"type": "weighted",
|
||||
"alpha": 0.7, # 70% vector search, 30% keyword search
|
||||
}
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
To use MongoDB Atlas in your Llama Stack project, follow these steps:
|
||||
|
||||
1. Create a MongoDB Atlas cluster with Vector Search enabled
|
||||
2. Install the necessary dependencies
|
||||
3. Configure your Llama Stack project to use MongoDB
|
||||
4. Start storing and querying vectors
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables
|
||||
Set up the following environment variable for your MongoDB Atlas connection:
|
||||
|
||||
```bash
|
||||
export MONGODB_CONNECTION_STRING="mongodb+srv://username:password@cluster.mongodb.net/?retryWrites=true&w=majority&appName=llama-stack"
|
||||
```
|
||||
|
||||
### Configuration Example
|
||||
|
||||
```yaml
|
||||
vector_io:
|
||||
- provider_id: mongodb_atlas
|
||||
provider_type: remote::mongodb
|
||||
config:
|
||||
connection_string: "${env.MONGODB_CONNECTION_STRING}"
|
||||
database_name: "llama_stack"
|
||||
index_name: "vector_index"
|
||||
similarity_metric: "cosine"
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
You can install the MongoDB Python driver using pip:
|
||||
|
||||
```bash
|
||||
pip install pymongo
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
||||
See [MongoDB Atlas Vector Search documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-overview/) for more details about MongoDB Atlas Vector Search.
|
||||
|
||||
For general MongoDB documentation, visit [MongoDB Documentation](https://docs.mongodb.com/).
|
||||
|
||||
|
||||
## Configuration
|
||||
|
||||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `connection_string` | `str \| None` | No | | MongoDB connection string (e.g., mongodb://user:pass@localhost:27017/ or mongodb+srv://user:pass@cluster.mongodb.net/) |
|
||||
| `host` | `str \| None` | No | | MongoDB host (used if connection_string is not provided) |
|
||||
| `port` | `int \| None` | No | | MongoDB port (used if connection_string is not provided) |
|
||||
| `username` | `str \| None` | No | | MongoDB username (used if connection_string is not provided) |
|
||||
| `password` | `str \| None` | No | | MongoDB password (used if connection_string is not provided) |
|
||||
| `database_name` | `<class 'str'>` | No | llama_stack | Database name to use for vector collections |
|
||||
| `index_name` | `<class 'str'>` | No | vector_index | Name of the vector search index |
|
||||
| `path_field` | `<class 'str'>` | No | embedding | Field name for storing embeddings |
|
||||
| `similarity_metric` | `<class 'str'>` | No | cosine | Similarity metric: cosine, euclidean, or dotProduct |
|
||||
| `max_pool_size` | `<class 'int'>` | No | 100 | Maximum connection pool size |
|
||||
| `timeout_ms` | `<class 'int'>` | No | 30000 | Connection timeout in milliseconds |
|
||||
| `persistence` | `llama_stack.core.storage.datatypes.KVStoreReference \| None` | No | | Config for KV store backend for metadata storage |
|
||||
|
||||
## Sample Configuration
|
||||
|
||||
```yaml
|
||||
connection_string: ${env.MONGODB_CONNECTION_STRING:=}
|
||||
host: ${env.MONGODB_HOST:=localhost}
|
||||
port: ${env.MONGODB_PORT:=27017}
|
||||
username: ${env.MONGODB_USERNAME:=}
|
||||
password: ${env.MONGODB_PASSWORD:=}
|
||||
database_name: ${env.MONGODB_DATABASE_NAME:=llama_stack}
|
||||
index_name: ${env.MONGODB_INDEX_NAME:=vector_index}
|
||||
path_field: ${env.MONGODB_PATH_FIELD:=embedding}
|
||||
similarity_metric: ${env.MONGODB_SIMILARITY_METRIC:=cosine}
|
||||
max_pool_size: ${env.MONGODB_MAX_POOL_SIZE:=100}
|
||||
timeout_ms: ${env.MONGODB_TIMEOUT_MS:=30000}
|
||||
persistence:
|
||||
namespace: vector_io::mongodb_atlas
|
||||
backend: kv_default
|
||||
```
|
||||
20
llama_stack/providers/remote/vector_io/mongodb/__init__.py
Normal file
20
llama_stack/providers/remote/vector_io/mongodb/__init__.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import MongoDBVectorIOConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: MongoDBVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||
from .mongodb import MongoDBVectorIOAdapter
|
||||
|
||||
# Handle the deps resolution - if files API exists, pass it, otherwise None
|
||||
files_api = deps.get(Api.files)
|
||||
models_api = deps.get(Api.models)
|
||||
impl = MongoDBVectorIOAdapter(config, deps[Api.inference], files_api, models_api)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
102
llama_stack/providers/remote/vector_io/mongodb/config.py
Normal file
102
llama_stack/providers/remote/vector_io/mongodb/config.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MongoDBVectorIOConfig(BaseModel):
|
||||
"""Configuration for MongoDB Atlas Vector Search provider.
|
||||
|
||||
This provider connects to MongoDB Atlas and uses Vector Search for RAG operations.
|
||||
"""
|
||||
|
||||
# MongoDB connection details - either connection_string or individual parameters
|
||||
connection_string: str | None = Field(
|
||||
default=None,
|
||||
description="MongoDB connection string (e.g., mongodb://user:pass@localhost:27017/ or mongodb+srv://user:pass@cluster.mongodb.net/)",
|
||||
)
|
||||
host: str | None = Field(default=None, description="MongoDB host (used if connection_string is not provided)")
|
||||
port: int | None = Field(default=None, description="MongoDB port (used if connection_string is not provided)")
|
||||
username: str | None = Field(
|
||||
default=None, description="MongoDB username (used if connection_string is not provided)"
|
||||
)
|
||||
password: str | None = Field(
|
||||
default=None, description="MongoDB password (used if connection_string is not provided)"
|
||||
)
|
||||
database_name: str = Field(default="llama_stack", description="Database name to use for vector collections")
|
||||
|
||||
# Vector search configuration
|
||||
index_name: str = Field(default="vector_index", description="Name of the vector search index")
|
||||
path_field: str = Field(default="embedding", description="Field name for storing embeddings")
|
||||
similarity_metric: str = Field(
|
||||
default="cosine",
|
||||
description="Similarity metric: cosine, euclidean, or dotProduct",
|
||||
)
|
||||
|
||||
# Connection options
|
||||
max_pool_size: int = Field(default=100, description="Maximum connection pool size")
|
||||
timeout_ms: int = Field(default=30000, description="Connection timeout in milliseconds")
|
||||
|
||||
# KV store configuration
|
||||
persistence: KVStoreReference | None = Field(
|
||||
description="Config for KV store backend for metadata storage", default=None
|
||||
)
|
||||
|
||||
def get_connection_string(self) -> str | None:
|
||||
"""Build connection string from individual parameters if not provided directly.
|
||||
|
||||
If both connection_string and individual parameters (host/port) are provided,
|
||||
individual parameters take precedence to allow test environment overrides.
|
||||
"""
|
||||
# Prioritize individual connection parameters over connection_string
|
||||
# This allows test environments to override with MONGODB_HOST/PORT/etc
|
||||
if self.host and self.port:
|
||||
auth_part = ""
|
||||
if self.username and self.password:
|
||||
auth_part = f"{self.username}:{self.password}@"
|
||||
return f"mongodb://{auth_part}{self.host}:{self.port}/"
|
||||
|
||||
# Fall back to connection_string if provided
|
||||
if self.connection_string:
|
||||
return self.connection_string
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
__distro_dir__: str,
|
||||
connection_string: str = "${env.MONGODB_CONNECTION_STRING:=}",
|
||||
host: str = "${env.MONGODB_HOST:=localhost}",
|
||||
port: int = "${env.MONGODB_PORT:=27017}",
|
||||
username: str = "${env.MONGODB_USERNAME:=}",
|
||||
password: str = "${env.MONGODB_PASSWORD:=}",
|
||||
database_name: str = "${env.MONGODB_DATABASE_NAME:=llama_stack}",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"connection_string": connection_string,
|
||||
"host": host,
|
||||
"port": port,
|
||||
"username": username,
|
||||
"password": password,
|
||||
"database_name": database_name,
|
||||
"index_name": "${env.MONGODB_INDEX_NAME:=vector_index}",
|
||||
"path_field": "${env.MONGODB_PATH_FIELD:=embedding}",
|
||||
"similarity_metric": "${env.MONGODB_SIMILARITY_METRIC:=cosine}",
|
||||
"max_pool_size": "${env.MONGODB_MAX_POOL_SIZE:=100}",
|
||||
"timeout_ms": "${env.MONGODB_TIMEOUT_MS:=30000}",
|
||||
"persistence": KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="vector_io::mongodb_atlas",
|
||||
).model_dump(exclude_none=True),
|
||||
}
|
||||
609
llama_stack/providers/remote/vector_io/mongodb/mongodb.py
Normal file
609
llama_stack/providers/remote/vector_io/mongodb/mongodb.py
Normal file
|
|
@ -0,0 +1,609 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import heapq
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from numpy.typing import NDArray
|
||||
from pymongo import MongoClient
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.database import Database
|
||||
from pymongo.operations import SearchIndexModel
|
||||
from pymongo.server_api import ServerApi
|
||||
|
||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
VectorStoresProtocolPrivate,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import (
|
||||
OpenAIVectorStoreMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorStoreWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import (
|
||||
WeightedInMemoryAggregator,
|
||||
sanitize_collection_name,
|
||||
)
|
||||
|
||||
from .config import MongoDBVectorIOConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="vector_io::mongodb")
|
||||
|
||||
VERSION = "v1"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:mongodb:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:mongodb:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:mongodb:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:mongodb:{VERSION}::"
|
||||
|
||||
|
||||
class MongoDBIndex(EmbeddingIndex):
|
||||
"""MongoDB Atlas Vector Search index implementation optimized for RAG."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_store: VectorStore,
|
||||
collection: Collection,
|
||||
config: MongoDBVectorIOConfig,
|
||||
):
|
||||
self.vector_store = vector_store
|
||||
self.collection = collection
|
||||
self.config = config
|
||||
self.dimension = vector_store.embedding_dimension
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the MongoDB collection and ensure vector search index exists."""
|
||||
try:
|
||||
# Create the collection if it doesn't exist
|
||||
collection_names = self.collection.database.list_collection_names()
|
||||
if self.collection.name not in collection_names:
|
||||
logger.info(f"Creating collection '{self.collection.name}'")
|
||||
# Create collection by inserting a dummy document
|
||||
dummy_doc = {"_id": "__dummy__", "dummy": True}
|
||||
self.collection.insert_one(dummy_doc)
|
||||
# Remove the dummy document
|
||||
self.collection.delete_one({"_id": "__dummy__"})
|
||||
logger.info(f"Collection '{self.collection.name}' created successfully")
|
||||
|
||||
# Create optimized vector search index for RAG
|
||||
await self._create_vector_search_index()
|
||||
|
||||
# Create text index for hybrid search
|
||||
await self._ensure_text_index()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to initialize MongoDB index for vector_store: {self.vector_store.identifier}. "
|
||||
f"Collection name: {self.collection.name}. Error: {str(e)}"
|
||||
)
|
||||
# Don't fail completely - just log the error and continue
|
||||
logger.warning(
|
||||
"Continuing without complete index initialization. "
|
||||
"You may need to create indexes manually in MongoDB Atlas dashboard."
|
||||
)
|
||||
|
||||
async def _create_vector_search_index(self) -> None:
|
||||
"""Create optimized vector search index based on MongoDB RAG best practices."""
|
||||
try:
|
||||
# Check if vector search index exists
|
||||
indexes = list(self.collection.list_search_indexes())
|
||||
index_exists = any(idx.get("name") == self.config.index_name for idx in indexes)
|
||||
|
||||
if not index_exists:
|
||||
# Create vector search index optimized for RAG
|
||||
# Based on MongoDB's RAG example using new vectorSearch format
|
||||
search_index_model = SearchIndexModel(
|
||||
definition={
|
||||
"fields": [
|
||||
{
|
||||
"type": "vector",
|
||||
"numDimensions": self.dimension,
|
||||
"path": self.config.path_field,
|
||||
"similarity": self._convert_similarity_metric(self.config.similarity_metric),
|
||||
}
|
||||
]
|
||||
},
|
||||
name=self.config.index_name,
|
||||
type="vectorSearch",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Creating vector search index '{self.config.index_name}' for RAG on collection '{self.collection.name}'"
|
||||
)
|
||||
|
||||
self.collection.create_search_index(model=search_index_model)
|
||||
|
||||
# Wait for index to be ready (like in MongoDB RAG example)
|
||||
await self._wait_for_index_ready()
|
||||
|
||||
logger.info("Vector search index created and ready for RAG queries")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create vector search index: {e}")
|
||||
|
||||
def _convert_similarity_metric(self, metric: str) -> str:
|
||||
"""Convert internal similarity metric to MongoDB Atlas format."""
|
||||
metric_map = {
|
||||
"cosine": "cosine",
|
||||
"euclidean": "euclidean",
|
||||
"dotProduct": "dotProduct",
|
||||
"dot_product": "dotProduct",
|
||||
}
|
||||
return metric_map.get(metric, "cosine")
|
||||
|
||||
async def _wait_for_index_ready(self) -> None:
|
||||
"""Wait for the vector search index to be ready, based on MongoDB RAG example."""
|
||||
logger.info("Waiting for vector search index to be ready...")
|
||||
|
||||
max_wait_time = 300 # 5 minutes max wait
|
||||
wait_interval = 5
|
||||
elapsed_time = 0
|
||||
|
||||
while elapsed_time < max_wait_time:
|
||||
try:
|
||||
indices = list(self.collection.list_search_indexes(self.config.index_name))
|
||||
if len(indices) and indices[0].get("queryable") is True:
|
||||
logger.info(f"Vector search index '{self.config.index_name}' is ready for querying")
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time.sleep(wait_interval)
|
||||
elapsed_time += wait_interval
|
||||
|
||||
logger.warning(f"Vector search index may not be fully ready after {max_wait_time}s")
|
||||
|
||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray) -> None:
|
||||
"""Add chunks with embeddings to MongoDB collection optimized for RAG."""
|
||||
if len(chunks) != len(embeddings):
|
||||
raise ValueError(f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}")
|
||||
|
||||
documents = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
# Structure document for optimal RAG retrieval
|
||||
doc = {
|
||||
"_id": chunk.chunk_id,
|
||||
"chunk_id": chunk.chunk_id,
|
||||
"text": interleaved_content_as_str(chunk.content), # Key field for RAG context
|
||||
"content": interleaved_content_as_str(chunk.content), # Backward compatibility
|
||||
"metadata": chunk.metadata or {},
|
||||
"chunk_metadata": (chunk.chunk_metadata.model_dump() if chunk.chunk_metadata else {}),
|
||||
self.config.path_field: embeddings[i].tolist(), # Vector embeddings
|
||||
"document": chunk.model_dump(), # Full chunk data
|
||||
}
|
||||
documents.append(doc)
|
||||
|
||||
try:
|
||||
# Use upsert behavior for chunks
|
||||
for doc in documents:
|
||||
self.collection.replace_one({"_id": doc["_id"]}, doc, upsert=True)
|
||||
|
||||
logger.debug(f"Successfully added {len(chunks)} chunks optimized for RAG to MongoDB collection")
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to add chunks to MongoDB collection: {e}")
|
||||
raise
|
||||
|
||||
async def query_vector(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
"""Perform vector similarity search optimized for RAG based on MongoDB example."""
|
||||
try:
|
||||
# Use MongoDB's vector search aggregation pipeline optimized for RAG
|
||||
pipeline = [
|
||||
{
|
||||
"$vectorSearch": {
|
||||
"index": self.config.index_name,
|
||||
"queryVector": embedding.tolist(),
|
||||
"path": self.config.path_field,
|
||||
"numCandidates": min(k * 10, 1000), # Cap at 1000 to prevent excessive candidates
|
||||
"limit": k,
|
||||
}
|
||||
},
|
||||
{
|
||||
"$project": {
|
||||
"_id": 0,
|
||||
"text": 1, # Primary field for RAG context
|
||||
"content": 1, # Backward compatibility
|
||||
"metadata": 1,
|
||||
"chunk_metadata": 1,
|
||||
"document": 1,
|
||||
"score": {"$meta": "vectorSearchScore"},
|
||||
}
|
||||
},
|
||||
{"$match": {"score": {"$gte": score_threshold}}},
|
||||
]
|
||||
|
||||
results = list(self.collection.aggregate(pipeline))
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
for result in results:
|
||||
score = result.get("score", 0.0)
|
||||
if score >= score_threshold:
|
||||
chunk_data = result.get("document", {})
|
||||
if chunk_data:
|
||||
chunks.append(Chunk(**chunk_data))
|
||||
scores.append(float(score))
|
||||
|
||||
logger.debug(f"Vector search for RAG returned {len(chunks)} results")
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Vector search for RAG failed: {e}")
|
||||
raise RuntimeError(f"Vector search for RAG failed: {e}") from e
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
"""Perform text search using MongoDB's text search for RAG context retrieval."""
|
||||
try:
|
||||
# Ensure text index exists
|
||||
await self._ensure_text_index()
|
||||
|
||||
pipeline: list[dict[str, Any]] = [
|
||||
{"$match": {"$text": {"$search": query_string}}},
|
||||
{
|
||||
"$project": {
|
||||
"_id": 0,
|
||||
"text": 1, # Primary field for RAG context
|
||||
"content": 1, # Backward compatibility
|
||||
"metadata": 1,
|
||||
"chunk_metadata": 1,
|
||||
"document": 1,
|
||||
"score": {"$meta": "textScore"},
|
||||
}
|
||||
},
|
||||
{"$match": {"score": {"$gte": score_threshold}}},
|
||||
{"$sort": {"score": {"$meta": "textScore"}}},
|
||||
{"$limit": k},
|
||||
]
|
||||
|
||||
results = list(self.collection.aggregate(pipeline))
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
for result in results:
|
||||
score = result.get("score", 0.0)
|
||||
if score >= score_threshold:
|
||||
chunk_data = result.get("document", {})
|
||||
if chunk_data:
|
||||
chunks.append(Chunk(**chunk_data))
|
||||
scores.append(float(score))
|
||||
|
||||
logger.debug(f"Keyword search for RAG returned {len(chunks)} results")
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Keyword search for RAG failed: {e}")
|
||||
raise RuntimeError(f"Keyword search for RAG failed: {e}") from e
|
||||
|
||||
async def query_hybrid(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
"""Perform hybrid search for enhanced RAG context retrieval."""
|
||||
if reranker_params is None:
|
||||
reranker_params = {}
|
||||
|
||||
# Get results from both search methods
|
||||
vector_response = await self.query_vector(embedding, k, 0.0)
|
||||
keyword_response = await self.query_keyword(query_string, k, 0.0)
|
||||
|
||||
# Convert responses to score dictionaries
|
||||
vector_scores = {
|
||||
chunk.chunk_id: score for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False)
|
||||
}
|
||||
keyword_scores = {
|
||||
chunk.chunk_id: score
|
||||
for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False)
|
||||
}
|
||||
|
||||
# Combine scores using the reranking utility
|
||||
combined_scores = WeightedInMemoryAggregator.combine_search_results(
|
||||
vector_scores, keyword_scores, reranker_type, reranker_params
|
||||
)
|
||||
|
||||
# Get top-k results
|
||||
top_k_items = heapq.nlargest(k, combined_scores.items(), key=lambda x: x[1])
|
||||
|
||||
# Filter by score threshold
|
||||
filtered_items = [(doc_id, score) for doc_id, score in top_k_items if score >= score_threshold]
|
||||
|
||||
# Create chunk map
|
||||
chunk_map = {c.chunk_id: c for c in vector_response.chunks + keyword_response.chunks}
|
||||
|
||||
# Build final results
|
||||
chunks = []
|
||||
scores = []
|
||||
for doc_id, score in filtered_items:
|
||||
if doc_id in chunk_map:
|
||||
chunks.append(chunk_map[doc_id])
|
||||
scores.append(score)
|
||||
|
||||
logger.debug(f"Hybrid search for RAG returned {len(chunks)} results")
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete chunks from MongoDB collection."""
|
||||
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
|
||||
try:
|
||||
result = self.collection.delete_many({"_id": {"$in": chunk_ids}})
|
||||
logger.debug(f"Deleted {result.deleted_count} chunks from MongoDB collection")
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to delete chunks: {e}")
|
||||
raise
|
||||
|
||||
async def delete(self) -> None:
|
||||
"""Delete the entire collection."""
|
||||
try:
|
||||
self.collection.drop()
|
||||
logger.debug(f"Dropped MongoDB collection: {self.collection.name}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to drop collection: {e}")
|
||||
raise
|
||||
|
||||
async def _ensure_text_index(self) -> None:
|
||||
"""Ensure text search index exists on content fields for RAG."""
|
||||
try:
|
||||
indexes = list(self.collection.list_indexes())
|
||||
text_index_exists = any(
|
||||
any(key.startswith(("content", "text")) for key in idx.get("key", {}).keys())
|
||||
and idx.get("textIndexVersion") is not None
|
||||
for idx in indexes
|
||||
)
|
||||
|
||||
if not text_index_exists:
|
||||
logger.info("Creating text search index on content fields for RAG")
|
||||
# Index both 'text' and 'content' fields for comprehensive text search
|
||||
self.collection.create_index([("text", "text"), ("content", "text")])
|
||||
logger.info("Text search index created successfully for RAG")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create text index for RAG: {e}")
|
||||
|
||||
|
||||
class MongoDBVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
|
||||
"""MongoDB Atlas Vector Search adapter for Llama Stack optimized for RAG workflows."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MongoDBVectorIOConfig,
|
||||
inference_api,
|
||||
files_api=None,
|
||||
models_api=None,
|
||||
) -> None:
|
||||
# Handle the case where files_api might be a ProviderSpec that needs resolution
|
||||
resolved_files_api = files_api
|
||||
super().__init__(files_api=resolved_files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.models_api = models_api
|
||||
self.client: MongoClient | None = None
|
||||
self.database: Database | None = None
|
||||
self.cache: dict[str, VectorStoreWithIndex] = {}
|
||||
self.kvstore: KVStore | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize MongoDB connection optimized for RAG workflows."""
|
||||
logger.info("Initializing MongoDB Atlas Vector IO adapter for RAG")
|
||||
|
||||
try:
|
||||
# Initialize KV store for metadata
|
||||
if self.config.persistence:
|
||||
self.kvstore = await kvstore_impl(self.config.persistence)
|
||||
|
||||
# Skip MongoDB connection if no connection string provided
|
||||
# This allows other providers to work without MongoDB credentials
|
||||
if not self.config.connection_string:
|
||||
logger.warning(
|
||||
"MongoDB connection_string not provided. "
|
||||
"MongoDB vector store will not be available until credentials are configured."
|
||||
)
|
||||
return
|
||||
|
||||
# Connect to MongoDB with optimized settings for RAG
|
||||
self.client = MongoClient(
|
||||
self.config.connection_string,
|
||||
server_api=ServerApi("1"),
|
||||
maxPoolSize=self.config.max_pool_size,
|
||||
serverSelectionTimeoutMS=self.config.timeout_ms,
|
||||
# Additional settings for RAG performance
|
||||
retryWrites=True,
|
||||
readPreference="primaryPreferred",
|
||||
)
|
||||
|
||||
# Test connection
|
||||
self.client.admin.command("ping")
|
||||
logger.info("Successfully connected to MongoDB Atlas for RAG")
|
||||
|
||||
# Get database
|
||||
self.database = self.client[self.config.database_name]
|
||||
|
||||
# Initialize OpenAI vector stores
|
||||
await self.initialize_openai_vector_stores()
|
||||
|
||||
# Load existing vector databases
|
||||
await self._load_existing_vector_dbs()
|
||||
|
||||
logger.info("MongoDB Atlas Vector IO adapter for RAG initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to initialize MongoDB Atlas Vector IO adapter for RAG")
|
||||
raise RuntimeError("Failed to initialize MongoDB Atlas Vector IO adapter for RAG") from e
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Shutdown MongoDB connection."""
|
||||
if self.client:
|
||||
self.client.close()
|
||||
logger.info("MongoDB Atlas RAG connection closed")
|
||||
|
||||
async def health(self) -> HealthResponse:
|
||||
"""Perform health check on MongoDB connection."""
|
||||
try:
|
||||
if self.client:
|
||||
self.client.admin.command("ping")
|
||||
return HealthResponse(status=HealthStatus.OK)
|
||||
else:
|
||||
return HealthResponse(status=HealthStatus.ERROR, message="MongoDB client not initialized")
|
||||
except Exception as e:
|
||||
return HealthResponse(
|
||||
status=HealthStatus.ERROR,
|
||||
message=f"MongoDB RAG health check failed: {str(e)}",
|
||||
)
|
||||
|
||||
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
||||
"""Register a new vector store optimized for RAG."""
|
||||
if self.database is None:
|
||||
raise RuntimeError("MongoDB database not initialized")
|
||||
|
||||
# Create collection name from vector store identifier
|
||||
collection_name = sanitize_collection_name(vector_store.identifier)
|
||||
collection = self.database[collection_name]
|
||||
|
||||
# Create and initialize MongoDB index optimized for RAG
|
||||
mongodb_index = MongoDBIndex(vector_store, collection, self.config)
|
||||
await mongodb_index.initialize()
|
||||
|
||||
# Create vector store with index wrapper
|
||||
vector_store_with_index = VectorStoreWithIndex(
|
||||
vector_store=vector_store,
|
||||
index=mongodb_index,
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
|
||||
# Cache the vector store
|
||||
self.cache[vector_store.identifier] = vector_store_with_index
|
||||
|
||||
# Save vector store info to KVStore for persistence
|
||||
if self.kvstore:
|
||||
await self.kvstore.set(
|
||||
f"{VECTOR_DBS_PREFIX}{vector_store.identifier}",
|
||||
vector_store.model_dump_json(),
|
||||
)
|
||||
|
||||
logger.info(f"Registered vector store for RAG: {vector_store.identifier}")
|
||||
|
||||
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||
"""Unregister a vector store."""
|
||||
if vector_store_id in self.cache:
|
||||
await self.cache[vector_store_id].index.delete()
|
||||
del self.cache[vector_store_id]
|
||||
|
||||
# Clean up from KV store
|
||||
if self.kvstore:
|
||||
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_store_id}")
|
||||
|
||||
logger.info(f"Unregistered vector store: {vector_store_id}")
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
"""Insert chunks into the vector database optimized for RAG."""
|
||||
vector_db_with_index = await self._get_vector_db_index(vector_db_id)
|
||||
await vector_db_with_index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
"""Query chunks from the vector database optimized for RAG context retrieval."""
|
||||
vector_db_with_index = await self._get_vector_db_index(vector_db_id)
|
||||
return await vector_db_with_index.query_chunks(query, params)
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete chunks from the vector database."""
|
||||
vector_db_with_index = await self._get_vector_db_index(store_id)
|
||||
await vector_db_with_index.index.delete_chunks(chunks_for_deletion)
|
||||
|
||||
async def _get_vector_db_index(self, vector_db_id: str) -> VectorStoreWithIndex:
|
||||
"""Get vector store index from cache."""
|
||||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
||||
async def _load_existing_vector_dbs(self) -> None:
|
||||
"""Load existing vector databases from KVStore."""
|
||||
if not self.kvstore:
|
||||
return
|
||||
|
||||
try:
|
||||
# Use keys_in_range to get all vector database keys from KVStore
|
||||
# This searches for keys with the prefix by using range scan
|
||||
start_key = VECTOR_DBS_PREFIX
|
||||
# Create an end key by incrementing the last character
|
||||
end_key = VECTOR_DBS_PREFIX[:-1] + chr(ord(VECTOR_DBS_PREFIX[-1]) + 1)
|
||||
|
||||
vector_db_keys = await self.kvstore.keys_in_range(start_key, end_key)
|
||||
|
||||
for key in vector_db_keys:
|
||||
try:
|
||||
vector_store_data = await self.kvstore.get(key)
|
||||
if vector_store_data:
|
||||
import json
|
||||
|
||||
vector_store = VectorStore(**json.loads(vector_store_data))
|
||||
# Register the vector store without re-initializing
|
||||
await self._register_existing_vector_store(vector_store)
|
||||
logger.info(f"Loaded existing RAG-optimized vector store: {vector_store.identifier}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load vector store from key {key}: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load existing vector stores: {e}")
|
||||
|
||||
async def _register_existing_vector_store(self, vector_store: VectorStore) -> None:
|
||||
"""Register an existing vector store without re-initialization."""
|
||||
if self.database is None:
|
||||
raise RuntimeError("MongoDB database not initialized")
|
||||
|
||||
# Create collection name from vector store identifier
|
||||
collection_name = sanitize_collection_name(vector_store.identifier)
|
||||
collection = self.database[collection_name]
|
||||
|
||||
# Create MongoDB index without initialization (collection already exists)
|
||||
mongodb_index = MongoDBIndex(vector_store, collection, self.config)
|
||||
|
||||
# Create vector store with index wrapper
|
||||
vector_store_with_index = VectorStoreWithIndex(
|
||||
vector_store=vector_store,
|
||||
index=mongodb_index,
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
|
||||
# Cache the vector store
|
||||
self.cache[vector_store.identifier] = vector_store_with_index
|
||||
|
|
@ -25,6 +25,7 @@ distribution_spec:
|
|||
- provider_type: inline::milvus
|
||||
- provider_type: remote::chromadb
|
||||
- provider_type: remote::pgvector
|
||||
- provider_type: remote::mongodb
|
||||
- provider_type: remote::qdrant
|
||||
- provider_type: remote::weaviate
|
||||
files:
|
||||
|
|
|
|||
|
|
@ -131,6 +131,23 @@ providers:
|
|||
persistence:
|
||||
namespace: vector_io::pgvector
|
||||
backend: kv_default
|
||||
- provider_id: ${env.MONGODB_CONNECTION_STRING:+mongodb_atlas}
|
||||
provider_type: remote::mongodb
|
||||
config:
|
||||
connection_string: ${env.MONGODB_CONNECTION_STRING:=}
|
||||
host: ${env.MONGODB_HOST:=localhost}
|
||||
port: ${env.MONGODB_PORT:=27017}
|
||||
username: ${env.MONGODB_USERNAME:=}
|
||||
password: ${env.MONGODB_PASSWORD:=}
|
||||
database_name: ${env.MONGODB_DATABASE_NAME:=llama_stack}
|
||||
index_name: ${env.MONGODB_INDEX_NAME:=vector_index}
|
||||
path_field: ${env.MONGODB_PATH_FIELD:=embedding}
|
||||
similarity_metric: ${env.MONGODB_SIMILARITY_METRIC:=cosine}
|
||||
max_pool_size: ${env.MONGODB_MAX_POOL_SIZE:=100}
|
||||
timeout_ms: ${env.MONGODB_TIMEOUT_MS:=30000}
|
||||
persistence:
|
||||
namespace: vector_io::mongodb_atlas
|
||||
backend: kv_default
|
||||
- provider_id: ${env.QDRANT_URL:+qdrant}
|
||||
provider_type: remote::qdrant
|
||||
config:
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ distribution_spec:
|
|||
- provider_type: inline::milvus
|
||||
- provider_type: remote::chromadb
|
||||
- provider_type: remote::pgvector
|
||||
- provider_type: remote::mongodb
|
||||
- provider_type: remote::qdrant
|
||||
- provider_type: remote::weaviate
|
||||
files:
|
||||
|
|
|
|||
|
|
@ -131,6 +131,23 @@ providers:
|
|||
persistence:
|
||||
namespace: vector_io::pgvector
|
||||
backend: kv_default
|
||||
- provider_id: ${env.MONGODB_CONNECTION_STRING:+mongodb_atlas}
|
||||
provider_type: remote::mongodb
|
||||
config:
|
||||
connection_string: ${env.MONGODB_CONNECTION_STRING:=}
|
||||
host: ${env.MONGODB_HOST:=localhost}
|
||||
port: ${env.MONGODB_PORT:=27017}
|
||||
username: ${env.MONGODB_USERNAME:=}
|
||||
password: ${env.MONGODB_PASSWORD:=}
|
||||
database_name: ${env.MONGODB_DATABASE_NAME:=llama_stack}
|
||||
index_name: ${env.MONGODB_INDEX_NAME:=vector_index}
|
||||
path_field: ${env.MONGODB_PATH_FIELD:=embedding}
|
||||
similarity_metric: ${env.MONGODB_SIMILARITY_METRIC:=cosine}
|
||||
max_pool_size: ${env.MONGODB_MAX_POOL_SIZE:=100}
|
||||
timeout_ms: ${env.MONGODB_TIMEOUT_MS:=30000}
|
||||
persistence:
|
||||
namespace: vector_io::mongodb_atlas
|
||||
backend: kv_default
|
||||
- provider_id: ${env.QDRANT_URL:+qdrant}
|
||||
provider_type: remote::qdrant
|
||||
config:
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ distribution_spec:
|
|||
- provider_type: inline::milvus
|
||||
- provider_type: remote::chromadb
|
||||
- provider_type: remote::pgvector
|
||||
- provider_type: remote::mongodb
|
||||
- provider_type: remote::qdrant
|
||||
- provider_type: remote::weaviate
|
||||
files:
|
||||
|
|
|
|||
|
|
@ -131,6 +131,23 @@ providers:
|
|||
persistence:
|
||||
namespace: vector_io::pgvector
|
||||
backend: kv_default
|
||||
- provider_id: ${env.MONGODB_CONNECTION_STRING:+mongodb_atlas}
|
||||
provider_type: remote::mongodb
|
||||
config:
|
||||
connection_string: ${env.MONGODB_CONNECTION_STRING:=}
|
||||
host: ${env.MONGODB_HOST:=localhost}
|
||||
port: ${env.MONGODB_PORT:=27017}
|
||||
username: ${env.MONGODB_USERNAME:=}
|
||||
password: ${env.MONGODB_PASSWORD:=}
|
||||
database_name: ${env.MONGODB_DATABASE_NAME:=llama_stack}
|
||||
index_name: ${env.MONGODB_INDEX_NAME:=vector_index}
|
||||
path_field: ${env.MONGODB_PATH_FIELD:=embedding}
|
||||
similarity_metric: ${env.MONGODB_SIMILARITY_METRIC:=cosine}
|
||||
max_pool_size: ${env.MONGODB_MAX_POOL_SIZE:=100}
|
||||
timeout_ms: ${env.MONGODB_TIMEOUT_MS:=30000}
|
||||
persistence:
|
||||
namespace: vector_io::mongodb_atlas
|
||||
backend: kv_default
|
||||
- provider_id: ${env.QDRANT_URL:+qdrant}
|
||||
provider_type: remote::qdrant
|
||||
config:
|
||||
|
|
|
|||
|
|
@ -36,11 +36,14 @@ from llama_stack.providers.inline.vector_io.sqlite_vec.config import (
|
|||
)
|
||||
from llama_stack.providers.registry.inference import available_providers
|
||||
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.mongodb.config import MongoDBVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.pgvector.config import (
|
||||
PGVectorVectorIOConfig,
|
||||
)
|
||||
from llama_stack.providers.remote.vector_io.qdrant.config import QdrantVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.weaviate.config import (
|
||||
WeaviateVectorIOConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore.config import PostgresKVStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
|
||||
|
||||
|
|
@ -124,6 +127,7 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
|||
BuildProvider(provider_type="inline::milvus"),
|
||||
BuildProvider(provider_type="remote::chromadb"),
|
||||
BuildProvider(provider_type="remote::pgvector"),
|
||||
BuildProvider(provider_type="remote::mongodb"),
|
||||
BuildProvider(provider_type="remote::qdrant"),
|
||||
BuildProvider(provider_type="remote::weaviate"),
|
||||
],
|
||||
|
|
@ -254,7 +258,70 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
|||
additional_pip_packages=list(set(PostgresSqlStoreConfig.pip_packages() + PostgresKVStoreConfig.pip_packages())),
|
||||
run_configs={
|
||||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides=default_overrides,
|
||||
provider_overrides={
|
||||
"inference": remote_inference_providers + [embedding_provider],
|
||||
"vector_io": [
|
||||
Provider(
|
||||
provider_id="faiss",
|
||||
provider_type="inline::faiss",
|
||||
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
),
|
||||
Provider(
|
||||
provider_id="sqlite-vec",
|
||||
provider_type="inline::sqlite-vec",
|
||||
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.MILVUS_URL:+milvus}",
|
||||
provider_type="inline::milvus",
|
||||
config=MilvusVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.CHROMADB_URL:+chromadb}",
|
||||
provider_type="remote::chromadb",
|
||||
config=ChromaVectorIOConfig.sample_run_config(
|
||||
f"~/.llama/distributions/{name}/",
|
||||
url="${env.CHROMADB_URL:=}",
|
||||
),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.PGVECTOR_DB:+pgvector}",
|
||||
provider_type="remote::pgvector",
|
||||
config=PGVectorVectorIOConfig.sample_run_config(
|
||||
f"~/.llama/distributions/{name}",
|
||||
db="${env.PGVECTOR_DB:=}",
|
||||
user="${env.PGVECTOR_USER:=}",
|
||||
password="${env.PGVECTOR_PASSWORD:=}",
|
||||
),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.MONGODB_CONNECTION_STRING:+mongodb_atlas}",
|
||||
provider_type="remote::mongodb",
|
||||
config=MongoDBVectorIOConfig.sample_run_config(
|
||||
f"~/.llama/distributions/{name}",
|
||||
connection_string="${env.MONGODB_CONNECTION_STRING:=}",
|
||||
database_name="${env.MONGODB_DATABASE_NAME:=llama_stack}",
|
||||
),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.QDRANT_URL:+qdrant}",
|
||||
provider_type="remote::qdrant",
|
||||
config=QdrantVectorIOConfig.sample_run_config(
|
||||
f"~/.llama/distributions/{name}",
|
||||
url="${env.QDRANT_URL:=}",
|
||||
),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.WEAVIATE_CLUSTER_URL:+weaviate}",
|
||||
provider_type="remote::weaviate",
|
||||
config=WeaviateVectorIOConfig.sample_run_config(
|
||||
f"~/.llama/distributions/{name}",
|
||||
cluster_url="${env.WEAVIATE_CLUSTER_URL:=}",
|
||||
),
|
||||
),
|
||||
],
|
||||
"files": [files_provider],
|
||||
},
|
||||
default_models=[],
|
||||
default_tool_groups=default_tool_groups,
|
||||
default_shields=default_shields,
|
||||
|
|
@ -384,5 +451,13 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
|||
"azure",
|
||||
"Azure API Type",
|
||||
),
|
||||
"MONGODB_CONNECTION_STRING": (
|
||||
"",
|
||||
"MongoDB Atlas connection string (e.g., mongodb+srv://user:pass@cluster.mongodb.net/)",
|
||||
),
|
||||
"MONGODB_DATABASE_NAME": (
|
||||
"llama_stack",
|
||||
"MongoDB database name",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -823,6 +823,132 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
|
|||
optional_api_dependencies=[Api.files, Api.models],
|
||||
description="""
|
||||
Please refer to the remote provider documentation.
|
||||
""",
|
||||
),
|
||||
RemoteProviderSpec(
|
||||
api=Api.vector_io,
|
||||
adapter_type="mongodb",
|
||||
provider_type="remote::mongodb",
|
||||
pip_packages=["pymongo>=4.0.0"],
|
||||
module="llama_stack.providers.remote.vector_io.mongodb",
|
||||
config_class="llama_stack.providers.remote.vector_io.mongodb.MongoDBVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
description="""
|
||||
[MongoDB Atlas](https://www.mongodb.com/products/platform/atlas-vector-search) is a remote vector database provider for Llama Stack. It
|
||||
uses MongoDB Atlas Vector Search to store and query vectors in the cloud.
|
||||
That means you get enterprise-grade vector search with MongoDB's scalability and reliability.
|
||||
|
||||
## Features
|
||||
|
||||
- Cloud-native vector search with MongoDB Atlas
|
||||
- Fully integrated with Llama Stack
|
||||
- Enterprise-grade security and scalability
|
||||
- Supports multiple search modes: vector, keyword, and hybrid search
|
||||
- Built-in metadata filtering and text search capabilities
|
||||
- Automatic index management
|
||||
|
||||
## Search Modes
|
||||
|
||||
MongoDB Atlas Vector Search supports three different search modes:
|
||||
|
||||
### Vector Search
|
||||
Vector search uses MongoDB's `$vectorSearch` aggregation stage to perform semantic similarity search using embedding vectors.
|
||||
|
||||
```python
|
||||
# Vector search example
|
||||
search_response = client.vector_stores.search(
|
||||
vector_store_id=vector_store.id,
|
||||
query="What is machine learning?",
|
||||
search_mode="vector",
|
||||
max_num_results=5,
|
||||
)
|
||||
```
|
||||
|
||||
### Keyword Search
|
||||
Keyword search uses MongoDB's text search capabilities with full-text indexes to find chunks containing specific terms.
|
||||
|
||||
```python
|
||||
# Keyword search example
|
||||
search_response = client.vector_stores.search(
|
||||
vector_store_id=vector_store.id,
|
||||
query="Python programming language",
|
||||
search_mode="keyword",
|
||||
max_num_results=5,
|
||||
)
|
||||
```
|
||||
|
||||
### Hybrid Search
|
||||
Hybrid search combines both vector and keyword search methods using configurable reranking algorithms.
|
||||
|
||||
```python
|
||||
# Hybrid search with RRF ranker (default)
|
||||
search_response = client.vector_stores.search(
|
||||
vector_store_id=vector_store.id,
|
||||
query="neural networks in Python",
|
||||
search_mode="hybrid",
|
||||
max_num_results=5,
|
||||
)
|
||||
|
||||
# Hybrid search with weighted ranker
|
||||
search_response = client.vector_stores.search(
|
||||
vector_store_id=vector_store.id,
|
||||
query="neural networks in Python",
|
||||
search_mode="hybrid",
|
||||
max_num_results=5,
|
||||
ranking_options={
|
||||
"ranker": {
|
||||
"type": "weighted",
|
||||
"alpha": 0.7, # 70% vector search, 30% keyword search
|
||||
}
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
To use MongoDB Atlas in your Llama Stack project, follow these steps:
|
||||
|
||||
1. Create a MongoDB Atlas cluster with Vector Search enabled
|
||||
2. Install the necessary dependencies
|
||||
3. Configure your Llama Stack project to use MongoDB
|
||||
4. Start storing and querying vectors
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables
|
||||
Set up the following environment variable for your MongoDB Atlas connection:
|
||||
|
||||
```bash
|
||||
export MONGODB_CONNECTION_STRING="mongodb+srv://username:password@cluster.mongodb.net/?retryWrites=true&w=majority&appName=llama-stack"
|
||||
```
|
||||
|
||||
### Configuration Example
|
||||
|
||||
```yaml
|
||||
vector_io:
|
||||
- provider_id: mongodb_atlas
|
||||
provider_type: remote::mongodb
|
||||
config:
|
||||
connection_string: "${env.MONGODB_CONNECTION_STRING}"
|
||||
database_name: "llama_stack"
|
||||
index_name: "vector_index"
|
||||
similarity_metric: "cosine"
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
You can install the MongoDB Python driver using pip:
|
||||
|
||||
```bash
|
||||
pip install pymongo
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
||||
See [MongoDB Atlas Vector Search documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-overview/) for more details about MongoDB Atlas Vector Search.
|
||||
|
||||
For general MongoDB documentation, visit [MongoDB Documentation](https://docs.mongodb.com/).
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import MongoDBVectorIOConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: MongoDBVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||
from .mongodb import MongoDBVectorIOAdapter
|
||||
|
||||
# Handle the deps resolution - if files API exists, pass it, otherwise None
|
||||
files_api = deps.get(Api.files)
|
||||
models_api = deps.get(Api.models)
|
||||
impl = MongoDBVectorIOAdapter(config, deps[Api.inference], files_api, models_api)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
110
src/llama_stack/providers/remote/vector_io/mongodb/config.py
Normal file
110
src/llama_stack/providers/remote/vector_io/mongodb/config.py
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MongoDBVectorIOConfig(BaseModel):
|
||||
"""Configuration for MongoDB Atlas Vector Search provider.
|
||||
|
||||
This provider connects to MongoDB Atlas and uses Vector Search for RAG operations.
|
||||
"""
|
||||
|
||||
# MongoDB connection details - either connection_string or individual parameters
|
||||
connection_string: str | None = Field(
|
||||
default=None,
|
||||
description="MongoDB connection string (e.g., mongodb://user:pass@localhost:27017/ or mongodb+srv://user:pass@cluster.mongodb.net/)",
|
||||
)
|
||||
host: str | None = Field(
|
||||
default=None,
|
||||
description="MongoDB host (used if connection_string is not provided)",
|
||||
)
|
||||
port: int | None = Field(
|
||||
default=None,
|
||||
description="MongoDB port (used if connection_string is not provided)",
|
||||
)
|
||||
username: str | None = Field(
|
||||
default=None,
|
||||
description="MongoDB username (used if connection_string is not provided)",
|
||||
)
|
||||
password: str | None = Field(
|
||||
default=None,
|
||||
description="MongoDB password (used if connection_string is not provided)",
|
||||
)
|
||||
database_name: str = Field(default="llama_stack", description="Database name to use for vector collections")
|
||||
|
||||
# Vector search configuration
|
||||
index_name: str = Field(default="vector_index", description="Name of the vector search index")
|
||||
path_field: str = Field(default="embedding", description="Field name for storing embeddings")
|
||||
similarity_metric: str = Field(
|
||||
default="cosine",
|
||||
description="Similarity metric: cosine, euclidean, or dotProduct",
|
||||
)
|
||||
|
||||
# Connection options
|
||||
max_pool_size: int = Field(default=100, description="Maximum connection pool size")
|
||||
timeout_ms: int = Field(default=30000, description="Connection timeout in milliseconds")
|
||||
|
||||
# KV store configuration
|
||||
persistence: KVStoreReference | None = Field(
|
||||
description="Config for KV store backend for metadata storage", default=None
|
||||
)
|
||||
|
||||
def get_connection_string(self) -> str | None:
|
||||
"""Build connection string from individual parameters if not provided directly.
|
||||
|
||||
If both connection_string and individual parameters (host/port) are provided,
|
||||
individual parameters take precedence to allow test environment overrides.
|
||||
"""
|
||||
# Prioritize individual connection parameters over connection_string
|
||||
# This allows test environments to override with MONGODB_HOST/PORT/etc
|
||||
if self.host and self.port:
|
||||
auth_part = ""
|
||||
if self.username and self.password:
|
||||
auth_part = f"{self.username}:{self.password}@"
|
||||
return f"mongodb://{auth_part}{self.host}:{self.port}/"
|
||||
|
||||
# Fall back to connection_string if provided
|
||||
if self.connection_string:
|
||||
return self.connection_string
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
__distro_dir__: str,
|
||||
connection_string: str = "${env.MONGODB_CONNECTION_STRING:=}",
|
||||
host: str = "${env.MONGODB_HOST:=localhost}",
|
||||
port: str = "${env.MONGODB_PORT:=27017}",
|
||||
username: str = "${env.MONGODB_USERNAME:=}",
|
||||
password: str = "${env.MONGODB_PASSWORD:=}",
|
||||
database_name: str = "${env.MONGODB_DATABASE_NAME:=llama_stack}",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"connection_string": connection_string,
|
||||
"host": host,
|
||||
"port": port,
|
||||
"username": username,
|
||||
"password": password,
|
||||
"database_name": database_name,
|
||||
"index_name": "${env.MONGODB_INDEX_NAME:=vector_index}",
|
||||
"path_field": "${env.MONGODB_PATH_FIELD:=embedding}",
|
||||
"similarity_metric": "${env.MONGODB_SIMILARITY_METRIC:=cosine}",
|
||||
"max_pool_size": "${env.MONGODB_MAX_POOL_SIZE:=100}",
|
||||
"timeout_ms": "${env.MONGODB_TIMEOUT_MS:=30000}",
|
||||
"persistence": KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="vector_io::mongodb_atlas",
|
||||
).model_dump(exclude_none=True),
|
||||
}
|
||||
631
src/llama_stack/providers/remote/vector_io/mongodb/mongodb.py
Normal file
631
src/llama_stack/providers/remote/vector_io/mongodb/mongodb.py
Normal file
|
|
@ -0,0 +1,631 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import heapq
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from numpy.typing import NDArray
|
||||
from pymongo import MongoClient
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.database import Database
|
||||
from pymongo.operations import SearchIndexModel
|
||||
from pymongo.server_api import ServerApi
|
||||
|
||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
VectorStoresProtocolPrivate,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import (
|
||||
OpenAIVectorStoreMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorStoreWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import (
|
||||
WeightedInMemoryAggregator,
|
||||
sanitize_collection_name,
|
||||
)
|
||||
|
||||
from .config import MongoDBVectorIOConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="vector_io::mongodb")
|
||||
|
||||
VERSION = "v1"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:mongodb:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:mongodb:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:mongodb:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:mongodb:{VERSION}::"
|
||||
|
||||
|
||||
class MongoDBIndex(EmbeddingIndex):
|
||||
"""MongoDB Atlas Vector Search index implementation optimized for RAG."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_store: VectorStore,
|
||||
collection: Collection,
|
||||
config: MongoDBVectorIOConfig,
|
||||
):
|
||||
self.vector_store = vector_store
|
||||
self.collection = collection
|
||||
self.config = config
|
||||
self.dimension = vector_store.embedding_dimension
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the MongoDB collection and ensure vector search index exists."""
|
||||
try:
|
||||
# Create the collection if it doesn't exist
|
||||
collection_names = self.collection.database.list_collection_names()
|
||||
if self.collection.name not in collection_names:
|
||||
logger.info(f"Creating collection '{self.collection.name}'")
|
||||
# Create collection by inserting a dummy document
|
||||
dummy_doc = {"_id": "__dummy__", "dummy": True}
|
||||
self.collection.insert_one(dummy_doc)
|
||||
# Remove the dummy document
|
||||
self.collection.delete_one({"_id": "__dummy__"})
|
||||
logger.info(f"Collection '{self.collection.name}' created successfully")
|
||||
|
||||
# Create optimized vector search index for RAG
|
||||
await self._create_vector_search_index()
|
||||
|
||||
# Create text index for hybrid search
|
||||
await self._ensure_text_index()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to initialize MongoDB index for vector_store: {self.vector_store.identifier}. "
|
||||
f"Collection name: {self.collection.name}. Error: {str(e)}"
|
||||
)
|
||||
# Don't fail completely - just log the error and continue
|
||||
logger.warning(
|
||||
"Continuing without complete index initialization. "
|
||||
"You may need to create indexes manually in MongoDB Atlas dashboard."
|
||||
)
|
||||
|
||||
async def _create_vector_search_index(self) -> None:
|
||||
"""Create optimized vector search index based on MongoDB RAG best practices."""
|
||||
try:
|
||||
# Check if vector search index exists
|
||||
indexes = list(self.collection.list_search_indexes())
|
||||
index_exists = any(idx.get("name") == self.config.index_name for idx in indexes)
|
||||
|
||||
if not index_exists:
|
||||
# Create vector search index optimized for RAG
|
||||
# Based on MongoDB's RAG example using new vectorSearch format
|
||||
search_index_model = SearchIndexModel(
|
||||
definition={
|
||||
"fields": [
|
||||
{
|
||||
"type": "vector",
|
||||
"numDimensions": self.dimension,
|
||||
"path": self.config.path_field,
|
||||
"similarity": self._convert_similarity_metric(self.config.similarity_metric),
|
||||
}
|
||||
]
|
||||
},
|
||||
name=self.config.index_name,
|
||||
type="vectorSearch",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Creating vector search index '{self.config.index_name}' for RAG on collection '{self.collection.name}'"
|
||||
)
|
||||
|
||||
self.collection.create_search_index(model=search_index_model)
|
||||
|
||||
# Wait for index to be ready (like in MongoDB RAG example)
|
||||
await self._wait_for_index_ready()
|
||||
|
||||
logger.info("Vector search index created and ready for RAG queries")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create vector search index: {e}")
|
||||
|
||||
def _convert_similarity_metric(self, metric: str) -> str:
|
||||
"""Convert internal similarity metric to MongoDB Atlas format."""
|
||||
metric_map = {
|
||||
"cosine": "cosine",
|
||||
"euclidean": "euclidean",
|
||||
"dotProduct": "dotProduct",
|
||||
"dot_product": "dotProduct",
|
||||
}
|
||||
return metric_map.get(metric, "cosine")
|
||||
|
||||
async def _wait_for_index_ready(self) -> None:
|
||||
"""Wait for the vector search index to be ready, based on MongoDB RAG example."""
|
||||
logger.info("Waiting for vector search index to be ready...")
|
||||
|
||||
max_wait_time = 300 # 5 minutes max wait
|
||||
wait_interval = 5
|
||||
elapsed_time = 0
|
||||
|
||||
while elapsed_time < max_wait_time:
|
||||
try:
|
||||
indices = list(self.collection.list_search_indexes(self.config.index_name))
|
||||
if len(indices) and indices[0].get("queryable") is True:
|
||||
logger.info(f"Vector search index '{self.config.index_name}' is ready for querying")
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time.sleep(wait_interval)
|
||||
elapsed_time += wait_interval
|
||||
|
||||
logger.warning(f"Vector search index may not be fully ready after {max_wait_time}s")
|
||||
|
||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray) -> None:
|
||||
"""Add chunks with embeddings to MongoDB collection optimized for RAG."""
|
||||
if len(chunks) != len(embeddings):
|
||||
raise ValueError(f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}")
|
||||
|
||||
documents = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
# Structure document for optimal RAG retrieval
|
||||
doc = {
|
||||
"_id": chunk.chunk_id,
|
||||
"chunk_id": chunk.chunk_id,
|
||||
"text": interleaved_content_as_str(chunk.content), # Key field for RAG context
|
||||
"content": interleaved_content_as_str(chunk.content), # Backward compatibility
|
||||
"metadata": chunk.metadata or {},
|
||||
"chunk_metadata": (chunk.chunk_metadata.model_dump() if chunk.chunk_metadata else {}),
|
||||
self.config.path_field: embeddings[i].tolist(), # Vector embeddings
|
||||
"document": chunk.model_dump(), # Full chunk data
|
||||
}
|
||||
documents.append(doc)
|
||||
|
||||
try:
|
||||
# Use upsert behavior for chunks
|
||||
for doc in documents:
|
||||
self.collection.replace_one({"_id": doc["_id"]}, doc, upsert=True)
|
||||
|
||||
logger.debug(f"Successfully added {len(chunks)} chunks optimized for RAG to MongoDB collection")
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to add chunks to MongoDB collection: {e}")
|
||||
raise
|
||||
|
||||
async def query_vector(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
"""Perform vector similarity search optimized for RAG based on MongoDB example."""
|
||||
try:
|
||||
# Use MongoDB's vector search aggregation pipeline optimized for RAG
|
||||
pipeline = [
|
||||
{
|
||||
"$vectorSearch": {
|
||||
"index": self.config.index_name,
|
||||
"queryVector": embedding.tolist(),
|
||||
"path": self.config.path_field,
|
||||
"numCandidates": min(k * 10, 1000), # Cap at 1000 to prevent excessive candidates
|
||||
"limit": k,
|
||||
}
|
||||
},
|
||||
{
|
||||
"$project": {
|
||||
"_id": 0,
|
||||
"text": 1, # Primary field for RAG context
|
||||
"content": 1, # Backward compatibility
|
||||
"metadata": 1,
|
||||
"chunk_metadata": 1,
|
||||
"document": 1,
|
||||
"score": {"$meta": "vectorSearchScore"},
|
||||
}
|
||||
},
|
||||
{"$match": {"score": {"$gte": score_threshold}}},
|
||||
]
|
||||
|
||||
results = list(self.collection.aggregate(pipeline))
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
for result in results:
|
||||
score = result.get("score", 0.0)
|
||||
if score >= score_threshold:
|
||||
chunk_data = result.get("document", {})
|
||||
if chunk_data:
|
||||
chunks.append(Chunk(**chunk_data))
|
||||
scores.append(float(score))
|
||||
|
||||
logger.debug(f"Vector search for RAG returned {len(chunks)} results")
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Vector search for RAG failed: {e}")
|
||||
raise RuntimeError(f"Vector search for RAG failed: {e}") from e
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
"""Perform text search using MongoDB's text search for RAG context retrieval."""
|
||||
try:
|
||||
# Ensure text index exists
|
||||
await self._ensure_text_index()
|
||||
|
||||
pipeline: list[dict[str, Any]] = [
|
||||
{"$match": {"$text": {"$search": query_string}}},
|
||||
{
|
||||
"$project": {
|
||||
"_id": 0,
|
||||
"text": 1, # Primary field for RAG context
|
||||
"content": 1, # Backward compatibility
|
||||
"metadata": 1,
|
||||
"chunk_metadata": 1,
|
||||
"document": 1,
|
||||
"score": {"$meta": "textScore"},
|
||||
}
|
||||
},
|
||||
{"$match": {"score": {"$gte": score_threshold}}},
|
||||
{"$sort": {"score": {"$meta": "textScore"}}},
|
||||
{"$limit": k},
|
||||
]
|
||||
|
||||
results = list(self.collection.aggregate(pipeline))
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
for result in results:
|
||||
score = result.get("score", 0.0)
|
||||
if score >= score_threshold:
|
||||
chunk_data = result.get("document", {})
|
||||
if chunk_data:
|
||||
chunks.append(Chunk(**chunk_data))
|
||||
scores.append(float(score))
|
||||
|
||||
logger.debug(f"Keyword search for RAG returned {len(chunks)} results")
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Keyword search for RAG failed: {e}")
|
||||
raise RuntimeError(f"Keyword search for RAG failed: {e}") from e
|
||||
|
||||
async def query_hybrid(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
"""Perform hybrid search for enhanced RAG context retrieval."""
|
||||
if reranker_params is None:
|
||||
reranker_params = {}
|
||||
|
||||
# Get results from both search methods
|
||||
vector_response = await self.query_vector(embedding, k, 0.0)
|
||||
keyword_response = await self.query_keyword(query_string, k, 0.0)
|
||||
|
||||
# Convert responses to score dictionaries
|
||||
vector_scores = {
|
||||
chunk.chunk_id: score for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False)
|
||||
}
|
||||
keyword_scores = {
|
||||
chunk.chunk_id: score
|
||||
for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False)
|
||||
}
|
||||
|
||||
# Combine scores using the reranking utility
|
||||
combined_scores = WeightedInMemoryAggregator.combine_search_results(
|
||||
vector_scores, keyword_scores, reranker_type, reranker_params
|
||||
)
|
||||
|
||||
# Get top-k results
|
||||
top_k_items = heapq.nlargest(k, combined_scores.items(), key=lambda x: x[1])
|
||||
|
||||
# Filter by score threshold
|
||||
filtered_items = [(doc_id, score) for doc_id, score in top_k_items if score >= score_threshold]
|
||||
|
||||
# Create chunk map
|
||||
chunk_map = {c.chunk_id: c for c in vector_response.chunks + keyword_response.chunks}
|
||||
|
||||
# Build final results
|
||||
chunks = []
|
||||
scores = []
|
||||
for doc_id, score in filtered_items:
|
||||
if doc_id in chunk_map:
|
||||
chunks.append(chunk_map[doc_id])
|
||||
scores.append(score)
|
||||
|
||||
logger.debug(f"Hybrid search for RAG returned {len(chunks)} results")
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete chunks from MongoDB collection."""
|
||||
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
|
||||
try:
|
||||
result = self.collection.delete_many({"_id": {"$in": chunk_ids}})
|
||||
logger.debug(f"Deleted {result.deleted_count} chunks from MongoDB collection")
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to delete chunks: {e}")
|
||||
raise
|
||||
|
||||
async def delete(self) -> None:
|
||||
"""Delete the entire collection."""
|
||||
try:
|
||||
self.collection.drop()
|
||||
logger.debug(f"Dropped MongoDB collection: {self.collection.name}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to drop collection: {e}")
|
||||
raise
|
||||
|
||||
async def _ensure_text_index(self) -> None:
|
||||
"""Ensure text search index exists on content fields for RAG."""
|
||||
try:
|
||||
indexes = list(self.collection.list_indexes())
|
||||
text_index_exists = any(
|
||||
any(key.startswith(("content", "text")) for key in idx.get("key", {}).keys())
|
||||
and idx.get("textIndexVersion") is not None
|
||||
for idx in indexes
|
||||
)
|
||||
|
||||
if not text_index_exists:
|
||||
logger.info("Creating text search index on content fields for RAG")
|
||||
# Index both 'text' and 'content' fields for comprehensive text search
|
||||
self.collection.create_index([("text", "text"), ("content", "text")])
|
||||
logger.info("Text search index created successfully for RAG")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create text index for RAG: {e}")
|
||||
|
||||
|
||||
class MongoDBVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
|
||||
"""MongoDB Atlas Vector Search adapter for Llama Stack optimized for RAG workflows."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MongoDBVectorIOConfig,
|
||||
inference_api,
|
||||
files_api=None,
|
||||
models_api=None,
|
||||
) -> None:
|
||||
# Handle the case where files_api might be a ProviderSpec that needs resolution
|
||||
resolved_files_api = files_api
|
||||
super().__init__(files_api=resolved_files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.models_api = models_api
|
||||
self.client: MongoClient | None = None
|
||||
self.database: Database | None = None
|
||||
self.cache: dict[str, VectorStoreWithIndex] = {}
|
||||
self.kvstore: KVStore | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize MongoDB connection optimized for RAG workflows."""
|
||||
logger.info("Initializing MongoDB Atlas Vector IO adapter for RAG")
|
||||
|
||||
try:
|
||||
# Initialize KV store for metadata
|
||||
if self.config.persistence:
|
||||
self.kvstore = await kvstore_impl(self.config.persistence)
|
||||
|
||||
# Get connection string from config (either direct or built from parameters)
|
||||
connection_string = self.config.get_connection_string()
|
||||
|
||||
# Skip MongoDB connection if no connection parameters provided
|
||||
# This allows other providers to work without MongoDB credentials
|
||||
if not connection_string:
|
||||
logger.warning(
|
||||
"MongoDB connection parameters not provided. "
|
||||
"MongoDB vector store will not be available until credentials are configured."
|
||||
)
|
||||
return
|
||||
|
||||
# Connect to MongoDB with optimized settings for RAG
|
||||
self.client = MongoClient(
|
||||
connection_string,
|
||||
server_api=ServerApi("1"),
|
||||
maxPoolSize=self.config.max_pool_size,
|
||||
serverSelectionTimeoutMS=self.config.timeout_ms,
|
||||
# Additional settings for RAG performance
|
||||
retryWrites=True,
|
||||
readPreference="primaryPreferred",
|
||||
)
|
||||
|
||||
# Test connection
|
||||
try:
|
||||
self.client.admin.command("ping")
|
||||
logger.info("Successfully connected to MongoDB Atlas for RAG")
|
||||
except Exception as conn_error:
|
||||
# Extract just the basic error type without the full traceback
|
||||
error_type = type(conn_error).__name__
|
||||
logger.warning(
|
||||
f"MongoDB connection failed ({error_type}). "
|
||||
"MongoDB vector store will not be available. "
|
||||
f"Attempted to connect to: {self.config.host or 'connection_string'}:{self.config.port or '(from connection_string)'}"
|
||||
)
|
||||
# Close the client and clear it
|
||||
if self.client:
|
||||
self.client.close()
|
||||
self.client = None
|
||||
return
|
||||
|
||||
# Get database
|
||||
self.database = self.client[self.config.database_name]
|
||||
|
||||
# Initialize OpenAI vector stores
|
||||
await self.initialize_openai_vector_stores()
|
||||
|
||||
# Load existing vector databases
|
||||
await self._load_existing_vector_dbs()
|
||||
|
||||
logger.info("MongoDB Atlas Vector IO adapter for RAG initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to initialize MongoDB Atlas Vector IO adapter for RAG")
|
||||
# Close the client if it was created
|
||||
if self.client:
|
||||
self.client.close()
|
||||
self.client = None
|
||||
# Log warning instead of raising to allow tests to skip gracefully
|
||||
logger.warning(f"MongoDB initialization failed: {e}. MongoDB vector store will not be available.")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Shutdown MongoDB connection."""
|
||||
if self.client:
|
||||
self.client.close()
|
||||
logger.info("MongoDB Atlas RAG connection closed")
|
||||
|
||||
async def health(self) -> HealthResponse:
|
||||
"""Perform health check on MongoDB connection."""
|
||||
try:
|
||||
if self.client:
|
||||
self.client.admin.command("ping")
|
||||
return HealthResponse(status=HealthStatus.OK)
|
||||
else:
|
||||
return HealthResponse(status=HealthStatus.ERROR, message="MongoDB client not initialized")
|
||||
except Exception as e:
|
||||
return HealthResponse(
|
||||
status=HealthStatus.ERROR,
|
||||
message=f"MongoDB RAG health check failed: {str(e)}",
|
||||
)
|
||||
|
||||
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
||||
"""Register a new vector store optimized for RAG."""
|
||||
if self.database is None:
|
||||
raise RuntimeError("MongoDB database not initialized")
|
||||
|
||||
# Create collection name from vector store identifier
|
||||
collection_name = sanitize_collection_name(vector_store.identifier)
|
||||
collection = self.database[collection_name]
|
||||
|
||||
# Create and initialize MongoDB index optimized for RAG
|
||||
mongodb_index = MongoDBIndex(vector_store, collection, self.config)
|
||||
await mongodb_index.initialize()
|
||||
|
||||
# Create vector store with index wrapper
|
||||
vector_store_with_index = VectorStoreWithIndex(
|
||||
vector_store=vector_store,
|
||||
index=mongodb_index,
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
|
||||
# Cache the vector store
|
||||
self.cache[vector_store.identifier] = vector_store_with_index
|
||||
|
||||
# Save vector store info to KVStore for persistence
|
||||
if self.kvstore:
|
||||
await self.kvstore.set(
|
||||
f"{VECTOR_DBS_PREFIX}{vector_store.identifier}",
|
||||
vector_store.model_dump_json(),
|
||||
)
|
||||
|
||||
logger.info(f"Registered vector store for RAG: {vector_store.identifier}")
|
||||
|
||||
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||
"""Unregister a vector store."""
|
||||
if vector_store_id in self.cache:
|
||||
await self.cache[vector_store_id].index.delete()
|
||||
del self.cache[vector_store_id]
|
||||
|
||||
# Clean up from KV store
|
||||
if self.kvstore:
|
||||
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_store_id}")
|
||||
|
||||
logger.info(f"Unregistered vector store: {vector_store_id}")
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
"""Insert chunks into the vector database optimized for RAG."""
|
||||
vector_db_with_index = await self._get_vector_db_index(vector_store_id)
|
||||
await vector_db_with_index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
"""Query chunks from the vector database optimized for RAG context retrieval."""
|
||||
vector_db_with_index = await self._get_vector_db_index(vector_store_id)
|
||||
return await vector_db_with_index.query_chunks(query, params)
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete chunks from the vector database."""
|
||||
vector_db_with_index = await self._get_vector_db_index(store_id)
|
||||
await vector_db_with_index.index.delete_chunks(chunks_for_deletion)
|
||||
|
||||
async def _get_vector_db_index(self, vector_db_id: str) -> VectorStoreWithIndex:
|
||||
"""Get vector store index from cache."""
|
||||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
||||
async def _load_existing_vector_dbs(self) -> None:
|
||||
"""Load existing vector databases from KVStore."""
|
||||
if not self.kvstore:
|
||||
return
|
||||
|
||||
try:
|
||||
# Use keys_in_range to get all vector database keys from KVStore
|
||||
# This searches for keys with the prefix by using range scan
|
||||
start_key = VECTOR_DBS_PREFIX
|
||||
# Create an end key by incrementing the last character
|
||||
end_key = VECTOR_DBS_PREFIX[:-1] + chr(ord(VECTOR_DBS_PREFIX[-1]) + 1)
|
||||
|
||||
vector_db_keys = await self.kvstore.keys_in_range(start_key, end_key)
|
||||
|
||||
for key in vector_db_keys:
|
||||
try:
|
||||
vector_store_data = await self.kvstore.get(key)
|
||||
if vector_store_data:
|
||||
import json
|
||||
|
||||
vector_store = VectorStore(**json.loads(vector_store_data))
|
||||
# Register the vector store without re-initializing
|
||||
await self._register_existing_vector_store(vector_store)
|
||||
logger.info(f"Loaded existing RAG-optimized vector store: {vector_store.identifier}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load vector store from key {key}: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load existing vector stores: {e}")
|
||||
|
||||
async def _register_existing_vector_store(self, vector_store: VectorStore) -> None:
|
||||
"""Register an existing vector store without re-initialization."""
|
||||
if self.database is None:
|
||||
raise RuntimeError("MongoDB database not initialized")
|
||||
|
||||
# Create collection name from vector store identifier
|
||||
collection_name = sanitize_collection_name(vector_store.identifier)
|
||||
collection = self.database[collection_name]
|
||||
|
||||
# Create MongoDB index without initialization (collection already exists)
|
||||
mongodb_index = MongoDBIndex(vector_store, collection, self.config)
|
||||
|
||||
# Create vector store with index wrapper
|
||||
vector_store_with_index = VectorStoreWithIndex(
|
||||
vector_store=vector_store,
|
||||
index=mongodb_index,
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
|
||||
# Cache the vector store
|
||||
self.cache[vector_store.identifier] = vector_store_with_index
|
||||
5
tests/unit/providers/vector_io/__init__.py
Normal file
5
tests/unit/providers/vector_io/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
Loading…
Add table
Add a link
Reference in a new issue