From aa7579efaf1ed30fd580ed9d7048878407d57130 Mon Sep 17 00:00:00 2001 From: Varsha Prasad Narsing Date: Tue, 5 Aug 2025 15:49:14 -0700 Subject: [PATCH] [Feat] Implement keyword search in FAISS --- .../providers/inline/vector_io/faiss/faiss.py | 54 +++++++++++- .../vector_io/test_openai_vector_stores.py | 1 + tests/unit/providers/vector_io/test_faiss.py | 88 +++++++++++++++++++ 3 files changed, 142 insertions(+), 1 deletion(-) diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index 7a5373726..03fd9955c 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -9,6 +9,8 @@ import base64 import io import json import logging +import re +from collections import Counter from typing import Any import faiss @@ -29,6 +31,7 @@ from llama_stack.providers.datatypes import ( HealthStatus, VectorDBsProtocolPrivate, ) +from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin @@ -49,6 +52,35 @@ OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:{VERSION}::" OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:{VERSION}::" +def _tokenize_text(text: str) -> list[str]: + """Tokenize text into words, converting to lowercase and removing punctuation.""" + words = re.findall(r"\b\w+\b", text.lower()) + return [word for word in words if len(word) > 2] + + +def _calculate_tf_idf_score(query_tokens: list[str], document_tokens: list[str]) -> float: + """ + Calculate a simple TF-IDF-like score for keyword matching. + This is a simplified version that doesn't require pre-computed IDF values. + """ + if not query_tokens or not document_tokens: + return 0.0 + + query_freq = Counter(query_tokens) + doc_freq = Counter(document_tokens) + + score = 0.0 + for term, query_count in query_freq.items(): + if term in doc_freq: + score += doc_freq[term] * query_count + + if score > 0 and len(document_tokens) > 0: + score /= len(document_tokens) + return score + + return 0.0 + + class FaissIndex(EmbeddingIndex): def __init__(self, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None): self.index = faiss.IndexFlatL2(dimension) @@ -174,7 +206,27 @@ class FaissIndex(EmbeddingIndex): k: int, score_threshold: float, ) -> QueryChunksResponse: - raise NotImplementedError("Keyword search is not supported in FAISS") + query_tokens = _tokenize_text(query_string) + if not query_tokens: + return QueryChunksResponse(chunks=[], scores=[]) + + # Calculate scores for all chunks + chunk_scores = [] + for chunk in self.chunk_by_index.values(): + document_content = interleaved_content_as_str(chunk.content) + document_tokens = _tokenize_text(document_content) + score = _calculate_tf_idf_score(query_tokens, document_tokens) + if score > 0 and score >= score_threshold: + chunk_scores.append((chunk, score)) + + # Sort by score (descending) and take top k + chunk_scores.sort(key=lambda x: x[1], reverse=True) + top_k = chunk_scores[:k] + + chunks = [chunk for chunk, _ in top_k] + scores = [score for _, score in top_k] + + return QueryChunksResponse(chunks=chunks, scores=scores) async def query_hybrid( self, diff --git a/tests/integration/vector_io/test_openai_vector_stores.py b/tests/integration/vector_io/test_openai_vector_stores.py index 1c9ef92b6..4037618c7 100644 --- a/tests/integration/vector_io/test_openai_vector_stores.py +++ b/tests/integration/vector_io/test_openai_vector_stores.py @@ -52,6 +52,7 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode ], "keyword": [ "inline::sqlite-vec", + "inline::faiss", ], "hybrid": [ "inline::sqlite-vec", diff --git a/tests/unit/providers/vector_io/test_faiss.py b/tests/unit/providers/vector_io/test_faiss.py index 90108d7a0..7764b9253 100644 --- a/tests/unit/providers/vector_io/test_faiss.py +++ b/tests/unit/providers/vector_io/test_faiss.py @@ -176,3 +176,91 @@ async def test_health_failure(): assert isinstance(response, dict) assert response["status"] == HealthStatus.ERROR assert response["message"] == "Health check failed: Test error" + + +# Keyword Search Tests +@pytest.fixture +def keyword_search_chunks(): + return [ + Chunk( + content="Python is a high-level programming language that emphasizes code readability.", + metadata={"document_id": "doc1", "topic": "programming"}, + ), + Chunk( + content="Machine learning is a subset of artificial intelligence that enables systems to learn automatically.", + metadata={"document_id": "doc2", "topic": "ai"}, + ), + Chunk( + content="Data structures are fundamental to computer science and enable efficient data processing.", + metadata={"document_id": "doc3", "topic": "computer_science"}, + ), + Chunk( + content="Neural networks are inspired by biological neural networks and use interconnected nodes.", + metadata={"document_id": "doc4", "topic": "ai"}, + ), + ] + + +@pytest.fixture +def keyword_search_embeddings(embedding_dimension): + return np.random.rand(4, embedding_dimension).astype(np.float32) + + +async def test_faiss_keyword_search_basic(faiss_index, keyword_search_chunks, keyword_search_embeddings): + """Test basic keyword search functionality.""" + await faiss_index.add_chunks(keyword_search_chunks, keyword_search_embeddings) + + response = await faiss_index.query_keyword("Python", k=2, score_threshold=0.0) + assert len(response.chunks) > 0 + assert "Python" in response.chunks[0].content + + response = await faiss_index.query_keyword("machine learning", k=2, score_threshold=0.0) + assert len(response.chunks) > 0 + assert "machine learning" in response.chunks[0].content.lower() + + +async def test_faiss_keyword_search_no_matches(faiss_index, keyword_search_chunks, keyword_search_embeddings): + """Test keyword search when no matches are found.""" + await faiss_index.add_chunks(keyword_search_chunks, keyword_search_embeddings) + + # Test with a term that doesn't exist + response = await faiss_index.query_keyword("nonexistent", k=2, score_threshold=0.0) + assert len(response.chunks) == 0 + assert len(response.scores) == 0 + + +async def test_faiss_keyword_search_score_threshold(faiss_index, keyword_search_chunks, keyword_search_embeddings): + """Test that score threshold filtering works correctly.""" + await faiss_index.add_chunks(keyword_search_chunks, keyword_search_embeddings) + + response = await faiss_index.query_keyword("Python", k=2, score_threshold=100.0) + assert len(response.chunks) == 0 + + +async def test_faiss_keyword_search_empty_index(faiss_index): + """Test keyword search on empty index.""" + response = await faiss_index.query_keyword("Python", k=2, score_threshold=0.0) + assert len(response.chunks) == 0 + assert len(response.scores) == 0 + + +async def test_faiss_keyword_search_empty_query(faiss_index, keyword_search_chunks, keyword_search_embeddings): + """Test keyword search with empty query.""" + await faiss_index.add_chunks(keyword_search_chunks, keyword_search_embeddings) + + response = await faiss_index.query_keyword("", k=2, score_threshold=0.0) + assert len(response.chunks) == 0 + assert len(response.scores) == 0 + + +async def test_faiss_keyword_search_case_insensitive(faiss_index, keyword_search_chunks, keyword_search_embeddings): + """Test that keyword search is case insensitive.""" + await faiss_index.add_chunks(keyword_search_chunks, keyword_search_embeddings) + + # Test with different cases + response1 = await faiss_index.query_keyword("python", k=2, score_threshold=0.0) + response2 = await faiss_index.query_keyword("PYTHON", k=2, score_threshold=0.0) + response3 = await faiss_index.query_keyword("Python", k=2, score_threshold=0.0) + + # All should return the same results + assert len(response1.chunks) == len(response2.chunks) == len(response3.chunks)