mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-18 02:42:31 +00:00
feat: Introduce weighted and rrf reranker implementations
Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
parent
eab85a7121
commit
6ea5c10d48
14 changed files with 637 additions and 75 deletions
|
@ -15,6 +15,48 @@ from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
|||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RRFRanker(BaseModel):
|
||||
"""
|
||||
Reciprocal Rank Fusion (RRF) ranker configuration.
|
||||
|
||||
:param type: The type of ranker, always "rrf"
|
||||
:param impact_factor: The impact factor for RRF scoring. Higher values give more weight to higher-ranked results.
|
||||
Must be greater than 0. Default of 60 is from the original RRF paper (Cormack et al., 2009).
|
||||
"""
|
||||
|
||||
type: Literal["rrf"] = "rrf"
|
||||
impact_factor: float = Field(default=60.0, gt=0.0) # default of 60 for optimal performance
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class WeightedRanker(BaseModel):
|
||||
"""
|
||||
Weighted ranker configuration that combines vector and keyword scores.
|
||||
|
||||
:param type: The type of ranker, always "weighted"
|
||||
:param alpha: Weight factor between 0 and 1.
|
||||
0 means only use keyword scores,
|
||||
1 means only use vector scores,
|
||||
values in between blend both scores.
|
||||
"""
|
||||
|
||||
type: Literal["weighted"] = "weighted"
|
||||
alpha: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Weight factor between 0 and 1. 0 means only keyword scores, 1 means only vector scores.",
|
||||
)
|
||||
|
||||
|
||||
Ranker = Annotated[
|
||||
RRFRanker | WeightedRanker,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(Ranker, name="Ranker")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RAGDocument(BaseModel):
|
||||
"""
|
||||
|
@ -77,6 +119,7 @@ class RAGQueryConfig(BaseModel):
|
|||
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
|
||||
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
|
||||
:param mode: Search mode for retrieval—either "vector", "keyword", or "hybrid". Default "vector".
|
||||
:param ranker: Configuration for the ranker to use in hybrid search. Defaults to RRF ranker.
|
||||
"""
|
||||
|
||||
# This config defines how a query is generated using the messages
|
||||
|
@ -86,6 +129,7 @@ class RAGQueryConfig(BaseModel):
|
|||
max_chunks: int = 5
|
||||
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
|
||||
mode: str | None = None
|
||||
ranker: Ranker | None = Field(default=None) # Only used for hybrid mode
|
||||
|
||||
@field_validator("chunk_template")
|
||||
def validate_chunk_template(cls, v: str) -> str:
|
||||
|
|
|
@ -121,8 +121,10 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
vector_db_id=vector_db_id,
|
||||
query=query,
|
||||
params={
|
||||
"max_chunks": query_config.max_chunks,
|
||||
"mode": query_config.mode,
|
||||
"max_chunks": query_config.max_chunks,
|
||||
"score_threshold": 0.0,
|
||||
"ranker": query_config.ranker,
|
||||
},
|
||||
)
|
||||
for vector_db_id in vector_db_ids
|
||||
|
|
|
@ -137,6 +137,8 @@ class FaissIndex(EmbeddingIndex):
|
|||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Hybrid search is not supported in FAISS")
|
||||
|
||||
|
|
|
@ -27,7 +27,12 @@ from llama_stack.apis.vector_io import (
|
|||
)
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
RERANKER_TYPE_RRF,
|
||||
RERANKER_TYPE_WEIGHTED,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -52,6 +57,59 @@ def _create_sqlite_connection(db_path):
|
|||
return connection
|
||||
|
||||
|
||||
def _normalize_scores(scores: dict[str, float]) -> dict[str, float]:
|
||||
"""Normalize scores to [0,1] range using min-max normalization."""
|
||||
if not scores:
|
||||
return {}
|
||||
min_score = min(scores.values())
|
||||
max_score = max(scores.values())
|
||||
score_range = max_score - min_score
|
||||
if score_range > 0:
|
||||
return {doc_id: (score - min_score) / score_range for doc_id, score in scores.items()}
|
||||
return {doc_id: 1.0 for doc_id in scores}
|
||||
|
||||
|
||||
def _weighted_rerank(
|
||||
vector_scores: dict[str, float],
|
||||
keyword_scores: dict[str, float],
|
||||
alpha: float = 0.5,
|
||||
) -> dict[str, float]:
|
||||
"""ReRanker that uses weighted average of scores."""
|
||||
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
|
||||
normalized_vector_scores = _normalize_scores(vector_scores)
|
||||
normalized_keyword_scores = _normalize_scores(keyword_scores)
|
||||
|
||||
return {
|
||||
doc_id: (alpha * normalized_keyword_scores.get(doc_id, 0.0))
|
||||
+ ((1 - alpha) * normalized_vector_scores.get(doc_id, 0.0))
|
||||
for doc_id in all_ids
|
||||
}
|
||||
|
||||
|
||||
def _rrf_rerank(
|
||||
vector_scores: dict[str, float],
|
||||
keyword_scores: dict[str, float],
|
||||
impact_factor: float = 60.0,
|
||||
) -> dict[str, float]:
|
||||
"""ReRanker that uses Reciprocal Rank Fusion."""
|
||||
# Convert scores to ranks
|
||||
vector_ranks = {
|
||||
doc_id: i + 1 for i, (doc_id, _) in enumerate(sorted(vector_scores.items(), key=lambda x: x[1], reverse=True))
|
||||
}
|
||||
keyword_ranks = {
|
||||
doc_id: i + 1 for i, (doc_id, _) in enumerate(sorted(keyword_scores.items(), key=lambda x: x[1], reverse=True))
|
||||
}
|
||||
|
||||
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
|
||||
rrf_scores = {}
|
||||
for doc_id in all_ids:
|
||||
vector_rank = vector_ranks.get(doc_id, float("inf"))
|
||||
keyword_rank = keyword_ranks.get(doc_id, float("inf"))
|
||||
# RRF formula: score = 1/(k + r) where k is impact_factor and r is the rank
|
||||
rrf_scores[doc_id] = (1.0 / (impact_factor + vector_rank)) + (1.0 / (impact_factor + keyword_rank))
|
||||
return rrf_scores
|
||||
|
||||
|
||||
class SQLiteVecIndex(EmbeddingIndex):
|
||||
"""
|
||||
An index implementation that stores embeddings in a SQLite virtual table using sqlite-vec.
|
||||
|
@ -299,60 +357,72 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
reranker_type: str = RERANKER_TYPE_RRF,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
"""
|
||||
Hybrid search using Reciprocal Rank Fusion (RRF) to combine vector and keyword search results.
|
||||
RRF assigns scores based on the reciprocal of the rank position in each search method,
|
||||
then combines these scores to get a final ranking.
|
||||
Hybrid search using a configurable re-ranking strategy.
|
||||
|
||||
Args:
|
||||
embedding: The query embedding vector
|
||||
query_string: The text query for keyword search
|
||||
k: Number of results to return
|
||||
score_threshold: Minimum similarity score threshold
|
||||
reranker_type: Type of reranker to use ("rrf" or "weighted")
|
||||
reranker_params: Parameters for the reranker
|
||||
|
||||
Returns:
|
||||
QueryChunksResponse with combined results
|
||||
"""
|
||||
if reranker_params is None:
|
||||
reranker_params = {}
|
||||
|
||||
# Get results from both search methods
|
||||
vector_response = await self.query_vector(embedding, k * 2, score_threshold)
|
||||
keyword_response = await self.query_keyword(query_string, k * 2, score_threshold)
|
||||
vector_response = await self.query_vector(embedding, k, score_threshold)
|
||||
keyword_response = await self.query_keyword(query_string, k, score_threshold)
|
||||
|
||||
# Create dictionaries to store ranks for each method
|
||||
vector_ranks = {chunk.metadata["document_id"]: i + 1 for i, chunk in enumerate(vector_response.chunks)}
|
||||
keyword_ranks = {chunk.metadata["document_id"]: i + 1 for i, chunk in enumerate(keyword_response.chunks)}
|
||||
# Convert responses to score dictionaries using generate_chunk_id
|
||||
vector_scores = {
|
||||
generate_chunk_id(chunk.metadata["document_id"], str(chunk.content)): score
|
||||
for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False)
|
||||
}
|
||||
keyword_scores = {
|
||||
generate_chunk_id(chunk.metadata["document_id"], str(chunk.content)): score
|
||||
for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False)
|
||||
}
|
||||
|
||||
# Calculate RRF scores for all unique document IDs
|
||||
all_ids = set(vector_ranks.keys()) | set(keyword_ranks.keys())
|
||||
rrf_scores = {}
|
||||
for doc_id in all_ids:
|
||||
vector_rank = vector_ranks.get(doc_id, float("inf"))
|
||||
keyword_rank = keyword_ranks.get(doc_id, float("inf"))
|
||||
# RRF formula: score = 1/(k + r) where k is a constant and r is the rank
|
||||
rrf_scores[doc_id] = (1.0 / (60 + vector_rank)) + (1.0 / (60 + keyword_rank))
|
||||
# Combine scores using the specified reranker
|
||||
if reranker_type == RERANKER_TYPE_WEIGHTED:
|
||||
alpha = reranker_params.get("alpha", 0.5)
|
||||
combined_scores = _weighted_rerank(vector_scores, keyword_scores, alpha)
|
||||
else:
|
||||
# Default to RRF for None, RRF, or any unknown types
|
||||
impact_factor = reranker_params.get("impact_factor", 60.0)
|
||||
combined_scores = _rrf_rerank(vector_scores, keyword_scores, impact_factor)
|
||||
|
||||
# Sort by RRF score and get top k results
|
||||
sorted_ids = sorted(rrf_scores.keys(), key=lambda x: rrf_scores[x], reverse=True)[:k]
|
||||
# Sort by combined score and get top k results
|
||||
sorted_items = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)
|
||||
top_k_items = sorted_items[:k]
|
||||
|
||||
# Combine results maintaining RRF scores
|
||||
# Filter by score threshold
|
||||
filtered_items = [(doc_id, score) for doc_id, score in top_k_items if score >= score_threshold]
|
||||
|
||||
# Create a map of chunk_id to chunk for both responses
|
||||
chunk_map = {}
|
||||
for c in vector_response.chunks:
|
||||
chunk_id = generate_chunk_id(c.metadata["document_id"], str(c.content))
|
||||
chunk_map[chunk_id] = c
|
||||
for c in keyword_response.chunks:
|
||||
chunk_id = generate_chunk_id(c.metadata["document_id"], str(c.content))
|
||||
chunk_map[chunk_id] = c
|
||||
|
||||
# Use the map to look up chunks by their IDs
|
||||
chunks = []
|
||||
scores = []
|
||||
for doc_id in sorted_ids:
|
||||
score = rrf_scores[doc_id]
|
||||
if score >= score_threshold:
|
||||
# Try to get from vector results first
|
||||
for chunk in vector_response.chunks:
|
||||
if chunk.metadata["document_id"] == doc_id:
|
||||
chunks.append(chunk)
|
||||
scores.append(score)
|
||||
break
|
||||
else:
|
||||
# If not in vector results, get from keyword results
|
||||
for chunk in keyword_response.chunks:
|
||||
if chunk.metadata["document_id"] == doc_id:
|
||||
chunks.append(chunk)
|
||||
scores.append(score)
|
||||
break
|
||||
for doc_id, score in filtered_items:
|
||||
if doc_id in chunk_map:
|
||||
chunks.append(chunk_map[doc_id])
|
||||
scores.append(score)
|
||||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
|
|
@ -111,6 +111,8 @@ class ChromaIndex(EmbeddingIndex):
|
|||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Hybrid search is not supported in Chroma")
|
||||
|
||||
|
|
|
@ -109,6 +109,8 @@ class MilvusIndex(EmbeddingIndex):
|
|||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Hybrid search is not supported in Milvus")
|
||||
|
||||
|
|
|
@ -134,6 +134,8 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Hybrid search is not supported in PGVector")
|
||||
|
||||
|
|
|
@ -118,6 +118,8 @@ class QdrantIndex(EmbeddingIndex):
|
|||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Hybrid search is not supported in Qdrant")
|
||||
|
||||
|
|
|
@ -98,6 +98,8 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Hybrid search is not supported in Weaviate")
|
||||
|
||||
|
|
|
@ -32,6 +32,10 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Constants for reranker types
|
||||
RERANKER_TYPE_RRF = "rrf"
|
||||
RERANKER_TYPE_WEIGHTED = "weighted"
|
||||
|
||||
|
||||
def parse_pdf(data: bytes) -> str:
|
||||
# For PDF and DOC/DOCX files, we can't reliably convert to string
|
||||
|
@ -204,7 +208,13 @@ class EmbeddingIndex(ABC):
|
|||
|
||||
@abstractmethod
|
||||
async def query_hybrid(
|
||||
self, embedding: NDArray, query_string: str, k: int, score_threshold: float
|
||||
self,
|
||||
embedding: NDArray,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
@ -251,15 +261,29 @@ class VectorDBWithIndex:
|
|||
k = params.get("max_chunks", 3)
|
||||
mode = params.get("mode")
|
||||
score_threshold = params.get("score_threshold", 0.0)
|
||||
|
||||
# Get ranker configuration
|
||||
ranker = params.get("ranker")
|
||||
if ranker is None:
|
||||
# Default to RRF with impact_factor=60.0
|
||||
reranker_type = RERANKER_TYPE_RRF
|
||||
reranker_params = {"impact_factor": 60.0}
|
||||
else:
|
||||
reranker_type = ranker.type
|
||||
reranker_params = (
|
||||
{"impact_factor": ranker.impact_factor} if ranker.type == RERANKER_TYPE_RRF else {"alpha": ranker.alpha}
|
||||
)
|
||||
|
||||
query_string = interleaved_content_as_str(query)
|
||||
if mode == "keyword":
|
||||
return await self.index.query_keyword(query_string, k, score_threshold)
|
||||
|
||||
# Calculate embeddings for both vector and hybrid modes
|
||||
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
|
||||
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
||||
|
||||
if mode == "keyword":
|
||||
return await self.index.query_keyword(query_string, k, score_threshold)
|
||||
elif mode == "hybrid":
|
||||
return await self.index.query_hybrid(query_vector, query_string, k, score_threshold)
|
||||
if mode == "hybrid":
|
||||
return await self.index.query_hybrid(
|
||||
query_vector, query_string, k, score_threshold, reranker_type, reranker_params
|
||||
)
|
||||
else:
|
||||
return await self.index.query_vector(query_vector, k, score_threshold)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue