mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-23 21:04:29 +00:00
feat: Introduce weighted and rrf reranker implementations
Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
parent
eab85a7121
commit
6ea5c10d48
14 changed files with 637 additions and 75 deletions
|
@ -93,7 +93,12 @@ async def test_query_chunks_hybrid(sqlite_vec_index, sample_chunks, sample_embed
|
|||
query_string = "Sentence 5"
|
||||
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
embedding=query_embedding, query_string=query_string, k=3, score_threshold=0.0
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=3,
|
||||
score_threshold=0.0,
|
||||
reranker_type="rrf",
|
||||
reranker_params={"impact_factor": 60.0},
|
||||
)
|
||||
|
||||
assert len(response.chunks) == 3, f"Expected 3 results, got {len(response.chunks)}"
|
||||
|
@ -175,7 +180,12 @@ async def test_query_chunks_hybrid_no_keyword_matches(sqlite_vec_index, sample_c
|
|||
|
||||
# Get hybrid results
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
embedding=query_embedding, query_string=query_string, k=3, score_threshold=0.0
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=3,
|
||||
score_threshold=0.0,
|
||||
reranker_type="rrf",
|
||||
reranker_params={"impact_factor": 60.0},
|
||||
)
|
||||
|
||||
# Should still get results from vector search
|
||||
|
@ -198,6 +208,8 @@ async def test_query_chunks_hybrid_score_threshold(sqlite_vec_index, sample_chun
|
|||
query_string=query_string,
|
||||
k=3,
|
||||
score_threshold=1000.0, # Very high threshold
|
||||
reranker_type="rrf",
|
||||
reranker_params={"impact_factor": 60.0},
|
||||
)
|
||||
|
||||
# Should return no results due to high threshold
|
||||
|
@ -216,7 +228,12 @@ async def test_query_chunks_hybrid_different_embedding(
|
|||
query_string = "Sentence 5"
|
||||
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
embedding=query_embedding, query_string=query_string, k=3, score_threshold=0.0
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=3,
|
||||
score_threshold=0.0,
|
||||
reranker_type="rrf",
|
||||
reranker_params={"impact_factor": 60.0},
|
||||
)
|
||||
|
||||
# Should still get results if keyword matches exist
|
||||
|
@ -236,7 +253,12 @@ async def test_query_chunks_hybrid_rrf_ranking(sqlite_vec_index, sample_chunks,
|
|||
query_string = "Sentence 5"
|
||||
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
embedding=query_embedding, query_string=query_string, k=5, score_threshold=0.0
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=5,
|
||||
score_threshold=0.0,
|
||||
reranker_type="rrf",
|
||||
reranker_params={"impact_factor": 60.0},
|
||||
)
|
||||
|
||||
# Verify we get results from both search methods
|
||||
|
@ -247,7 +269,6 @@ async def test_query_chunks_hybrid_rrf_ranking(sqlite_vec_index, sample_chunks,
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_hybrid_score_selection(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
"""Test that we correctly rank documents that appear in both search methods."""
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Create a query embedding that's similar to the first chunk
|
||||
|
@ -255,26 +276,43 @@ async def test_query_chunks_hybrid_score_selection(sqlite_vec_index, sample_chun
|
|||
# Use a keyword that appears in the first document
|
||||
query_string = "Sentence 0 from document 0"
|
||||
|
||||
# First get individual results to verify ranks
|
||||
vector_response = await sqlite_vec_index.query_vector(query_embedding, k=5, score_threshold=0.0)
|
||||
keyword_response = await sqlite_vec_index.query_keyword(query_string, k=5, score_threshold=0.0)
|
||||
|
||||
# Verify document-0 appears in both results
|
||||
assert any(chunk.metadata["document_id"] == "document-0" for chunk in vector_response.chunks), (
|
||||
"document-0 not found in vector search results"
|
||||
)
|
||||
assert any(chunk.metadata["document_id"] == "document-0" for chunk in keyword_response.chunks), (
|
||||
"document-0 not found in keyword search results"
|
||||
)
|
||||
|
||||
# Now get hybrid results
|
||||
# Test weighted re-ranking
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
embedding=query_embedding, query_string=query_string, k=1, score_threshold=0.0
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=1,
|
||||
score_threshold=0.0,
|
||||
reranker_type="weighted",
|
||||
reranker_params={"alpha": 0.5},
|
||||
)
|
||||
|
||||
# Verify document-0 is ranked first in hybrid results
|
||||
assert len(response.chunks) == 1
|
||||
assert response.chunks[0].metadata["document_id"] == "document-0", "document-0 not ranked first in hybrid results"
|
||||
# Score should be weighted average of normalized keyword score and vector score
|
||||
assert response.scores[0] > 0.5 # Both scores should be high
|
||||
|
||||
# Test RRF re-ranking
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=1,
|
||||
score_threshold=0.0,
|
||||
reranker_type="rrf",
|
||||
reranker_params={"impact_factor": 60.0},
|
||||
)
|
||||
assert len(response.chunks) == 1
|
||||
# RRF score should be sum of reciprocal ranks
|
||||
assert response.scores[0] == pytest.approx(2.0 / 61.0, rel=1e-6) # 1/(60+1) + 1/(60+1)
|
||||
|
||||
# Test default re-ranking (should be RRF)
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=1,
|
||||
score_threshold=0.0,
|
||||
reranker_type="rrf",
|
||||
reranker_params={"impact_factor": 60.0},
|
||||
)
|
||||
assert len(response.chunks) == 1
|
||||
assert response.scores[0] == pytest.approx(2.0 / 61.0, rel=1e-6) # Should behave like RRF
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -288,7 +326,12 @@ async def test_query_chunks_hybrid_mixed_results(sqlite_vec_index, sample_chunks
|
|||
query_string = "Sentence 9 from document 2"
|
||||
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
embedding=query_embedding, query_string=query_string, k=3, score_threshold=0.0
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=3,
|
||||
score_threshold=0.0,
|
||||
reranker_type="rrf",
|
||||
reranker_params={"impact_factor": 60.0},
|
||||
)
|
||||
|
||||
# Should get results from both search methods
|
||||
|
@ -299,3 +342,176 @@ async def test_query_chunks_hybrid_mixed_results(sqlite_vec_index, sample_chunks
|
|||
doc_ids = {chunk.metadata["document_id"] for chunk in response.chunks}
|
||||
assert "document-0" in doc_ids # From vector search
|
||||
assert "document-2" in doc_ids # From keyword search
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_hybrid_weighted_reranker_parametrization(
|
||||
sqlite_vec_index, sample_chunks, sample_embeddings
|
||||
):
|
||||
"""Test WeightedReRanker with different alpha values."""
|
||||
# Re-add data before each search to ensure test isolation
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
query_embedding = sample_embeddings[0]
|
||||
query_string = "Sentence 0 from document 0"
|
||||
|
||||
# alpha=1.0 (should behave like pure keyword)
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=1,
|
||||
score_threshold=0.0,
|
||||
reranker_type="weighted",
|
||||
reranker_params={"alpha": 1.0},
|
||||
)
|
||||
assert len(response.chunks) > 0 # Should get at least one result
|
||||
assert any("document-0" in chunk.metadata["document_id"] for chunk in response.chunks)
|
||||
|
||||
# alpha=0.0 (should behave like pure vector)
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=1,
|
||||
score_threshold=0.0,
|
||||
reranker_type="weighted",
|
||||
reranker_params={"alpha": 0.0},
|
||||
)
|
||||
assert len(response.chunks) > 0 # Should get at least one result
|
||||
assert any("document-0" in chunk.metadata["document_id"] for chunk in response.chunks)
|
||||
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
# alpha=0.7 (should be a mix)
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=1,
|
||||
score_threshold=0.0,
|
||||
reranker_type="weighted",
|
||||
reranker_params={"alpha": 0.7},
|
||||
)
|
||||
assert len(response.chunks) > 0 # Should get at least one result
|
||||
assert any("document-0" in chunk.metadata["document_id"] for chunk in response.chunks)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_hybrid_rrf_impact_factor(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
"""Test RRFReRanker with different impact factors."""
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
query_embedding = sample_embeddings[0]
|
||||
query_string = "Sentence 0 from document 0"
|
||||
|
||||
# impact_factor=10
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=1,
|
||||
score_threshold=0.0,
|
||||
reranker_type="rrf",
|
||||
reranker_params={"impact_factor": 10.0},
|
||||
)
|
||||
assert len(response.chunks) == 1
|
||||
assert response.scores[0] == pytest.approx(2.0 / 11.0, rel=1e-6)
|
||||
|
||||
# impact_factor=100
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=1,
|
||||
score_threshold=0.0,
|
||||
reranker_type="rrf",
|
||||
reranker_params={"impact_factor": 100.0},
|
||||
)
|
||||
assert len(response.chunks) == 1
|
||||
assert response.scores[0] == pytest.approx(2.0 / 101.0, rel=1e-6)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_hybrid_edge_cases(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# No results from either search - use a completely different embedding and a nonzero threshold
|
||||
query_embedding = np.ones_like(sample_embeddings[0]) * -1 # Very different from sample embeddings
|
||||
query_string = "no_such_keyword_that_will_never_match"
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=3,
|
||||
score_threshold=0.1, # Nonzero threshold to filter out low-similarity matches
|
||||
reranker_type="rrf",
|
||||
reranker_params={"impact_factor": 60.0},
|
||||
)
|
||||
assert len(response.chunks) == 0
|
||||
|
||||
# All results below threshold
|
||||
query_embedding = sample_embeddings[0]
|
||||
query_string = "Sentence 0 from document 0"
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=3,
|
||||
score_threshold=1000.0,
|
||||
reranker_type="rrf",
|
||||
reranker_params={"impact_factor": 60.0},
|
||||
)
|
||||
assert len(response.chunks) == 0
|
||||
|
||||
# Large k value
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=100,
|
||||
score_threshold=0.0,
|
||||
reranker_type="rrf",
|
||||
reranker_params={"impact_factor": 60.0},
|
||||
)
|
||||
# Should not error, should return all available results
|
||||
assert len(response.chunks) > 0
|
||||
assert len(response.chunks) <= 100
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_hybrid_tie_breaking(
|
||||
sqlite_vec_index, sample_embeddings, embedding_dimension, tmp_path_factory
|
||||
):
|
||||
"""Test tie-breaking and determinism when scores are equal."""
|
||||
# Create two chunks with the same content and embedding
|
||||
chunk1 = Chunk(content="identical", metadata={"document_id": "docA"})
|
||||
chunk2 = Chunk(content="identical", metadata={"document_id": "docB"})
|
||||
chunks = [chunk1, chunk2]
|
||||
# Use the same embedding for both chunks to ensure equal scores
|
||||
same_embedding = sample_embeddings[0]
|
||||
embeddings = np.array([same_embedding, same_embedding])
|
||||
|
||||
# Clear existing data and recreate index
|
||||
await sqlite_vec_index.delete()
|
||||
temp_dir = tmp_path_factory.getbasetemp()
|
||||
db_path = str(temp_dir / "test_sqlite.db")
|
||||
sqlite_vec_index = await SQLiteVecIndex.create(dimension=embedding_dimension, db_path=db_path, bank_id="test_bank")
|
||||
await sqlite_vec_index.add_chunks(chunks, embeddings)
|
||||
|
||||
# Query with the same embedding and content to ensure equal scores
|
||||
query_embedding = same_embedding
|
||||
query_string = "identical"
|
||||
|
||||
# Run multiple queries to verify determinism
|
||||
responses = []
|
||||
for _ in range(3):
|
||||
response = await sqlite_vec_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=2,
|
||||
score_threshold=0.0,
|
||||
reranker_type="rrf",
|
||||
reranker_params={"impact_factor": 60.0},
|
||||
)
|
||||
responses.append(response)
|
||||
|
||||
# Verify all responses are identical
|
||||
first_response = responses[0]
|
||||
for response in responses[1:]:
|
||||
assert response.chunks == first_response.chunks
|
||||
assert response.scores == first_response.scores
|
||||
|
||||
# Verify both chunks are returned with equal scores
|
||||
assert len(first_response.chunks) == 2
|
||||
assert first_response.scores[0] == first_response.scores[1]
|
||||
assert {chunk.metadata["document_id"] for chunk in first_response.chunks} == {"docA", "docB"}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue