updated batching and config

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-02-18 00:09:22 -05:00
parent 663396ec5a
commit 9ee59443aa
2 changed files with 5 additions and 6 deletions

View file

@ -18,7 +18,6 @@ from llama_stack.providers.utils.kvstore.config import (
class SQLiteVectorIOConfig(BaseModel): class SQLiteVectorIOConfig(BaseModel):
db_path: str db_path: str
kvstore: KVStoreConfig kvstore: KVStoreConfig
batch_size: bool = 500
@classmethod @classmethod
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]: def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:

View file

@ -69,7 +69,7 @@ class SQLiteVecIndex(EmbeddingIndex):
cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};") cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};")
self.connection.commit() self.connection.commit()
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray, batch_size: int = 500):
""" """
Add new chunks along with their embeddings using batch inserts. Add new chunks along with their embeddings using batch inserts.
For each chunk, we insert its JSON into the metadata table and then insert its For each chunk, we insert its JSON into the metadata table and then insert its
@ -80,9 +80,9 @@ class SQLiteVecIndex(EmbeddingIndex):
try: try:
# Start transaction # Start transaction
cur.execute("BEGIN TRANSACTION") cur.execute("BEGIN TRANSACTION")
for i in range(0, len(chunks), self.config.batch_size): for i in range(0, len(chunks), batch_size):
batch_chunks = chunks[i : i + self.config.batch_size] batch_chunks = chunks[i : i + batch_size]
batch_embeddings = embeddings[i : i + self.config.batch_size] batch_embeddings = embeddings[i : i + batch_size]
# Prepare metadata inserts # Prepare metadata inserts
metadata_data = [ metadata_data = [
(f"{chunk.metadata['document_id']}:chunk-{j}", chunk.model_dump_json()) (f"{chunk.metadata['document_id']}:chunk-{j}", chunk.model_dump_json())
@ -218,7 +218,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
if vector_db_id not in self.cache: if vector_db_id not in self.cache:
raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}") raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}")
# The VectorDBWithIndex helper is expected to compute embeddings via the inference_api # The VectorDBWithIndex helper is expected to compute embeddings via the inference_api
# and then call our indexs add_chunks. # and then call our index's add_chunks.
await self.cache[vector_db_id].insert_chunks(chunks) await self.cache[vector_db_id].insert_chunks(chunks)
async def query_chunks( async def query_chunks(