feat: Enhance Vector Stores config with full configurations (#4397)

# What does this PR do?

Enhances the Vector Stores config with full set of appropriate
configurations
- Add FileIngestionParams, ChunkRetrievalParams, and FileBatchParams
subconfigs
- Update RAG memory, OpenAI vector store mixin, and vector store utils
to use configuration
  - Fix import organization across vector store components
  - Add comprehensive vector stores configuration documentation
  - Update docs navigation to include vector store configuration guide
- Delete `memory/constants.py` and move constant values directly into
Pydantic models

## Test Plan
Tests updated + CI

---------

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-12-17 16:56:46 -05:00 committed by GitHub
parent a7d509aaf9
commit 2d149e3d2d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 3249 additions and 110 deletions

View file

@ -18,15 +18,6 @@ from llama_stack.core.storage.datatypes import (
StorageConfig,
)
from llama_stack.log import LoggingConfig
from llama_stack.providers.utils.memory.constants import (
DEFAULT_ANNOTATION_INSTRUCTION_TEMPLATE,
DEFAULT_CHUNK_ANNOTATION_TEMPLATE,
DEFAULT_CHUNK_WITH_SOURCES_TEMPLATE,
DEFAULT_CONTEXT_TEMPLATE,
DEFAULT_FILE_SEARCH_FOOTER_TEMPLATE,
DEFAULT_FILE_SEARCH_HEADER_TEMPLATE,
DEFAULT_QUERY_REWRITE_PROMPT,
)
from llama_stack_api import (
Api,
Benchmark,
@ -367,7 +358,7 @@ class RewriteQueryParams(BaseModel):
description="LLM model for query rewriting/expansion in vector search.",
)
prompt: str = Field(
default=DEFAULT_QUERY_REWRITE_PROMPT,
default="Expand this query with relevant synonyms and related terms. Return only the improved query, no explanations:\n\n{query}\n\nImproved query:",
description="Prompt template for query rewriting. Use {query} as placeholder for the original query.",
)
max_tokens: int = Field(
@ -407,11 +398,11 @@ class FileSearchParams(BaseModel):
"""Configuration for file search tool output formatting."""
header_template: str = Field(
default=DEFAULT_FILE_SEARCH_HEADER_TEMPLATE,
default="knowledge_search tool found {num_chunks} chunks:\nBEGIN of knowledge_search tool results.\n",
description="Template for the header text shown before search results. Available placeholders: {num_chunks} number of chunks found.",
)
footer_template: str = Field(
default=DEFAULT_FILE_SEARCH_FOOTER_TEMPLATE,
default="END of knowledge_search tool results.\n",
description="Template for the footer text shown after search results.",
)
@ -433,11 +424,11 @@ class ContextPromptParams(BaseModel):
"""Configuration for LLM prompt content and chunk formatting."""
chunk_annotation_template: str = Field(
default=DEFAULT_CHUNK_ANNOTATION_TEMPLATE,
default="Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
description="Template for formatting individual chunks in search results. Available placeholders: {index} 1-based chunk index, {chunk.content} chunk content, {metadata} chunk metadata dict.",
)
context_template: str = Field(
default=DEFAULT_CONTEXT_TEMPLATE,
default='The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query. {annotation_instruction}\n',
description="Template for explaining the search results to the model. Available placeholders: {query} user's query, {num_chunks} number of chunks.",
)
@ -470,11 +461,11 @@ class AnnotationPromptParams(BaseModel):
description="Whether to include annotation information in results.",
)
annotation_instruction_template: str = Field(
default=DEFAULT_ANNOTATION_INSTRUCTION_TEMPLATE,
default="Cite sources immediately at the end of sentences before punctuation, using `<|file-id|>` format like 'This is a fact <|file-Cn3MSNn72ENTiiq11Qda4A|>.'. Do not add extra punctuation. Use only the file IDs provided, do not invent new ones.",
description="Instructions for how the model should cite sources. Used when enable_annotations is True.",
)
chunk_annotation_template: str = Field(
default=DEFAULT_CHUNK_WITH_SOURCES_TEMPLATE,
default="[{index}] {metadata_text} cite as <|{file_id}|>\n{chunk_text}\n",
description="Template for chunks with annotation information. Available placeholders: {index} 1-based chunk index, {metadata_text} formatted metadata, {file_id} document identifier, {chunk_text} chunk content.",
)
@ -499,6 +490,61 @@ class AnnotationPromptParams(BaseModel):
return v
class FileIngestionParams(BaseModel):
"""Configuration for file processing during ingestion."""
default_chunk_size_tokens: int = Field(
default=512,
description="Default chunk size for RAG tool operations when not specified",
)
default_chunk_overlap_tokens: int = Field(
default=128,
description="Default overlap in tokens between chunks (original default: 512 // 4 = 128)",
)
class ChunkRetrievalParams(BaseModel):
"""Configuration for chunk retrieval and ranking during search."""
chunk_multiplier: int = Field(
default=5,
description="Multiplier for OpenAI API over-retrieval (affects all providers)",
)
max_tokens_in_context: int = Field(
default=4000,
description="Maximum tokens allowed in RAG context before truncation",
)
default_reranker_strategy: str = Field(
default="rrf",
description="Default reranker when not specified: 'rrf', 'weighted', or 'normalized'",
)
rrf_impact_factor: float = Field(
default=60.0,
description="Impact factor for RRF (Reciprocal Rank Fusion) reranking",
)
weighted_search_alpha: float = Field(
default=0.5,
description="Alpha weight for weighted search reranking (0.0-1.0)",
)
class FileBatchParams(BaseModel):
"""Configuration for file batch processing."""
max_concurrent_files_per_batch: int = Field(
default=3,
description="Maximum files processed concurrently in file batches",
)
file_batch_chunk_size: int = Field(
default=10,
description="Number of files to process in each batch chunk",
)
cleanup_interval_seconds: int = Field(
default=86400, # 24 hours
description="Interval for cleaning up expired file batches (seconds)",
)
class VectorStoresConfig(BaseModel):
"""Configuration for vector stores in the stack."""
@ -527,6 +573,19 @@ class VectorStoresConfig(BaseModel):
description="Configuration for source annotation and attribution features.",
)
file_ingestion_params: FileIngestionParams = Field(
default_factory=FileIngestionParams,
description="Configuration for file processing during ingestion.",
)
chunk_retrieval_params: ChunkRetrievalParams = Field(
default_factory=ChunkRetrievalParams,
description="Configuration for chunk retrieval and ranking during search.",
)
file_batch_params: FileBatchParams = Field(
default_factory=FileBatchParams,
description="Configuration for file batch processing.",
)
class SafetyConfig(BaseModel):
"""Configuration for default moderations model."""

View file

@ -11,6 +11,9 @@ def redact_sensitive_fields(data: dict[str, Any]) -> dict[str, Any]:
"""Redact sensitive information from config before printing."""
sensitive_patterns = ["api_key", "api_token", "password", "secret", "token"]
# Specific configuration field names that should NOT be redacted despite containing "token"
safe_token_fields = ["chunk_size_tokens", "max_tokens", "default_chunk_overlap_tokens"]
def _redact_value(v: Any) -> Any:
if isinstance(v, dict):
return _redact_dict(v)
@ -21,7 +24,10 @@ def redact_sensitive_fields(data: dict[str, Any]) -> dict[str, Any]:
def _redact_dict(d: dict[str, Any]) -> dict[str, Any]:
result = {}
for k, v in d.items():
if any(pattern in k.lower() for pattern in sensitive_patterns):
# Don't redact if it's a safe field
if any(safe_field in k.lower() for safe_field in safe_token_fields):
result[k] = _redact_value(v)
elif any(pattern in k.lower() for pattern in sensitive_patterns):
result[k] = "********"
else:
result[k] = _redact_value(v)

View file

@ -296,19 +296,32 @@ vector_stores:
'
context_template: 'The above results were retrieved to help answer the user''s
query: "{query}". Use them as supporting information only in answering this
query.{annotation_instruction}
query. {annotation_instruction}
'
annotation_prompt_params:
enable_annotations: true
annotation_instruction_template: ' Cite sources immediately at the end of sentences
before punctuation, using `<|file-id|>` format like ''This is a fact <|file-Cn3MSNn72ENTiiq11Qda4A|>.''.
annotation_instruction_template: Cite sources immediately at the end of sentences
before punctuation, using `<|file-id|>` format like 'This is a fact <|file-Cn3MSNn72ENTiiq11Qda4A|>.'.
Do not add extra punctuation. Use only the file IDs provided, do not invent
new ones.'
new ones.
chunk_annotation_template: '[{index}] {metadata_text} cite as <|{file_id}|>
{chunk_text}
'
file_ingestion_params:
default_chunk_size_tokens: 512
default_chunk_overlap_tokens: 128
chunk_retrieval_params:
chunk_multiplier: 5
max_tokens_in_context: 4000
default_reranker_strategy: rrf
rrf_impact_factor: 60.0
weighted_search_alpha: 0.5
file_batch_params:
max_concurrent_files_per_batch: 3
file_batch_chunk_size: 10
cleanup_interval_seconds: 86400
safety:
default_shield_id: llama-guard

View file

@ -305,19 +305,32 @@ vector_stores:
'
context_template: 'The above results were retrieved to help answer the user''s
query: "{query}". Use them as supporting information only in answering this
query.{annotation_instruction}
query. {annotation_instruction}
'
annotation_prompt_params:
enable_annotations: true
annotation_instruction_template: ' Cite sources immediately at the end of sentences
before punctuation, using `<|file-id|>` format like ''This is a fact <|file-Cn3MSNn72ENTiiq11Qda4A|>.''.
annotation_instruction_template: Cite sources immediately at the end of sentences
before punctuation, using `<|file-id|>` format like 'This is a fact <|file-Cn3MSNn72ENTiiq11Qda4A|>.'.
Do not add extra punctuation. Use only the file IDs provided, do not invent
new ones.'
new ones.
chunk_annotation_template: '[{index}] {metadata_text} cite as <|{file_id}|>
{chunk_text}
'
file_ingestion_params:
default_chunk_size_tokens: 512
default_chunk_overlap_tokens: 128
chunk_retrieval_params:
chunk_multiplier: 5
max_tokens_in_context: 4000
default_reranker_strategy: rrf
rrf_impact_factor: 60.0
weighted_search_alpha: 0.5
file_batch_params:
max_concurrent_files_per_batch: 3
file_batch_chunk_size: 10
cleanup_interval_seconds: 86400
safety:
default_shield_id: llama-guard

View file

@ -299,19 +299,32 @@ vector_stores:
'
context_template: 'The above results were retrieved to help answer the user''s
query: "{query}". Use them as supporting information only in answering this
query.{annotation_instruction}
query. {annotation_instruction}
'
annotation_prompt_params:
enable_annotations: true
annotation_instruction_template: ' Cite sources immediately at the end of sentences
before punctuation, using `<|file-id|>` format like ''This is a fact <|file-Cn3MSNn72ENTiiq11Qda4A|>.''.
annotation_instruction_template: Cite sources immediately at the end of sentences
before punctuation, using `<|file-id|>` format like 'This is a fact <|file-Cn3MSNn72ENTiiq11Qda4A|>.'.
Do not add extra punctuation. Use only the file IDs provided, do not invent
new ones.'
new ones.
chunk_annotation_template: '[{index}] {metadata_text} cite as <|{file_id}|>
{chunk_text}
'
file_ingestion_params:
default_chunk_size_tokens: 512
default_chunk_overlap_tokens: 128
chunk_retrieval_params:
chunk_multiplier: 5
max_tokens_in_context: 4000
default_reranker_strategy: rrf
rrf_impact_factor: 60.0
weighted_search_alpha: 0.5
file_batch_params:
max_concurrent_files_per_batch: 3
file_batch_chunk_size: 10
cleanup_interval_seconds: 86400
safety:
default_shield_id: llama-guard

View file

@ -308,19 +308,32 @@ vector_stores:
'
context_template: 'The above results were retrieved to help answer the user''s
query: "{query}". Use them as supporting information only in answering this
query.{annotation_instruction}
query. {annotation_instruction}
'
annotation_prompt_params:
enable_annotations: true
annotation_instruction_template: ' Cite sources immediately at the end of sentences
before punctuation, using `<|file-id|>` format like ''This is a fact <|file-Cn3MSNn72ENTiiq11Qda4A|>.''.
annotation_instruction_template: Cite sources immediately at the end of sentences
before punctuation, using `<|file-id|>` format like 'This is a fact <|file-Cn3MSNn72ENTiiq11Qda4A|>.'.
Do not add extra punctuation. Use only the file IDs provided, do not invent
new ones.'
new ones.
chunk_annotation_template: '[{index}] {metadata_text} cite as <|{file_id}|>
{chunk_text}
'
file_ingestion_params:
default_chunk_size_tokens: 512
default_chunk_overlap_tokens: 128
chunk_retrieval_params:
chunk_multiplier: 5
max_tokens_in_context: 4000
default_reranker_strategy: rrf
rrf_impact_factor: 60.0
weighted_search_alpha: 0.5
file_batch_params:
max_concurrent_files_per_batch: 3
file_batch_chunk_size: 10
cleanup_interval_seconds: 86400
safety:
default_shield_id: llama-guard

View file

@ -296,19 +296,32 @@ vector_stores:
'
context_template: 'The above results were retrieved to help answer the user''s
query: "{query}". Use them as supporting information only in answering this
query.{annotation_instruction}
query. {annotation_instruction}
'
annotation_prompt_params:
enable_annotations: true
annotation_instruction_template: ' Cite sources immediately at the end of sentences
before punctuation, using `<|file-id|>` format like ''This is a fact <|file-Cn3MSNn72ENTiiq11Qda4A|>.''.
annotation_instruction_template: Cite sources immediately at the end of sentences
before punctuation, using `<|file-id|>` format like 'This is a fact <|file-Cn3MSNn72ENTiiq11Qda4A|>.'.
Do not add extra punctuation. Use only the file IDs provided, do not invent
new ones.'
new ones.
chunk_annotation_template: '[{index}] {metadata_text} cite as <|{file_id}|>
{chunk_text}
'
file_ingestion_params:
default_chunk_size_tokens: 512
default_chunk_overlap_tokens: 128
chunk_retrieval_params:
chunk_multiplier: 5
max_tokens_in_context: 4000
default_reranker_strategy: rrf
rrf_impact_factor: 60.0
weighted_search_alpha: 0.5
file_batch_params:
max_concurrent_files_per_batch: 3
file_batch_chunk_size: 10
cleanup_interval_seconds: 86400
safety:
default_shield_id: llama-guard

View file

@ -305,19 +305,32 @@ vector_stores:
'
context_template: 'The above results were retrieved to help answer the user''s
query: "{query}". Use them as supporting information only in answering this
query.{annotation_instruction}
query. {annotation_instruction}
'
annotation_prompt_params:
enable_annotations: true
annotation_instruction_template: ' Cite sources immediately at the end of sentences
before punctuation, using `<|file-id|>` format like ''This is a fact <|file-Cn3MSNn72ENTiiq11Qda4A|>.''.
annotation_instruction_template: Cite sources immediately at the end of sentences
before punctuation, using `<|file-id|>` format like 'This is a fact <|file-Cn3MSNn72ENTiiq11Qda4A|>.'.
Do not add extra punctuation. Use only the file IDs provided, do not invent
new ones.'
new ones.
chunk_annotation_template: '[{index}] {metadata_text} cite as <|{file_id}|>
{chunk_text}
'
file_ingestion_params:
default_chunk_size_tokens: 512
default_chunk_overlap_tokens: 128
chunk_retrieval_params:
chunk_multiplier: 5
max_tokens_in_context: 4000
default_reranker_strategy: rrf
rrf_impact_factor: 60.0
weighted_search_alpha: 0.5
file_batch_params:
max_concurrent_files_per_batch: 3
file_batch_chunk_size: 10
cleanup_interval_seconds: 86400
safety:
default_shield_id: llama-guard

View file

@ -11,11 +11,8 @@ from typing import Any
from opentelemetry import trace
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.log import get_logger
from llama_stack.providers.utils.memory.constants import (
DEFAULT_ANNOTATION_INSTRUCTION_TEMPLATE,
DEFAULT_CHUNK_WITH_SOURCES_TEMPLATE,
)
from llama_stack_api import (
ImageContentItem,
OpenAIChatCompletionContentPartImageParam,
@ -175,8 +172,10 @@ class ToolExecutor:
self.vector_stores_config.annotation_prompt_params.annotation_instruction_template
)
else:
chunk_annotation_template = DEFAULT_CHUNK_WITH_SOURCES_TEMPLATE
annotation_instruction_template = DEFAULT_ANNOTATION_INSTRUCTION_TEMPLATE
# Use defaults from VectorStoresConfig when annotations disabled
default_config = VectorStoresConfig()
chunk_annotation_template = default_config.annotation_prompt_params.chunk_annotation_template
annotation_instruction_template = default_config.annotation_prompt_params.annotation_instruction_template
content_items = []
content_items.append(TextContentItem(text=header_template.format(num_chunks=len(search_results))))

View file

@ -116,8 +116,10 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime):
self,
documents: list[RAGDocument],
vector_store_id: str,
chunk_size_in_tokens: int = 512,
chunk_size_in_tokens: int | None = None,
) -> None:
if chunk_size_in_tokens is None:
chunk_size_in_tokens = self.config.vector_stores_config.file_ingestion_params.default_chunk_size_tokens
if not documents:
return
@ -145,10 +147,11 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime):
log.error(f"Failed to upload file for document {doc.document_id}: {e}")
continue
overlap_tokens = self.config.vector_stores_config.file_ingestion_params.default_chunk_overlap_tokens
chunking_strategy = VectorStoreChunkingStrategyStatic(
static=VectorStoreChunkingStrategyStaticConfig(
max_chunk_size_tokens=chunk_size_in_tokens,
chunk_overlap_tokens=chunk_size_in_tokens // 4,
chunk_overlap_tokens=overlap_tokens,
)
)
@ -180,7 +183,9 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime):
"No vector DBs were provided to the knowledge search tool. Please provide at least one vector DB ID."
)
query_config = query_config or RAGQueryConfig()
query_config = query_config or RAGQueryConfig(
max_tokens_in_context=self.config.vector_stores_config.chunk_retrieval_params.max_tokens_in_context
)
query = await generate_rag_query(
query_config.query_generator_config,
content,
@ -319,7 +324,9 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime):
if query_config:
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
else:
query_config = RAGQueryConfig()
query_config = RAGQueryConfig(
max_tokens_in_context=self.config.vector_stores_config.chunk_retrieval_params.max_tokens_in_context
)
query = kwargs["query"]
result = await self.query(

View file

@ -4,6 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .constants import DEFAULT_QUERY_REWRITE_PROMPT
__all__ = ["DEFAULT_QUERY_REWRITE_PROMPT"]
__all__ = []

View file

@ -1,22 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Default prompt template for query rewriting in vector search
DEFAULT_QUERY_REWRITE_PROMPT = "Expand this query with relevant synonyms and related terms. Return only the improved query, no explanations:\n\n{query}\n\nImproved query:"
# Default templates for file search tool output formatting
DEFAULT_FILE_SEARCH_HEADER_TEMPLATE = (
"knowledge_search tool found {num_chunks} chunks:\nBEGIN of knowledge_search tool results.\n"
)
DEFAULT_FILE_SEARCH_FOOTER_TEMPLATE = "END of knowledge_search tool results.\n"
# Default templates for LLM prompt content and chunk formatting
DEFAULT_CHUNK_ANNOTATION_TEMPLATE = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
DEFAULT_CONTEXT_TEMPLATE = 'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.{annotation_instruction}\n'
# Default templates for source annotation and attribution features
DEFAULT_ANNOTATION_INSTRUCTION_TEMPLATE = " Cite sources immediately at the end of sentences before punctuation, using `<|file-id|>` format like 'This is a fact <|file-Cn3MSNn72ENTiiq11Qda4A|>.'. Do not add extra punctuation. Use only the file IDs provided, do not invent new ones."
DEFAULT_CHUNK_WITH_SOURCES_TEMPLATE = "[{index}] {metadata_text} cite as <|{file_id}|>\n{chunk_text}\n"

View file

@ -15,6 +15,7 @@ from typing import Annotated, Any
from fastapi import Body
from pydantic import TypeAdapter
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.core.id_generation import generate_object_id
from llama_stack.log import get_logger
from llama_stack.providers.utils.memory.vector_store import (
@ -59,10 +60,6 @@ EMBEDDING_DIMENSION = 768
logger = get_logger(name=__name__, category="providers::utils")
# Constants for OpenAI vector stores
CHUNK_MULTIPLIER = 5
FILE_BATCH_CLEANUP_INTERVAL_SECONDS = 24 * 60 * 60 # 1 day in seconds
MAX_CONCURRENT_FILES_PER_BATCH = 3 # Maximum concurrent file processing within a batch
FILE_BATCH_CHUNK_SIZE = 10 # Process files in chunks of this size
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_stores:{VERSION}::"
@ -85,11 +82,13 @@ class OpenAIVectorStoreMixin(ABC):
self,
files_api: Files | None = None,
kvstore: KVStore | None = None,
vector_stores_config: VectorStoresConfig | None = None,
):
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.openai_file_batches: dict[str, dict[str, Any]] = {}
self.files_api = files_api
self.kvstore = kvstore
self.vector_stores_config = vector_stores_config or VectorStoresConfig()
self._last_file_batch_cleanup_time = 0
self._file_batch_tasks: dict[str, asyncio.Task[None]] = {}
self._vector_store_locks: dict[str, asyncio.Lock] = {}
@ -619,7 +618,7 @@ class OpenAIVectorStoreMixin(ABC):
else 0.0
)
params = {
"max_chunks": max_num_results * CHUNK_MULTIPLIER,
"max_chunks": max_num_results * self.vector_stores_config.chunk_retrieval_params.chunk_multiplier,
"score_threshold": score_threshold,
"mode": search_mode,
}
@ -1072,7 +1071,10 @@ class OpenAIVectorStoreMixin(ABC):
# Run cleanup if needed (throttled to once every 1 day)
current_time = int(time.time())
if current_time - self._last_file_batch_cleanup_time >= FILE_BATCH_CLEANUP_INTERVAL_SECONDS:
if (
current_time - self._last_file_batch_cleanup_time
>= self.vector_stores_config.file_batch_params.cleanup_interval_seconds
):
logger.info("Running throttled cleanup of expired file batches")
asyncio.create_task(self._cleanup_expired_file_batches())
self._last_file_batch_cleanup_time = current_time
@ -1089,7 +1091,7 @@ class OpenAIVectorStoreMixin(ABC):
batch_info: dict[str, Any],
) -> None:
"""Process files with controlled concurrency and chunking."""
semaphore = asyncio.Semaphore(MAX_CONCURRENT_FILES_PER_BATCH)
semaphore = asyncio.Semaphore(self.vector_stores_config.file_batch_params.max_concurrent_files_per_batch)
async def process_single_file(file_id: str) -> tuple[str, bool]:
"""Process a single file with concurrency control."""
@ -1108,12 +1110,13 @@ class OpenAIVectorStoreMixin(ABC):
# Process files in chunks to avoid creating too many tasks at once
total_files = len(file_ids)
for chunk_start in range(0, total_files, FILE_BATCH_CHUNK_SIZE):
chunk_end = min(chunk_start + FILE_BATCH_CHUNK_SIZE, total_files)
chunk_size = self.vector_stores_config.file_batch_params.file_batch_chunk_size
for chunk_start in range(0, total_files, chunk_size):
chunk_end = min(chunk_start + chunk_size, total_files)
chunk = file_ids[chunk_start:chunk_end]
chunk_num = chunk_start // FILE_BATCH_CHUNK_SIZE + 1
total_chunks = (total_files + FILE_BATCH_CHUNK_SIZE - 1) // FILE_BATCH_CHUNK_SIZE
chunk_num = chunk_start // chunk_size + 1
total_chunks = (total_files + chunk_size - 1) // chunk_size
logger.info(
f"Processing chunk {chunk_num} of {total_chunks} ({len(chunk)} files, {chunk_start + 1}-{chunk_end} of {total_files} total files)"
)

View file

@ -17,6 +17,7 @@ import numpy as np
from numpy.typing import NDArray
from pydantic import BaseModel
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.utils.inference.prompt_adapter import (
@ -262,6 +263,7 @@ class VectorStoreWithIndex:
vector_store: VectorStore
index: EmbeddingIndex
inference_api: Api.inference
vector_stores_config: VectorStoresConfig | None = None
async def insert_chunks(
self,
@ -294,6 +296,8 @@ class VectorStoreWithIndex:
query: InterleavedContent,
params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
config = self.vector_stores_config or VectorStoresConfig()
if params is None:
params = {}
k = params.get("max_chunks", 3)
@ -302,19 +306,25 @@ class VectorStoreWithIndex:
ranker = params.get("ranker")
if ranker is None:
reranker_type = RERANKER_TYPE_RRF
reranker_params = {"impact_factor": 60.0}
reranker_type = (
RERANKER_TYPE_RRF
if config.chunk_retrieval_params.default_reranker_strategy == "rrf"
else config.chunk_retrieval_params.default_reranker_strategy
)
reranker_params = {"impact_factor": config.chunk_retrieval_params.rrf_impact_factor}
else:
strategy = ranker.get("strategy", "rrf")
strategy = ranker.get("strategy", config.chunk_retrieval_params.default_reranker_strategy)
if strategy == "weighted":
weights = ranker.get("params", {}).get("weights", [0.5, 0.5])
reranker_type = RERANKER_TYPE_WEIGHTED
reranker_params = {"alpha": weights[0] if len(weights) > 0 else 0.5}
reranker_params = {
"alpha": weights[0] if len(weights) > 0 else config.chunk_retrieval_params.weighted_search_alpha
}
elif strategy == "normalized":
reranker_type = RERANKER_TYPE_NORMALIZED
else:
reranker_type = RERANKER_TYPE_RRF
k_value = ranker.get("params", {}).get("k", 60.0)
k_value = ranker.get("params", {}).get("k", config.chunk_retrieval_params.rrf_impact_factor)
reranker_params = {"impact_factor": k_value}
query_string = interleaved_content_as_str(query)