[Feat] Implement keyword search in FAISS

This commit is contained in:
Varsha Prasad Narsing 2025-08-05 15:49:14 -07:00
parent 7f834339ba
commit aa7579efaf
3 changed files with 142 additions and 1 deletions

View file

@ -9,6 +9,8 @@ import base64
import io import io
import json import json
import logging import logging
import re
from collections import Counter
from typing import Any from typing import Any
import faiss import faiss
@ -29,6 +31,7 @@ from llama_stack.providers.datatypes import (
HealthStatus, HealthStatus,
VectorDBsProtocolPrivate, VectorDBsProtocolPrivate,
) )
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
@ -49,6 +52,35 @@ OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:{VERSION}::" OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:{VERSION}::"
def _tokenize_text(text: str) -> list[str]:
"""Tokenize text into words, converting to lowercase and removing punctuation."""
words = re.findall(r"\b\w+\b", text.lower())
return [word for word in words if len(word) > 2]
def _calculate_tf_idf_score(query_tokens: list[str], document_tokens: list[str]) -> float:
"""
Calculate a simple TF-IDF-like score for keyword matching.
This is a simplified version that doesn't require pre-computed IDF values.
"""
if not query_tokens or not document_tokens:
return 0.0
query_freq = Counter(query_tokens)
doc_freq = Counter(document_tokens)
score = 0.0
for term, query_count in query_freq.items():
if term in doc_freq:
score += doc_freq[term] * query_count
if score > 0 and len(document_tokens) > 0:
score /= len(document_tokens)
return score
return 0.0
class FaissIndex(EmbeddingIndex): class FaissIndex(EmbeddingIndex):
def __init__(self, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None): def __init__(self, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
self.index = faiss.IndexFlatL2(dimension) self.index = faiss.IndexFlatL2(dimension)
@ -174,7 +206,27 @@ class FaissIndex(EmbeddingIndex):
k: int, k: int,
score_threshold: float, score_threshold: float,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in FAISS") query_tokens = _tokenize_text(query_string)
if not query_tokens:
return QueryChunksResponse(chunks=[], scores=[])
# Calculate scores for all chunks
chunk_scores = []
for chunk in self.chunk_by_index.values():
document_content = interleaved_content_as_str(chunk.content)
document_tokens = _tokenize_text(document_content)
score = _calculate_tf_idf_score(query_tokens, document_tokens)
if score > 0 and score >= score_threshold:
chunk_scores.append((chunk, score))
# Sort by score (descending) and take top k
chunk_scores.sort(key=lambda x: x[1], reverse=True)
top_k = chunk_scores[:k]
chunks = [chunk for chunk, _ in top_k]
scores = [score for _, score in top_k]
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_hybrid( async def query_hybrid(
self, self,

View file

@ -52,6 +52,7 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode
], ],
"keyword": [ "keyword": [
"inline::sqlite-vec", "inline::sqlite-vec",
"inline::faiss",
], ],
"hybrid": [ "hybrid": [
"inline::sqlite-vec", "inline::sqlite-vec",

View file

@ -176,3 +176,91 @@ async def test_health_failure():
assert isinstance(response, dict) assert isinstance(response, dict)
assert response["status"] == HealthStatus.ERROR assert response["status"] == HealthStatus.ERROR
assert response["message"] == "Health check failed: Test error" assert response["message"] == "Health check failed: Test error"
# Keyword Search Tests
@pytest.fixture
def keyword_search_chunks():
return [
Chunk(
content="Python is a high-level programming language that emphasizes code readability.",
metadata={"document_id": "doc1", "topic": "programming"},
),
Chunk(
content="Machine learning is a subset of artificial intelligence that enables systems to learn automatically.",
metadata={"document_id": "doc2", "topic": "ai"},
),
Chunk(
content="Data structures are fundamental to computer science and enable efficient data processing.",
metadata={"document_id": "doc3", "topic": "computer_science"},
),
Chunk(
content="Neural networks are inspired by biological neural networks and use interconnected nodes.",
metadata={"document_id": "doc4", "topic": "ai"},
),
]
@pytest.fixture
def keyword_search_embeddings(embedding_dimension):
return np.random.rand(4, embedding_dimension).astype(np.float32)
async def test_faiss_keyword_search_basic(faiss_index, keyword_search_chunks, keyword_search_embeddings):
"""Test basic keyword search functionality."""
await faiss_index.add_chunks(keyword_search_chunks, keyword_search_embeddings)
response = await faiss_index.query_keyword("Python", k=2, score_threshold=0.0)
assert len(response.chunks) > 0
assert "Python" in response.chunks[0].content
response = await faiss_index.query_keyword("machine learning", k=2, score_threshold=0.0)
assert len(response.chunks) > 0
assert "machine learning" in response.chunks[0].content.lower()
async def test_faiss_keyword_search_no_matches(faiss_index, keyword_search_chunks, keyword_search_embeddings):
"""Test keyword search when no matches are found."""
await faiss_index.add_chunks(keyword_search_chunks, keyword_search_embeddings)
# Test with a term that doesn't exist
response = await faiss_index.query_keyword("nonexistent", k=2, score_threshold=0.0)
assert len(response.chunks) == 0
assert len(response.scores) == 0
async def test_faiss_keyword_search_score_threshold(faiss_index, keyword_search_chunks, keyword_search_embeddings):
"""Test that score threshold filtering works correctly."""
await faiss_index.add_chunks(keyword_search_chunks, keyword_search_embeddings)
response = await faiss_index.query_keyword("Python", k=2, score_threshold=100.0)
assert len(response.chunks) == 0
async def test_faiss_keyword_search_empty_index(faiss_index):
"""Test keyword search on empty index."""
response = await faiss_index.query_keyword("Python", k=2, score_threshold=0.0)
assert len(response.chunks) == 0
assert len(response.scores) == 0
async def test_faiss_keyword_search_empty_query(faiss_index, keyword_search_chunks, keyword_search_embeddings):
"""Test keyword search with empty query."""
await faiss_index.add_chunks(keyword_search_chunks, keyword_search_embeddings)
response = await faiss_index.query_keyword("", k=2, score_threshold=0.0)
assert len(response.chunks) == 0
assert len(response.scores) == 0
async def test_faiss_keyword_search_case_insensitive(faiss_index, keyword_search_chunks, keyword_search_embeddings):
"""Test that keyword search is case insensitive."""
await faiss_index.add_chunks(keyword_search_chunks, keyword_search_embeddings)
# Test with different cases
response1 = await faiss_index.query_keyword("python", k=2, score_threshold=0.0)
response2 = await faiss_index.query_keyword("PYTHON", k=2, score_threshold=0.0)
response3 = await faiss_index.query_keyword("Python", k=2, score_threshold=0.0)
# All should return the same results
assert len(response1.chunks) == len(response2.chunks) == len(response3.chunks)