mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-02 13:14:32 +00:00
adding mongodb vector_io module
updated mongodb.py from print to log add documentation for mongodb vector search module changed insert to update mongodb bug fix mongodb json object conversion error
This commit is contained in:
parent
d224ae0c8e
commit
80d9d50954
8 changed files with 503 additions and 65 deletions
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import MongoDBVectorIOConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: MongoDBVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
||||
from .mongodb import MongoDBVectorIOAdapter
|
||||
|
||||
impl = MongoDBVectorIOAdapter(config, deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -4,26 +4,22 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MongoDBVectorIOConfig(BaseModel):
|
||||
conncetion_str: str
|
||||
namespace: str = Field(None, description="Namespace of the MongoDB collection")
|
||||
index_name: Optional[str] = Field("default", description="Name of the index in the MongoDB collection")
|
||||
filter_fields: Optional[str] = Field(None, description="Fields to filter the MongoDB collection")
|
||||
embedding_field: Optional[str] = Field("embeddings", description="Field name for the embeddings in the MongoDB collection")
|
||||
text_field: Optional[str] = Field("text", description="Field name for the text in the MongoDB collection")
|
||||
connection_str: str = Field(None, description="Connection string for the MongoDB Atlas collection")
|
||||
namespace: str = Field(None, description="Namespace i.e. db_name.collection_name of the MongoDB Atlas collection")
|
||||
index_name: Optional[str] = Field("default", description="Name of the index in the MongoDB Atlas collection")
|
||||
filter_fields: Optional[List[str]] = Field([], description="Fields to filter along side vector search in MongoDB Atlas collection")
|
||||
embeddings_key: Optional[str] = Field("embeddings", description="Field name for the embeddings in the MongoDB Atlas collection")
|
||||
text_field: Optional[str] = Field("text", description="Field name for the text in the MongoDB Atlas collection")
|
||||
|
||||
@classmethod
|
||||
def sample_config(cls) -> Dict[str, Any]:
|
||||
return {
|
||||
"connection_str": "{env.MONGODB_CONNECTION_STR}",
|
||||
"namespace": "{env.MONGODB_NAMESPACE}",
|
||||
"index_name": "{env.MONGODB_INDEX_NAME}",
|
||||
"filter_fields": "{env.MONGODB_FILTER_FIELDS}",
|
||||
"embedding_field": "{env.MONGODB_EMBEDDING_FIELD}",
|
||||
"text_field": "{env.MONGODB_TEXT_FIELD}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
import pymongo
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
|
|
@ -11,8 +9,8 @@ import logging
|
|||
from typing import Any, Dict, List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pymongo import MongoClient,
|
||||
from pymongo.operations import UpdateOne, InsertOne, DeleteOne, DeleteMany, SearchIndexModel
|
||||
from pymongo import MongoClient
|
||||
from pymongo.operations import InsertOne, SearchIndexModel, UpdateOne
|
||||
import certifi
|
||||
from numpy.typing import NDArray
|
||||
|
||||
|
|
@ -25,19 +23,24 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
|
||||
from .config import MongoDBVectorIOConfig
|
||||
|
||||
from .config import MongoDBAtlasVectorIOConfig
|
||||
from time import sleep
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
CHUNK_ID_KEY = "_chunk_id"
|
||||
|
||||
|
||||
class MongoDBAtlasIndex(EmbeddingIndex):
|
||||
def __init__(self, client: MongoClient, namespace: str, embeddings_key: str, index_name: str):
|
||||
|
||||
def __init__(self, client: MongoClient, namespace: str, embeddings_key: str, embedding_dimension: str, index_name: str, filter_fields: List[str]):
|
||||
self.client = client
|
||||
self.namespace = namespace
|
||||
self.embeddings_key = embeddings_key
|
||||
self.index_name = index_name
|
||||
self.filter_fields = filter_fields
|
||||
self.embedding_dimension = embedding_dimension
|
||||
|
||||
def _get_index_config(self, collection, index_name):
|
||||
idxs = list(collection.list_search_indexes())
|
||||
|
|
@ -45,14 +48,39 @@ class MongoDBAtlasIndex(EmbeddingIndex):
|
|||
if ele["name"] == index_name:
|
||||
return ele
|
||||
|
||||
def _get_search_index_model(self):
|
||||
index_fields = [
|
||||
{
|
||||
"path": self.embeddings_key,
|
||||
"type": "vector",
|
||||
"numDimensions": self.embedding_dimension,
|
||||
"similarity": "cosine"
|
||||
}
|
||||
]
|
||||
|
||||
if len(self.filter_fields) > 0:
|
||||
for filter_field in self.filter_fields:
|
||||
index_fields.append(
|
||||
{
|
||||
"path": filter_field,
|
||||
"type": "filter"
|
||||
}
|
||||
)
|
||||
|
||||
return SearchIndexModel(
|
||||
name=self.index_name,
|
||||
type="vectorSearch",
|
||||
definition={
|
||||
"fields": index_fields
|
||||
}
|
||||
)
|
||||
|
||||
def _check_n_create_index(self):
|
||||
client = self.client
|
||||
db,collection = self.namespace.split(".")
|
||||
db, collection = self.namespace.split(".")
|
||||
collection = client[db][collection]
|
||||
index_name = self.index_name
|
||||
print(">>>>>>>>Index name: ", index_name, "<<<<<<<<<<")
|
||||
idx = self._get_index_config(collection, index_name)
|
||||
print(idx)
|
||||
if not idx:
|
||||
log.info("Creating search index ...")
|
||||
search_index_model = self._get_search_index_model()
|
||||
|
|
@ -60,10 +88,10 @@ class MongoDBAtlasIndex(EmbeddingIndex):
|
|||
while True:
|
||||
idx = self._get_index_config(collection, index_name)
|
||||
if idx and idx["queryable"]:
|
||||
print("Search index created successfully.")
|
||||
log.info("Search index created successfully.")
|
||||
break
|
||||
else:
|
||||
print("Waiting for search index to be created ...")
|
||||
log.info("Waiting for search index to be created ...")
|
||||
sleep(5)
|
||||
else:
|
||||
log.info("Search index already exists.")
|
||||
|
|
@ -77,46 +105,53 @@ class MongoDBAtlasIndex(EmbeddingIndex):
|
|||
operations = []
|
||||
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
||||
chunk_id = f"{chunk.metadata['document_id']}:chunk-{i}"
|
||||
|
||||
operations.append(
|
||||
InsertOne(
|
||||
UpdateOne(
|
||||
{CHUNK_ID_KEY: chunk_id},
|
||||
{
|
||||
CHUNK_ID_KEY: chunk_id,
|
||||
"chunk_content": chunk.model_dump_json(),
|
||||
self.embeddings_key: embedding.tolist(),
|
||||
}
|
||||
"$set": {
|
||||
CHUNK_ID_KEY: chunk_id,
|
||||
"chunk_content": json.loads(chunk.model_dump_json()),
|
||||
self.embeddings_key: embedding.tolist(),
|
||||
}
|
||||
},
|
||||
upsert=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Perform the bulk operations
|
||||
db,collection_name = self.namespace.split(".")
|
||||
db, collection_name = self.namespace.split(".")
|
||||
collection = self.client[db][collection_name]
|
||||
collection.bulk_write(operations)
|
||||
|
||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
|
||||
print(f"Added {len(chunks)} chunks to the collection")
|
||||
# Create a search index model
|
||||
print("Creating search index ...")
|
||||
self._check_n_create_index()
|
||||
|
||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
# Perform a query
|
||||
db,collection_name = self.namespace.split(".")
|
||||
db, collection_name = self.namespace.split(".")
|
||||
collection = self.client[db][collection_name]
|
||||
|
||||
# Create vector search query
|
||||
vs_query = {"$vectorSearch":
|
||||
{
|
||||
"index": "vector_index",
|
||||
"path": self.embeddings_key,
|
||||
"queryVector": embedding.tolist(),
|
||||
"numCandidates": k,
|
||||
"limit": k,
|
||||
}
|
||||
}
|
||||
vs_query = {"$vectorSearch":
|
||||
{
|
||||
"index": self.index_name,
|
||||
"path": self.embeddings_key,
|
||||
"queryVector": embedding.tolist(),
|
||||
"numCandidates": k,
|
||||
"limit": k,
|
||||
}
|
||||
}
|
||||
# Add a field to store the score
|
||||
score_add_field_query = {
|
||||
"$addFields": {
|
||||
"score": {"$meta": "vectorSearchScore"}
|
||||
}
|
||||
}
|
||||
if score_threshold is None:
|
||||
score_threshold = 0.01
|
||||
# Filter the results based on the score threshold
|
||||
filter_query = {
|
||||
"$match": {
|
||||
|
|
@ -141,60 +176,90 @@ class MongoDBAtlasIndex(EmbeddingIndex):
|
|||
chunks = []
|
||||
scores = []
|
||||
for result in results:
|
||||
content = result["chunk_content"]
|
||||
chunk = Chunk(
|
||||
metadata={"document_id": result[CHUNK_ID_KEY]},
|
||||
content=json.loads(result["chunk_content"]),
|
||||
metadata=content["metadata"],
|
||||
content=content["content"],
|
||||
)
|
||||
chunks.append(chunk)
|
||||
scores.append(result["score"])
|
||||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(self, config: MongoDBAtlasVectorIOConfig, inference_api: Api.inference):
|
||||
async def delete(self):
|
||||
db, _ = self.namespace.split(".")
|
||||
self.client.drop_database(db)
|
||||
|
||||
|
||||
class MongoDBVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(self, config: MongoDBVectorIOConfig, inference_api: Api.inference):
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
|
||||
self.cache = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.client = MongoClient(
|
||||
self.config.uri,
|
||||
self.config.connection_str,
|
||||
tlsCAFile=certifi.where(),
|
||||
)
|
||||
self.cache = {}
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.client.close()
|
||||
pass
|
||||
if not self.client:
|
||||
self.client.close()
|
||||
|
||||
async def register_vector_db( self, vector_db: VectorDB) -> None:
|
||||
index = VectorDBWithIndex(
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
index=MongoDBAtlasIndex(
|
||||
client=self.client,
|
||||
namespace=self.config.namespace,
|
||||
embeddings_key=self.config.embeddings_key,
|
||||
embedding_dimension=vector_db.embedding_dimension,
|
||||
index_name=self.config.index_name,
|
||||
filter_fields=self.config.filter_fields,
|
||||
)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=index,
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
|
||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex:
|
||||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
|
||||
self.cache[vector_db_id] = VectorDBWithIndex(
|
||||
vector_db=vector_db_id,
|
||||
index=MongoDBAtlasIndex(
|
||||
client=self.client,
|
||||
namespace=self.config.namespace,
|
||||
embeddings_key=self.config.embeddings_key,
|
||||
embedding_dimension=vector_db.embedding_dimension,
|
||||
index_name=self.config.index_name,
|
||||
filter_fields=self.config.filter_fields,
|
||||
),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_db] = index
|
||||
pass
|
||||
return self.cache[vector_db_id]
|
||||
|
||||
async def insert_chunks(self, vector_db_id, chunks, ttl_seconds = None):
|
||||
index = self.cache[vector_db_id].index
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
|
||||
async def insert_chunks(self,
|
||||
vector_db_id: str,
|
||||
chunks: List[Chunk],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
|
||||
async def query_chunks(self, vector_db_id, query, params = None):
|
||||
index = self.cache[vector_db_id].index
|
||||
async def query_chunks(self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
|
||||
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue