mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 02:32:40 +00:00
docs: Fix incorrect link and command for generating API reference (#1124)
added MongoDB vector io
This commit is contained in:
parent
66d6c2580e
commit
bfece15fb4
3 changed files with 229 additions and 0 deletions
29
llama_stack/providers/remote/vector_io/mongodb/config.py
Normal file
29
llama_stack/providers/remote/vector_io/mongodb/config.py
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class MongoDBVectorIOConfig(BaseModel):
|
||||||
|
conncetion_str: str
|
||||||
|
namespace: str = Field(None, description="Namespace of the MongoDB collection")
|
||||||
|
index_name: Optional[str] = Field("default", description="Name of the index in the MongoDB collection")
|
||||||
|
filter_fields: Optional[str] = Field(None, description="Fields to filter the MongoDB collection")
|
||||||
|
embedding_field: Optional[str] = Field("embeddings", description="Field name for the embeddings in the MongoDB collection")
|
||||||
|
text_field: Optional[str] = Field("text", description="Field name for the text in the MongoDB collection")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_config(cls) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"connection_str": "{env.MONGODB_CONNECTION_STR}",
|
||||||
|
"namespace": "{env.MONGODB_NAMESPACE}",
|
||||||
|
"index_name": "{env.MONGODB_INDEX_NAME}",
|
||||||
|
"filter_fields": "{env.MONGODB_FILTER_FIELDS}",
|
||||||
|
"embedding_field": "{env.MONGODB_EMBEDDING_FIELD}",
|
||||||
|
"text_field": "{env.MONGODB_TEXT_FIELD}",
|
||||||
|
}
|
200
llama_stack/providers/remote/vector_io/mongodb/mongodb_atlas.py
Normal file
200
llama_stack/providers/remote/vector_io/mongodb/mongodb_atlas.py
Normal file
|
@ -0,0 +1,200 @@
|
||||||
|
import pymongo
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from pymongo import MongoClient,
|
||||||
|
from pymongo.operations import UpdateOne, InsertOne, DeleteOne, DeleteMany, SearchIndexModel
|
||||||
|
import certifi
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import InterleavedContent
|
||||||
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
EmbeddingIndex,
|
||||||
|
VectorDBWithIndex,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .config import MongoDBAtlasVectorIOConfig
|
||||||
|
from time import sleep
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
CHUNK_ID_KEY = "_chunk_id"
|
||||||
|
|
||||||
|
class MongoDBAtlasIndex(EmbeddingIndex):
|
||||||
|
def __init__(self, client: MongoClient, namespace: str, embeddings_key: str, index_name: str):
|
||||||
|
self.client = client
|
||||||
|
self.namespace = namespace
|
||||||
|
self.embeddings_key = embeddings_key
|
||||||
|
self.index_name = index_name
|
||||||
|
|
||||||
|
def _get_index_config(self, collection, index_name):
|
||||||
|
idxs = list(collection.list_search_indexes())
|
||||||
|
for ele in idxs:
|
||||||
|
if ele["name"] == index_name:
|
||||||
|
return ele
|
||||||
|
|
||||||
|
def _check_n_create_index(self):
|
||||||
|
client = self.client
|
||||||
|
db,collection = self.namespace.split(".")
|
||||||
|
collection = client[db][collection]
|
||||||
|
index_name = self.index_name
|
||||||
|
print(">>>>>>>>Index name: ", index_name, "<<<<<<<<<<")
|
||||||
|
idx = self._get_index_config(collection, index_name)
|
||||||
|
print(idx)
|
||||||
|
if not idx:
|
||||||
|
log.info("Creating search index ...")
|
||||||
|
search_index_model = self._get_search_index_model()
|
||||||
|
collection.create_search_index(search_index_model)
|
||||||
|
while True:
|
||||||
|
idx = self._get_index_config(collection, index_name)
|
||||||
|
if idx and idx["queryable"]:
|
||||||
|
print("Search index created successfully.")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
print("Waiting for search index to be created ...")
|
||||||
|
sleep(5)
|
||||||
|
else:
|
||||||
|
log.info("Search index already exists.")
|
||||||
|
|
||||||
|
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||||
|
assert len(chunks) == len(embeddings), (
|
||||||
|
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a list of operations to perform in bulk
|
||||||
|
operations = []
|
||||||
|
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
||||||
|
chunk_id = f"{chunk.metadata['document_id']}:chunk-{i}"
|
||||||
|
operations.append(
|
||||||
|
InsertOne(
|
||||||
|
{
|
||||||
|
CHUNK_ID_KEY: chunk_id,
|
||||||
|
"chunk_content": chunk.model_dump_json(),
|
||||||
|
self.embeddings_key: embedding.tolist(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform the bulk operations
|
||||||
|
db,collection_name = self.namespace.split(".")
|
||||||
|
collection = self.client[db][collection_name]
|
||||||
|
collection.bulk_write(operations)
|
||||||
|
|
||||||
|
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
|
|
||||||
|
# Create a search index model
|
||||||
|
self._check_n_create_index()
|
||||||
|
|
||||||
|
# Perform a query
|
||||||
|
db,collection_name = self.namespace.split(".")
|
||||||
|
collection = self.client[db][collection_name]
|
||||||
|
|
||||||
|
# Create vector search query
|
||||||
|
vs_query = {"$vectorSearch":
|
||||||
|
{
|
||||||
|
"index": "vector_index",
|
||||||
|
"path": self.embeddings_key,
|
||||||
|
"queryVector": embedding.tolist(),
|
||||||
|
"numCandidates": k,
|
||||||
|
"limit": k,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
# Add a field to store the score
|
||||||
|
score_add_field_query = {
|
||||||
|
"$addFields": {
|
||||||
|
"score": {"$meta": "vectorSearchScore"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
# Filter the results based on the score threshold
|
||||||
|
filter_query = {
|
||||||
|
"$match": {
|
||||||
|
"score": {"$gt": score_threshold}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
project_query = {
|
||||||
|
"$project": {
|
||||||
|
CHUNK_ID_KEY: 1,
|
||||||
|
"chunk_content": 1,
|
||||||
|
"score": 1,
|
||||||
|
"_id": 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
query = [vs_query, score_add_field_query, filter_query, project_query]
|
||||||
|
|
||||||
|
results = collection.aggregate(query)
|
||||||
|
|
||||||
|
# Create the response
|
||||||
|
chunks = []
|
||||||
|
scores = []
|
||||||
|
for result in results:
|
||||||
|
chunk = Chunk(
|
||||||
|
metadata={"document_id": result[CHUNK_ID_KEY]},
|
||||||
|
content=json.loads(result["chunk_content"]),
|
||||||
|
)
|
||||||
|
chunks.append(chunk)
|
||||||
|
scores.append(result["score"])
|
||||||
|
|
||||||
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
|
class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
|
def __init__(self, config: MongoDBAtlasVectorIOConfig, inference_api: Api.inference):
|
||||||
|
self.config = config
|
||||||
|
self.inference_api = inference_api
|
||||||
|
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
self.client = MongoClient(
|
||||||
|
self.config.uri,
|
||||||
|
tlsCAFile=certifi.where(),
|
||||||
|
)
|
||||||
|
self.cache = {}
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
self.client.close()
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_vector_db( self, vector_db: VectorDB) -> None:
|
||||||
|
index = VectorDBWithIndex(
|
||||||
|
vector_db=vector_db,
|
||||||
|
index=MongoDBAtlasIndex(
|
||||||
|
client=self.client,
|
||||||
|
namespace=self.config.namespace,
|
||||||
|
embeddings_key=self.config.embeddings_key,
|
||||||
|
index_name=self.config.index_name,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.cache[vector_db] = index
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def insert_chunks(self, vector_db_id, chunks, ttl_seconds = None):
|
||||||
|
index = self.cache[vector_db_id].index
|
||||||
|
if not index:
|
||||||
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||||
|
await index.insert_chunks(chunks)
|
||||||
|
|
||||||
|
|
||||||
|
async def query_chunks(self, vector_db_id, query, params = None):
|
||||||
|
index = self.cache[vector_db_id].index
|
||||||
|
if not index:
|
||||||
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||||
|
return await index.query_chunks(query, params)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue