mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-18 19:02:30 +00:00
feat: Introduce weighted and rrf reranker implementations
Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
parent
eab85a7121
commit
6ea5c10d48
14 changed files with 637 additions and 75 deletions
|
@ -32,6 +32,10 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Constants for reranker types
|
||||
RERANKER_TYPE_RRF = "rrf"
|
||||
RERANKER_TYPE_WEIGHTED = "weighted"
|
||||
|
||||
|
||||
def parse_pdf(data: bytes) -> str:
|
||||
# For PDF and DOC/DOCX files, we can't reliably convert to string
|
||||
|
@ -204,7 +208,13 @@ class EmbeddingIndex(ABC):
|
|||
|
||||
@abstractmethod
|
||||
async def query_hybrid(
|
||||
self, embedding: NDArray, query_string: str, k: int, score_threshold: float
|
||||
self,
|
||||
embedding: NDArray,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
@ -251,15 +261,29 @@ class VectorDBWithIndex:
|
|||
k = params.get("max_chunks", 3)
|
||||
mode = params.get("mode")
|
||||
score_threshold = params.get("score_threshold", 0.0)
|
||||
|
||||
# Get ranker configuration
|
||||
ranker = params.get("ranker")
|
||||
if ranker is None:
|
||||
# Default to RRF with impact_factor=60.0
|
||||
reranker_type = RERANKER_TYPE_RRF
|
||||
reranker_params = {"impact_factor": 60.0}
|
||||
else:
|
||||
reranker_type = ranker.type
|
||||
reranker_params = (
|
||||
{"impact_factor": ranker.impact_factor} if ranker.type == RERANKER_TYPE_RRF else {"alpha": ranker.alpha}
|
||||
)
|
||||
|
||||
query_string = interleaved_content_as_str(query)
|
||||
if mode == "keyword":
|
||||
return await self.index.query_keyword(query_string, k, score_threshold)
|
||||
|
||||
# Calculate embeddings for both vector and hybrid modes
|
||||
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
|
||||
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
||||
|
||||
if mode == "keyword":
|
||||
return await self.index.query_keyword(query_string, k, score_threshold)
|
||||
elif mode == "hybrid":
|
||||
return await self.index.query_hybrid(query_vector, query_string, k, score_threshold)
|
||||
if mode == "hybrid":
|
||||
return await self.index.query_hybrid(
|
||||
query_vector, query_string, k, score_threshold, reranker_type, reranker_params
|
||||
)
|
||||
else:
|
||||
return await self.index.query_vector(query_vector, k, score_threshold)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue