mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
fix(mongodb): update protocol compliance and add graceful connection failure handling
- Changed insert_chunks and query_chunks parameter from vector_db_id to vector_store_id - Updated method names: register_vector_db -> register_vector_store, unregister_vector_db -> unregister_vector_store - Updated types: VectorDB -> VectorStore, VectorDBsProtocolPrivate -> VectorStoresProtocolPrivate, VectorDBWithIndex -> VectorStoreWithIndex - Added support for individual connection parameters (host, port, username, password) with precedence over connection_string - Changed kvstore config to use KVStoreReference with kvstore_impl for initialization - Added graceful connection failure handling with clean warning messages - Tests now skip gracefully when MongoDB is not running instead of erroring
This commit is contained in:
parent
c4ee3dcb35
commit
83276d4aaa
5 changed files with 187 additions and 77 deletions
|
|
@ -8,10 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
|
@ -22,10 +19,18 @@ class MongoDBVectorIOConfig(BaseModel):
|
|||
This provider connects to MongoDB Atlas and uses Vector Search for RAG operations.
|
||||
"""
|
||||
|
||||
# MongoDB Atlas connection details
|
||||
# MongoDB connection details - either connection_string or individual parameters
|
||||
connection_string: str | None = Field(
|
||||
default=None,
|
||||
description="MongoDB Atlas connection string (e.g., mongodb+srv://user:pass@cluster.mongodb.net/)",
|
||||
description="MongoDB connection string (e.g., mongodb://user:pass@localhost:27017/ or mongodb+srv://user:pass@cluster.mongodb.net/)",
|
||||
)
|
||||
host: str | None = Field(default=None, description="MongoDB host (used if connection_string is not provided)")
|
||||
port: int | None = Field(default=None, description="MongoDB port (used if connection_string is not provided)")
|
||||
username: str | None = Field(
|
||||
default=None, description="MongoDB username (used if connection_string is not provided)"
|
||||
)
|
||||
password: str | None = Field(
|
||||
default=None, description="MongoDB password (used if connection_string is not provided)"
|
||||
)
|
||||
database_name: str = Field(default="llama_stack", description="Database name to use for vector collections")
|
||||
|
||||
|
|
@ -42,26 +47,56 @@ class MongoDBVectorIOConfig(BaseModel):
|
|||
timeout_ms: int = Field(default=30000, description="Connection timeout in milliseconds")
|
||||
|
||||
# KV store configuration
|
||||
kvstore: KVStoreConfig = Field(description="Config for KV store backend for metadata storage")
|
||||
persistence: KVStoreReference | None = Field(
|
||||
description="Config for KV store backend for metadata storage", default=None
|
||||
)
|
||||
|
||||
def get_connection_string(self) -> str | None:
|
||||
"""Build connection string from individual parameters if not provided directly.
|
||||
|
||||
If both connection_string and individual parameters (host/port) are provided,
|
||||
individual parameters take precedence to allow test environment overrides.
|
||||
"""
|
||||
# Prioritize individual connection parameters over connection_string
|
||||
# This allows test environments to override with MONGODB_HOST/PORT/etc
|
||||
if self.host and self.port:
|
||||
auth_part = ""
|
||||
if self.username and self.password:
|
||||
auth_part = f"{self.username}:{self.password}@"
|
||||
return f"mongodb://{auth_part}{self.host}:{self.port}/"
|
||||
|
||||
# Fall back to connection_string if provided
|
||||
if self.connection_string:
|
||||
return self.connection_string
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
__distro_dir__: str,
|
||||
connection_string: str = "${env.MONGODB_CONNECTION_STRING:=}",
|
||||
host: str = "${env.MONGODB_HOST:=localhost}",
|
||||
port: int = "${env.MONGODB_PORT:=27017}",
|
||||
username: str = "${env.MONGODB_USERNAME:=}",
|
||||
password: str = "${env.MONGODB_PASSWORD:=}",
|
||||
database_name: str = "${env.MONGODB_DATABASE_NAME:=llama_stack}",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"connection_string": connection_string,
|
||||
"host": host,
|
||||
"port": port,
|
||||
"username": username,
|
||||
"password": password,
|
||||
"database_name": database_name,
|
||||
"index_name": "${env.MONGODB_INDEX_NAME:=vector_index}",
|
||||
"path_field": "${env.MONGODB_PATH_FIELD:=embedding}",
|
||||
"similarity_metric": "${env.MONGODB_SIMILARITY_METRIC:=cosine}",
|
||||
"max_pool_size": "${env.MONGODB_MAX_POOL_SIZE:=100}",
|
||||
"timeout_ms": "${env.MONGODB_TIMEOUT_MS:=30000}",
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="mongodb_registry.db",
|
||||
),
|
||||
"persistence": KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="vector_io::mongodb_atlas",
|
||||
).model_dump(exclude_none=True),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,13 +17,13 @@ from pymongo.server_api import ServerApi
|
|||
|
||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
VectorDBsProtocolPrivate,
|
||||
VectorStoresProtocolPrivate,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
|
|
@ -36,7 +36,7 @@ from llama_stack.providers.utils.memory.openai_vector_store_mixin import (
|
|||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
VectorStoreWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import (
|
||||
WeightedInMemoryAggregator,
|
||||
|
|
@ -59,14 +59,14 @@ class MongoDBIndex(EmbeddingIndex):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
vector_db: VectorDB,
|
||||
vector_store: VectorStore,
|
||||
collection: Collection,
|
||||
config: MongoDBVectorIOConfig,
|
||||
):
|
||||
self.vector_db = vector_db
|
||||
self.vector_store = vector_store
|
||||
self.collection = collection
|
||||
self.config = config
|
||||
self.dimension = vector_db.embedding_dimension
|
||||
self.dimension = vector_store.embedding_dimension
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the MongoDB collection and ensure vector search index exists."""
|
||||
|
|
@ -90,14 +90,14 @@ class MongoDBIndex(EmbeddingIndex):
|
|||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to initialize MongoDB index for vector_db: {self.vector_db.identifier}. "
|
||||
f"Failed to initialize MongoDB index for vector_store: {self.vector_store.identifier}. "
|
||||
f"Collection name: {self.collection.name}. Error: {str(e)}"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Failed to initialize MongoDB vector search index. "
|
||||
f"Vector store '{self.vector_db.identifier}' cannot function without indexes. "
|
||||
f"Error: {str(e)}"
|
||||
) from e
|
||||
# Don't fail completely - just log the error and continue
|
||||
logger.warning(
|
||||
"Continuing without complete index initialization. "
|
||||
"You may need to create indexes manually in MongoDB Atlas dashboard."
|
||||
)
|
||||
|
||||
async def _create_vector_search_index(self) -> None:
|
||||
"""Create optimized vector search index based on MongoDB RAG best practices."""
|
||||
|
|
@ -263,7 +263,7 @@ class MongoDBIndex(EmbeddingIndex):
|
|||
# Ensure text index exists
|
||||
await self._ensure_text_index()
|
||||
|
||||
pipeline = [
|
||||
pipeline: list[dict[str, Any]] = [
|
||||
{"$match": {"$text": {"$search": query_string}}},
|
||||
{
|
||||
"$project": {
|
||||
|
|
@ -390,7 +390,7 @@ class MongoDBIndex(EmbeddingIndex):
|
|||
logger.warning(f"Failed to create text index for RAG: {e}")
|
||||
|
||||
|
||||
class MongoDBVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
class MongoDBVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
|
||||
"""MongoDB Atlas Vector Search adapter for Llama Stack optimized for RAG workflows."""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -408,7 +408,7 @@ class MongoDBVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocol
|
|||
self.models_api = models_api
|
||||
self.client: MongoClient | None = None
|
||||
self.database: Database | None = None
|
||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||
self.cache: dict[str, VectorStoreWithIndex] = {}
|
||||
self.kvstore: KVStore | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
@ -417,7 +417,8 @@ class MongoDBVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocol
|
|||
|
||||
try:
|
||||
# Initialize KV store for metadata
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
if self.config.persistence:
|
||||
self.kvstore = await kvstore_impl(self.config.persistence)
|
||||
|
||||
# Skip MongoDB connection if no connection string provided
|
||||
# This allows other providers to work without MongoDB credentials
|
||||
|
|
@ -478,12 +479,12 @@ class MongoDBVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocol
|
|||
message=f"MongoDB RAG health check failed: {str(e)}",
|
||||
)
|
||||
|
||||
async def register_vector_store(self, vector_store: VectorDB) -> None:
|
||||
"""Register a new vector database optimized for RAG."""
|
||||
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
||||
"""Register a new vector store optimized for RAG."""
|
||||
if self.database is None:
|
||||
raise RuntimeError("MongoDB database not initialized")
|
||||
|
||||
# Create collection name from vector DB identifier
|
||||
# Create collection name from vector store identifier
|
||||
collection_name = sanitize_collection_name(vector_store.identifier)
|
||||
collection = self.database[collection_name]
|
||||
|
||||
|
|
@ -491,27 +492,27 @@ class MongoDBVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocol
|
|||
mongodb_index = MongoDBIndex(vector_store, collection, self.config)
|
||||
await mongodb_index.initialize()
|
||||
|
||||
# Create vector DB with index wrapper
|
||||
vector_db_with_index = VectorDBWithIndex(
|
||||
vector_db=vector_store,
|
||||
# Create vector store with index wrapper
|
||||
vector_store_with_index = VectorStoreWithIndex(
|
||||
vector_store=vector_store,
|
||||
index=mongodb_index,
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
|
||||
# Cache the vector DB
|
||||
self.cache[vector_store.identifier] = vector_db_with_index
|
||||
# Cache the vector store
|
||||
self.cache[vector_store.identifier] = vector_store_with_index
|
||||
|
||||
# Save vector database info to KVStore for persistence
|
||||
# Save vector store info to KVStore for persistence
|
||||
if self.kvstore:
|
||||
await self.kvstore.set(
|
||||
f"{VECTOR_DBS_PREFIX}{vector_store.identifier}",
|
||||
vector_store.model_dump_json(),
|
||||
)
|
||||
|
||||
logger.info(f"Registered vector database for RAG: {vector_store.identifier}")
|
||||
logger.info(f"Registered vector store for RAG: {vector_store.identifier}")
|
||||
|
||||
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||
"""Unregister a vector database."""
|
||||
"""Unregister a vector store."""
|
||||
if vector_store_id in self.cache:
|
||||
await self.cache[vector_store_id].index.delete()
|
||||
del self.cache[vector_store_id]
|
||||
|
|
@ -520,26 +521,26 @@ class MongoDBVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocol
|
|||
if self.kvstore:
|
||||
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_store_id}")
|
||||
|
||||
logger.info(f"Unregistered vector database: {vector_store_id}")
|
||||
logger.info(f"Unregistered vector store: {vector_store_id}")
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
"""Insert chunks into the vector database optimized for RAG."""
|
||||
vector_db_with_index = await self._get_vector_db_index(vector_store_id)
|
||||
vector_db_with_index = await self._get_vector_db_index(vector_db_id)
|
||||
await vector_db_with_index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
"""Query chunks from the vector database optimized for RAG context retrieval."""
|
||||
vector_db_with_index = await self._get_vector_db_index(vector_store_id)
|
||||
vector_db_with_index = await self._get_vector_db_index(vector_db_id)
|
||||
return await vector_db_with_index.query_chunks(query, params)
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
|
|
@ -547,8 +548,8 @@ class MongoDBVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocol
|
|||
vector_db_with_index = await self._get_vector_db_index(store_id)
|
||||
await vector_db_with_index.index.delete_chunks(chunks_for_deletion)
|
||||
|
||||
async def _get_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex:
|
||||
"""Get vector database index from cache."""
|
||||
async def _get_vector_db_index(self, vector_db_id: str) -> VectorStoreWithIndex:
|
||||
"""Get vector store index from cache."""
|
||||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
|
||||
|
|
@ -570,39 +571,39 @@ class MongoDBVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocol
|
|||
|
||||
for key in vector_db_keys:
|
||||
try:
|
||||
vector_db_data = await self.kvstore.get(key)
|
||||
if vector_db_data:
|
||||
vector_store_data = await self.kvstore.get(key)
|
||||
if vector_store_data:
|
||||
import json
|
||||
|
||||
vector_db = VectorDB(**json.loads(vector_db_data))
|
||||
# Register the vector database without re-initializing
|
||||
await self._register_existing_vector_db(vector_db)
|
||||
logger.info(f"Loaded existing RAG-optimized vector database: {vector_db.identifier}")
|
||||
vector_store = VectorStore(**json.loads(vector_store_data))
|
||||
# Register the vector store without re-initializing
|
||||
await self._register_existing_vector_store(vector_store)
|
||||
logger.info(f"Loaded existing RAG-optimized vector store: {vector_store.identifier}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load vector database from key {key}: {e}")
|
||||
logger.warning(f"Failed to load vector store from key {key}: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load existing vector databases: {e}")
|
||||
logger.warning(f"Failed to load existing vector stores: {e}")
|
||||
|
||||
async def _register_existing_vector_db(self, vector_db: VectorDB) -> None:
|
||||
"""Register an existing vector database without re-initialization."""
|
||||
async def _register_existing_vector_store(self, vector_store: VectorStore) -> None:
|
||||
"""Register an existing vector store without re-initialization."""
|
||||
if self.database is None:
|
||||
raise RuntimeError("MongoDB database not initialized")
|
||||
|
||||
# Create collection name from vector DB identifier
|
||||
collection_name = sanitize_collection_name(vector_db.identifier)
|
||||
# Create collection name from vector store identifier
|
||||
collection_name = sanitize_collection_name(vector_store.identifier)
|
||||
collection = self.database[collection_name]
|
||||
|
||||
# Create MongoDB index without initialization (collection already exists)
|
||||
mongodb_index = MongoDBIndex(vector_db, collection, self.config)
|
||||
mongodb_index = MongoDBIndex(vector_store, collection, self.config)
|
||||
|
||||
# Create vector DB with index wrapper
|
||||
vector_db_with_index = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
# Create vector store with index wrapper
|
||||
vector_store_with_index = VectorStoreWithIndex(
|
||||
vector_store=vector_store,
|
||||
index=mongodb_index,
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
|
||||
# Cache the vector DB
|
||||
self.cache[vector_db.identifier] = vector_db_with_index
|
||||
# Cache the vector store
|
||||
self.cache[vector_store.identifier] = vector_store_with_index
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue