mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
init: add mongodb in vector_io
This commit is contained in:
parent
0066d986c5
commit
dd1136a1ac
5 changed files with 1080 additions and 0 deletions
19
llama_stack/providers/remote/vector_io/mongodb/__init__.py
Normal file
19
llama_stack/providers/remote/vector_io/mongodb/__init__.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import MongoDBVectorIOConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: MongoDBVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||
from .mongodb import MongoDBVectorIOAdapter
|
||||
|
||||
# Handle the deps resolution - if files API exists, pass it, otherwise None
|
||||
files_api = deps.get(Api.files)
|
||||
impl = MongoDBVectorIOAdapter(config, deps[Api.inference], files_api)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
66
llama_stack/providers/remote/vector_io/mongodb/config.py
Normal file
66
llama_stack/providers/remote/vector_io/mongodb/config.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MongoDBVectorIOConfig(BaseModel):
|
||||
"""Configuration for MongoDB Atlas Vector Search provider.
|
||||
|
||||
This provider connects to MongoDB Atlas and uses Vector Search for RAG operations.
|
||||
"""
|
||||
|
||||
# MongoDB Atlas connection details
|
||||
connection_string: str = Field(
|
||||
description="MongoDB Atlas connection string (e.g., mongodb+srv://user:pass@cluster.mongodb.net/)"
|
||||
)
|
||||
database_name: str = Field(default="llama_stack", description="Database name to use for vector collections")
|
||||
|
||||
# Vector search configuration
|
||||
index_name: str = Field(default="vector_index", description="Name of the vector search index")
|
||||
path_field: str = Field(default="embedding", description="Field name for storing embeddings")
|
||||
similarity_metric: str = Field(
|
||||
default="cosine",
|
||||
description="Similarity metric: cosine, euclidean, or dotProduct",
|
||||
)
|
||||
|
||||
# Connection options
|
||||
max_pool_size: int = Field(default=100, description="Maximum connection pool size")
|
||||
timeout_ms: int = Field(default=30000, description="Connection timeout in milliseconds")
|
||||
|
||||
# KV store for metadata
|
||||
kvstore: KVStoreConfig = Field(description="Config for KV store backend for metadata storage")
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
__distro_dir__: str,
|
||||
connection_string: str = "${env.MONGODB_CONNECTION_STRING}",
|
||||
database_name: str = "llama_stack",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"connection_string": connection_string,
|
||||
"database_name": database_name,
|
||||
"index_name": "vector_index",
|
||||
"path_field": "embedding",
|
||||
"similarity_metric": "cosine",
|
||||
"max_pool_size": 100,
|
||||
"timeout_ms": 30000,
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="mongodb_registry.db",
|
||||
),
|
||||
}
|
||||
601
llama_stack/providers/remote/vector_io/mongodb/mongodb.py
Normal file
601
llama_stack/providers/remote/vector_io/mongodb/mongodb.py
Normal file
|
|
@ -0,0 +1,601 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import heapq
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from numpy.typing import NDArray
|
||||
from pymongo import MongoClient
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.database import Database
|
||||
from pymongo.operations import SearchIndexModel
|
||||
from pymongo.server_api import ServerApi
|
||||
|
||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
VectorDBsProtocolPrivate,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import (
|
||||
OpenAIVectorStoreMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import (
|
||||
WeightedInMemoryAggregator,
|
||||
sanitize_collection_name,
|
||||
)
|
||||
|
||||
from .config import MongoDBVectorIOConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="vector_io::mongodb")
|
||||
|
||||
VERSION = "v1"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:mongodb:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:mongodb:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:mongodb:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:mongodb:{VERSION}::"
|
||||
|
||||
|
||||
class MongoDBIndex(EmbeddingIndex):
|
||||
"""MongoDB Atlas Vector Search index implementation optimized for RAG."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_db: VectorDB,
|
||||
collection: Collection,
|
||||
config: MongoDBVectorIOConfig,
|
||||
):
|
||||
self.vector_db = vector_db
|
||||
self.collection = collection
|
||||
self.config = config
|
||||
self.dimension = vector_db.embedding_dimension
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the MongoDB collection and ensure vector search index exists."""
|
||||
try:
|
||||
# Create the collection if it doesn't exist
|
||||
collection_names = self.collection.database.list_collection_names()
|
||||
if self.collection.name not in collection_names:
|
||||
logger.info(f"Creating collection '{self.collection.name}'")
|
||||
# Create collection by inserting a dummy document
|
||||
dummy_doc = {"_id": "__dummy__", "dummy": True}
|
||||
self.collection.insert_one(dummy_doc)
|
||||
# Remove the dummy document
|
||||
self.collection.delete_one({"_id": "__dummy__"})
|
||||
logger.info(f"Collection '{self.collection.name}' created successfully")
|
||||
|
||||
# Create optimized vector search index for RAG
|
||||
await self._create_vector_search_index()
|
||||
|
||||
# Create text index for hybrid search
|
||||
await self._ensure_text_index()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to initialize MongoDB index for vector_db: {self.vector_db.identifier}. "
|
||||
f"Collection name: {self.collection.name}. Error: {str(e)}"
|
||||
)
|
||||
# Don't fail completely - just log the error and continue
|
||||
logger.warning(
|
||||
"Continuing without complete index initialization. "
|
||||
"You may need to create indexes manually in MongoDB Atlas dashboard."
|
||||
)
|
||||
|
||||
async def _create_vector_search_index(self) -> None:
|
||||
"""Create optimized vector search index based on MongoDB RAG best practices."""
|
||||
try:
|
||||
# Check if vector search index exists
|
||||
indexes = list(self.collection.list_search_indexes())
|
||||
index_exists = any(idx.get("name") == self.config.index_name for idx in indexes)
|
||||
|
||||
if not index_exists:
|
||||
# Create vector search index optimized for RAG
|
||||
# Based on MongoDB's RAG example using new vectorSearch format
|
||||
search_index_model = SearchIndexModel(
|
||||
definition={
|
||||
"fields": [
|
||||
{
|
||||
"type": "vector",
|
||||
"numDimensions": self.dimension,
|
||||
"path": self.config.path_field,
|
||||
"similarity": self._convert_similarity_metric(self.config.similarity_metric),
|
||||
}
|
||||
]
|
||||
},
|
||||
name=self.config.index_name,
|
||||
type="vectorSearch",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Creating vector search index '{self.config.index_name}' for RAG on collection '{self.collection.name}'"
|
||||
)
|
||||
|
||||
self.collection.create_search_index(model=search_index_model)
|
||||
|
||||
# Wait for index to be ready (like in MongoDB RAG example)
|
||||
await self._wait_for_index_ready()
|
||||
|
||||
logger.info("Vector search index created and ready for RAG queries")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create vector search index: {e}")
|
||||
|
||||
def _convert_similarity_metric(self, metric: str) -> str:
|
||||
"""Convert internal similarity metric to MongoDB Atlas format."""
|
||||
metric_map = {
|
||||
"cosine": "cosine",
|
||||
"euclidean": "euclidean",
|
||||
"dotProduct": "dotProduct",
|
||||
"dot_product": "dotProduct",
|
||||
}
|
||||
return metric_map.get(metric, "cosine")
|
||||
|
||||
async def _wait_for_index_ready(self) -> None:
|
||||
"""Wait for the vector search index to be ready, based on MongoDB RAG example."""
|
||||
logger.info("Waiting for vector search index to be ready...")
|
||||
|
||||
max_wait_time = 300 # 5 minutes max wait
|
||||
wait_interval = 5
|
||||
elapsed_time = 0
|
||||
|
||||
while elapsed_time < max_wait_time:
|
||||
try:
|
||||
indices = list(self.collection.list_search_indexes(self.config.index_name))
|
||||
if len(indices) and indices[0].get("queryable") is True:
|
||||
logger.info(f"Vector search index '{self.config.index_name}' is ready for querying")
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time.sleep(wait_interval)
|
||||
elapsed_time += wait_interval
|
||||
|
||||
logger.warning(f"Vector search index may not be fully ready after {max_wait_time}s")
|
||||
|
||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray) -> None:
|
||||
"""Add chunks with embeddings to MongoDB collection optimized for RAG."""
|
||||
if len(chunks) != len(embeddings):
|
||||
raise ValueError(f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}")
|
||||
|
||||
documents = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
# Structure document for optimal RAG retrieval
|
||||
doc = {
|
||||
"_id": chunk.chunk_id,
|
||||
"chunk_id": chunk.chunk_id,
|
||||
"text": interleaved_content_as_str(chunk.content), # Key field for RAG context
|
||||
"content": interleaved_content_as_str(chunk.content), # Backward compatibility
|
||||
"metadata": chunk.metadata or {},
|
||||
"chunk_metadata": (chunk.chunk_metadata.model_dump() if chunk.chunk_metadata else {}),
|
||||
self.config.path_field: embeddings[i].tolist(), # Vector embeddings
|
||||
"document": chunk.model_dump(), # Full chunk data
|
||||
}
|
||||
documents.append(doc)
|
||||
|
||||
try:
|
||||
# Use upsert behavior for chunks
|
||||
for doc in documents:
|
||||
self.collection.replace_one({"_id": doc["_id"]}, doc, upsert=True)
|
||||
|
||||
logger.debug(f"Successfully added {len(chunks)} chunks optimized for RAG to MongoDB collection")
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to add chunks to MongoDB collection: {e}")
|
||||
raise
|
||||
|
||||
async def query_vector(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
"""Perform vector similarity search optimized for RAG based on MongoDB example."""
|
||||
try:
|
||||
# Use MongoDB's vector search aggregation pipeline optimized for RAG
|
||||
pipeline = [
|
||||
{
|
||||
"$vectorSearch": {
|
||||
"index": self.config.index_name,
|
||||
"queryVector": embedding.tolist(),
|
||||
"path": self.config.path_field,
|
||||
"numCandidates": k * 10, # Get more candidates for better results
|
||||
"limit": k,
|
||||
}
|
||||
},
|
||||
{
|
||||
"$project": {
|
||||
"_id": 0,
|
||||
"text": 1, # Primary field for RAG context
|
||||
"content": 1, # Backward compatibility
|
||||
"metadata": 1,
|
||||
"chunk_metadata": 1,
|
||||
"document": 1,
|
||||
"score": {"$meta": "vectorSearchScore"},
|
||||
}
|
||||
},
|
||||
{"$match": {"score": {"$gte": score_threshold}}},
|
||||
]
|
||||
|
||||
results = list(self.collection.aggregate(pipeline))
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
for result in results:
|
||||
score = result.get("score", 0.0)
|
||||
if score >= score_threshold:
|
||||
chunk_data = result.get("document", {})
|
||||
if chunk_data:
|
||||
chunks.append(Chunk(**chunk_data))
|
||||
scores.append(float(score))
|
||||
|
||||
logger.debug(f"Vector search for RAG returned {len(chunks)} results")
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Vector search for RAG failed: {e}")
|
||||
raise RuntimeError(f"Vector search for RAG failed: {e}") from e
|
||||
|
||||
async def query_keyword(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
"""Perform text search using MongoDB's text search for RAG context retrieval."""
|
||||
try:
|
||||
# Ensure text index exists
|
||||
await self._ensure_text_index()
|
||||
|
||||
pipeline = [
|
||||
{"$match": {"$text": {"$search": query_string}}},
|
||||
{
|
||||
"$project": {
|
||||
"_id": 0,
|
||||
"text": 1, # Primary field for RAG context
|
||||
"content": 1, # Backward compatibility
|
||||
"metadata": 1,
|
||||
"chunk_metadata": 1,
|
||||
"document": 1,
|
||||
"score": {"$meta": "textScore"},
|
||||
}
|
||||
},
|
||||
{"$match": {"score": {"$gte": score_threshold}}},
|
||||
{"$sort": {"score": {"$meta": "textScore"}}},
|
||||
{"$limit": k},
|
||||
]
|
||||
|
||||
results = list(self.collection.aggregate(pipeline))
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
for result in results:
|
||||
score = result.get("score", 0.0)
|
||||
if score >= score_threshold:
|
||||
chunk_data = result.get("document", {})
|
||||
if chunk_data:
|
||||
chunks.append(Chunk(**chunk_data))
|
||||
scores.append(float(score))
|
||||
|
||||
logger.debug(f"Keyword search for RAG returned {len(chunks)} results")
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Keyword search for RAG failed: {e}")
|
||||
raise RuntimeError(f"Keyword search for RAG failed: {e}") from e
|
||||
|
||||
async def query_hybrid(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
"""Perform hybrid search for enhanced RAG context retrieval."""
|
||||
if reranker_params is None:
|
||||
reranker_params = {}
|
||||
|
||||
# Get results from both search methods
|
||||
vector_response = await self.query_vector(embedding, k, 0.0)
|
||||
keyword_response = await self.query_keyword(query_string, k, 0.0)
|
||||
|
||||
# Convert responses to score dictionaries
|
||||
vector_scores = {
|
||||
chunk.chunk_id: score for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False)
|
||||
}
|
||||
keyword_scores = {
|
||||
chunk.chunk_id: score
|
||||
for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False)
|
||||
}
|
||||
|
||||
# Combine scores using the reranking utility
|
||||
combined_scores = WeightedInMemoryAggregator.combine_search_results(
|
||||
vector_scores, keyword_scores, reranker_type, reranker_params
|
||||
)
|
||||
|
||||
# Get top-k results
|
||||
top_k_items = heapq.nlargest(k, combined_scores.items(), key=lambda x: x[1])
|
||||
|
||||
# Filter by score threshold
|
||||
filtered_items = [(doc_id, score) for doc_id, score in top_k_items if score >= score_threshold]
|
||||
|
||||
# Create chunk map
|
||||
chunk_map = {c.chunk_id: c for c in vector_response.chunks + keyword_response.chunks}
|
||||
|
||||
# Build final results
|
||||
chunks = []
|
||||
scores = []
|
||||
for doc_id, score in filtered_items:
|
||||
if doc_id in chunk_map:
|
||||
chunks.append(chunk_map[doc_id])
|
||||
scores.append(score)
|
||||
|
||||
logger.debug(f"Hybrid search for RAG returned {len(chunks)} results")
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete chunks from MongoDB collection."""
|
||||
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
|
||||
try:
|
||||
result = self.collection.delete_many({"_id": {"$in": chunk_ids}})
|
||||
logger.debug(f"Deleted {result.deleted_count} chunks from MongoDB collection")
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to delete chunks: {e}")
|
||||
raise
|
||||
|
||||
async def delete(self) -> None:
|
||||
"""Delete the entire collection."""
|
||||
try:
|
||||
self.collection.drop()
|
||||
logger.debug(f"Dropped MongoDB collection: {self.collection.name}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to drop collection: {e}")
|
||||
raise
|
||||
|
||||
async def _ensure_text_index(self) -> None:
|
||||
"""Ensure text search index exists on content fields for RAG."""
|
||||
try:
|
||||
indexes = list(self.collection.list_indexes())
|
||||
text_index_exists = any(
|
||||
any(key.startswith(("content", "text")) for key in idx.get("key", {}).keys())
|
||||
and idx.get("textIndexVersion") is not None
|
||||
for idx in indexes
|
||||
)
|
||||
|
||||
if not text_index_exists:
|
||||
logger.info("Creating text search index on content fields for RAG")
|
||||
# Index both 'text' and 'content' fields for comprehensive text search
|
||||
self.collection.create_index([("text", "text"), ("content", "text")])
|
||||
logger.info("Text search index created successfully for RAG")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create text index for RAG: {e}")
|
||||
|
||||
|
||||
class MongoDBVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
"""MongoDB Atlas Vector Search adapter for Llama Stack optimized for RAG workflows."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MongoDBVectorIOConfig,
|
||||
inference_api,
|
||||
files_api=None,
|
||||
) -> None:
|
||||
# Handle the case where files_api might be a ProviderSpec that needs resolution
|
||||
resolved_files_api = files_api
|
||||
super().__init__(files_api=resolved_files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.client: MongoClient | None = None
|
||||
self.database: Database | None = None
|
||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||
self.kvstore: KVStore | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize MongoDB connection optimized for RAG workflows."""
|
||||
logger.info("Initializing MongoDB Atlas Vector IO adapter for RAG")
|
||||
|
||||
try:
|
||||
# Initialize KV store for metadata
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
|
||||
# Connect to MongoDB with optimized settings for RAG
|
||||
self.client = MongoClient(
|
||||
self.config.connection_string,
|
||||
server_api=ServerApi("1"),
|
||||
maxPoolSize=self.config.max_pool_size,
|
||||
serverSelectionTimeoutMS=self.config.timeout_ms,
|
||||
# Additional settings for RAG performance
|
||||
retryWrites=True,
|
||||
readPreference="primaryPreferred",
|
||||
)
|
||||
|
||||
# Test connection
|
||||
self.client.admin.command("ping")
|
||||
logger.info("Successfully connected to MongoDB Atlas for RAG")
|
||||
|
||||
# Get database
|
||||
self.database = self.client[self.config.database_name]
|
||||
|
||||
# Initialize OpenAI vector stores
|
||||
await self.initialize_openai_vector_stores()
|
||||
|
||||
# Load existing vector databases
|
||||
await self._load_existing_vector_dbs()
|
||||
|
||||
logger.info("MongoDB Atlas Vector IO adapter for RAG initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to initialize MongoDB Atlas Vector IO adapter for RAG")
|
||||
raise RuntimeError("Failed to initialize MongoDB Atlas Vector IO adapter for RAG") from e
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Shutdown MongoDB connection."""
|
||||
if self.client:
|
||||
self.client.close()
|
||||
logger.info("MongoDB Atlas RAG connection closed")
|
||||
|
||||
async def health(self) -> HealthResponse:
|
||||
"""Perform health check on MongoDB connection."""
|
||||
try:
|
||||
if self.client:
|
||||
self.client.admin.command("ping")
|
||||
return HealthResponse(status=HealthStatus.OK)
|
||||
else:
|
||||
return HealthResponse(status=HealthStatus.ERROR, message="MongoDB client not initialized")
|
||||
except Exception as e:
|
||||
return HealthResponse(
|
||||
status=HealthStatus.ERROR,
|
||||
message=f"MongoDB RAG health check failed: {str(e)}",
|
||||
)
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
"""Register a new vector database optimized for RAG."""
|
||||
if self.database is None:
|
||||
raise RuntimeError("MongoDB database not initialized")
|
||||
|
||||
# Create collection name from vector DB identifier
|
||||
collection_name = sanitize_collection_name(vector_db.identifier)
|
||||
collection = self.database[collection_name]
|
||||
|
||||
# Create and initialize MongoDB index optimized for RAG
|
||||
mongodb_index = MongoDBIndex(vector_db, collection, self.config)
|
||||
await mongodb_index.initialize()
|
||||
|
||||
# Create vector DB with index wrapper
|
||||
vector_db_with_index = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=mongodb_index,
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
|
||||
# Cache the vector DB
|
||||
self.cache[vector_db.identifier] = vector_db_with_index
|
||||
|
||||
# Save vector database info to KVStore for persistence
|
||||
if self.kvstore:
|
||||
await self.kvstore.set(
|
||||
f"{VECTOR_DBS_PREFIX}{vector_db.identifier}",
|
||||
vector_db.model_dump_json(),
|
||||
)
|
||||
|
||||
logger.info(f"Registered vector database for RAG: {vector_db.identifier}")
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
"""Unregister a vector database."""
|
||||
if vector_db_id in self.cache:
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
|
||||
# Clean up from KV store
|
||||
if self.kvstore:
|
||||
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_db_id}")
|
||||
|
||||
logger.info(f"Unregistered vector database: {vector_db_id}")
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
"""Insert chunks into the vector database optimized for RAG."""
|
||||
vector_db_with_index = await self._get_vector_db_index(vector_db_id)
|
||||
await vector_db_with_index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
"""Query chunks from the vector database optimized for RAG context retrieval."""
|
||||
vector_db_with_index = await self._get_vector_db_index(vector_db_id)
|
||||
return await vector_db_with_index.query_chunks(query, params)
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete chunks from the vector database."""
|
||||
vector_db_with_index = await self._get_vector_db_index(store_id)
|
||||
await vector_db_with_index.index.delete_chunks(chunks_for_deletion)
|
||||
|
||||
async def _get_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex:
|
||||
"""Get vector database index from cache."""
|
||||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
||||
async def _load_existing_vector_dbs(self) -> None:
|
||||
"""Load existing vector databases from KVStore."""
|
||||
if not self.kvstore:
|
||||
return
|
||||
|
||||
try:
|
||||
# Use keys_in_range to get all vector database keys from KVStore
|
||||
# This searches for keys with the prefix by using range scan
|
||||
start_key = VECTOR_DBS_PREFIX
|
||||
# Create an end key by incrementing the last character
|
||||
end_key = VECTOR_DBS_PREFIX[:-1] + chr(ord(VECTOR_DBS_PREFIX[-1]) + 1)
|
||||
|
||||
vector_db_keys = await self.kvstore.keys_in_range(start_key, end_key)
|
||||
|
||||
for key in vector_db_keys:
|
||||
try:
|
||||
vector_db_data = await self.kvstore.get(key)
|
||||
if vector_db_data:
|
||||
import json
|
||||
|
||||
vector_db = VectorDB(**json.loads(vector_db_data))
|
||||
# Register the vector database without re-initializing
|
||||
await self._register_existing_vector_db(vector_db)
|
||||
logger.info(f"Loaded existing RAG-optimized vector database: {vector_db.identifier}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load vector database from key {key}: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load existing vector databases: {e}")
|
||||
|
||||
async def _register_existing_vector_db(self, vector_db: VectorDB) -> None:
|
||||
"""Register an existing vector database without re-initialization."""
|
||||
if self.database is None:
|
||||
raise RuntimeError("MongoDB database not initialized")
|
||||
|
||||
# Create collection name from vector DB identifier
|
||||
collection_name = sanitize_collection_name(vector_db.identifier)
|
||||
collection = self.database[collection_name]
|
||||
|
||||
# Create MongoDB index without initialization (collection already exists)
|
||||
mongodb_index = MongoDBIndex(vector_db, collection, self.config)
|
||||
|
||||
# Create vector DB with index wrapper
|
||||
vector_db_with_index = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=mongodb_index,
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
|
||||
# Cache the vector DB
|
||||
self.cache[vector_db.identifier] = vector_db_with_index
|
||||
Loading…
Add table
Add a link
Reference in a new issue