mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 14:08:00 +00:00
[Feat] Implement keyword search in FAISS
This commit is contained in:
parent
7f834339ba
commit
aa7579efaf
3 changed files with 142 additions and 1 deletions
|
@ -9,6 +9,8 @@ import base64
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
|
from collections import Counter
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import faiss
|
import faiss
|
||||||
|
@ -29,6 +31,7 @@ from llama_stack.providers.datatypes import (
|
||||||
HealthStatus,
|
HealthStatus,
|
||||||
VectorDBsProtocolPrivate,
|
VectorDBsProtocolPrivate,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||||
|
@ -49,6 +52,35 @@ OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:{VERSION}::"
|
||||||
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:{VERSION}::"
|
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:{VERSION}::"
|
||||||
|
|
||||||
|
|
||||||
|
def _tokenize_text(text: str) -> list[str]:
|
||||||
|
"""Tokenize text into words, converting to lowercase and removing punctuation."""
|
||||||
|
words = re.findall(r"\b\w+\b", text.lower())
|
||||||
|
return [word for word in words if len(word) > 2]
|
||||||
|
|
||||||
|
|
||||||
|
def _calculate_tf_idf_score(query_tokens: list[str], document_tokens: list[str]) -> float:
|
||||||
|
"""
|
||||||
|
Calculate a simple TF-IDF-like score for keyword matching.
|
||||||
|
This is a simplified version that doesn't require pre-computed IDF values.
|
||||||
|
"""
|
||||||
|
if not query_tokens or not document_tokens:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
query_freq = Counter(query_tokens)
|
||||||
|
doc_freq = Counter(document_tokens)
|
||||||
|
|
||||||
|
score = 0.0
|
||||||
|
for term, query_count in query_freq.items():
|
||||||
|
if term in doc_freq:
|
||||||
|
score += doc_freq[term] * query_count
|
||||||
|
|
||||||
|
if score > 0 and len(document_tokens) > 0:
|
||||||
|
score /= len(document_tokens)
|
||||||
|
return score
|
||||||
|
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
class FaissIndex(EmbeddingIndex):
|
class FaissIndex(EmbeddingIndex):
|
||||||
def __init__(self, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
|
def __init__(self, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
|
||||||
self.index = faiss.IndexFlatL2(dimension)
|
self.index = faiss.IndexFlatL2(dimension)
|
||||||
|
@ -174,7 +206,27 @@ class FaissIndex(EmbeddingIndex):
|
||||||
k: int,
|
k: int,
|
||||||
score_threshold: float,
|
score_threshold: float,
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
raise NotImplementedError("Keyword search is not supported in FAISS")
|
query_tokens = _tokenize_text(query_string)
|
||||||
|
if not query_tokens:
|
||||||
|
return QueryChunksResponse(chunks=[], scores=[])
|
||||||
|
|
||||||
|
# Calculate scores for all chunks
|
||||||
|
chunk_scores = []
|
||||||
|
for chunk in self.chunk_by_index.values():
|
||||||
|
document_content = interleaved_content_as_str(chunk.content)
|
||||||
|
document_tokens = _tokenize_text(document_content)
|
||||||
|
score = _calculate_tf_idf_score(query_tokens, document_tokens)
|
||||||
|
if score > 0 and score >= score_threshold:
|
||||||
|
chunk_scores.append((chunk, score))
|
||||||
|
|
||||||
|
# Sort by score (descending) and take top k
|
||||||
|
chunk_scores.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
top_k = chunk_scores[:k]
|
||||||
|
|
||||||
|
chunks = [chunk for chunk, _ in top_k]
|
||||||
|
scores = [score for _, score in top_k]
|
||||||
|
|
||||||
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
async def query_hybrid(
|
async def query_hybrid(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -52,6 +52,7 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode
|
||||||
],
|
],
|
||||||
"keyword": [
|
"keyword": [
|
||||||
"inline::sqlite-vec",
|
"inline::sqlite-vec",
|
||||||
|
"inline::faiss",
|
||||||
],
|
],
|
||||||
"hybrid": [
|
"hybrid": [
|
||||||
"inline::sqlite-vec",
|
"inline::sqlite-vec",
|
||||||
|
|
|
@ -176,3 +176,91 @@ async def test_health_failure():
|
||||||
assert isinstance(response, dict)
|
assert isinstance(response, dict)
|
||||||
assert response["status"] == HealthStatus.ERROR
|
assert response["status"] == HealthStatus.ERROR
|
||||||
assert response["message"] == "Health check failed: Test error"
|
assert response["message"] == "Health check failed: Test error"
|
||||||
|
|
||||||
|
|
||||||
|
# Keyword Search Tests
|
||||||
|
@pytest.fixture
|
||||||
|
def keyword_search_chunks():
|
||||||
|
return [
|
||||||
|
Chunk(
|
||||||
|
content="Python is a high-level programming language that emphasizes code readability.",
|
||||||
|
metadata={"document_id": "doc1", "topic": "programming"},
|
||||||
|
),
|
||||||
|
Chunk(
|
||||||
|
content="Machine learning is a subset of artificial intelligence that enables systems to learn automatically.",
|
||||||
|
metadata={"document_id": "doc2", "topic": "ai"},
|
||||||
|
),
|
||||||
|
Chunk(
|
||||||
|
content="Data structures are fundamental to computer science and enable efficient data processing.",
|
||||||
|
metadata={"document_id": "doc3", "topic": "computer_science"},
|
||||||
|
),
|
||||||
|
Chunk(
|
||||||
|
content="Neural networks are inspired by biological neural networks and use interconnected nodes.",
|
||||||
|
metadata={"document_id": "doc4", "topic": "ai"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def keyword_search_embeddings(embedding_dimension):
|
||||||
|
return np.random.rand(4, embedding_dimension).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_faiss_keyword_search_basic(faiss_index, keyword_search_chunks, keyword_search_embeddings):
|
||||||
|
"""Test basic keyword search functionality."""
|
||||||
|
await faiss_index.add_chunks(keyword_search_chunks, keyword_search_embeddings)
|
||||||
|
|
||||||
|
response = await faiss_index.query_keyword("Python", k=2, score_threshold=0.0)
|
||||||
|
assert len(response.chunks) > 0
|
||||||
|
assert "Python" in response.chunks[0].content
|
||||||
|
|
||||||
|
response = await faiss_index.query_keyword("machine learning", k=2, score_threshold=0.0)
|
||||||
|
assert len(response.chunks) > 0
|
||||||
|
assert "machine learning" in response.chunks[0].content.lower()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_faiss_keyword_search_no_matches(faiss_index, keyword_search_chunks, keyword_search_embeddings):
|
||||||
|
"""Test keyword search when no matches are found."""
|
||||||
|
await faiss_index.add_chunks(keyword_search_chunks, keyword_search_embeddings)
|
||||||
|
|
||||||
|
# Test with a term that doesn't exist
|
||||||
|
response = await faiss_index.query_keyword("nonexistent", k=2, score_threshold=0.0)
|
||||||
|
assert len(response.chunks) == 0
|
||||||
|
assert len(response.scores) == 0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_faiss_keyword_search_score_threshold(faiss_index, keyword_search_chunks, keyword_search_embeddings):
|
||||||
|
"""Test that score threshold filtering works correctly."""
|
||||||
|
await faiss_index.add_chunks(keyword_search_chunks, keyword_search_embeddings)
|
||||||
|
|
||||||
|
response = await faiss_index.query_keyword("Python", k=2, score_threshold=100.0)
|
||||||
|
assert len(response.chunks) == 0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_faiss_keyword_search_empty_index(faiss_index):
|
||||||
|
"""Test keyword search on empty index."""
|
||||||
|
response = await faiss_index.query_keyword("Python", k=2, score_threshold=0.0)
|
||||||
|
assert len(response.chunks) == 0
|
||||||
|
assert len(response.scores) == 0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_faiss_keyword_search_empty_query(faiss_index, keyword_search_chunks, keyword_search_embeddings):
|
||||||
|
"""Test keyword search with empty query."""
|
||||||
|
await faiss_index.add_chunks(keyword_search_chunks, keyword_search_embeddings)
|
||||||
|
|
||||||
|
response = await faiss_index.query_keyword("", k=2, score_threshold=0.0)
|
||||||
|
assert len(response.chunks) == 0
|
||||||
|
assert len(response.scores) == 0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_faiss_keyword_search_case_insensitive(faiss_index, keyword_search_chunks, keyword_search_embeddings):
|
||||||
|
"""Test that keyword search is case insensitive."""
|
||||||
|
await faiss_index.add_chunks(keyword_search_chunks, keyword_search_embeddings)
|
||||||
|
|
||||||
|
# Test with different cases
|
||||||
|
response1 = await faiss_index.query_keyword("python", k=2, score_threshold=0.0)
|
||||||
|
response2 = await faiss_index.query_keyword("PYTHON", k=2, score_threshold=0.0)
|
||||||
|
response3 = await faiss_index.query_keyword("Python", k=2, score_threshold=0.0)
|
||||||
|
|
||||||
|
# All should return the same results
|
||||||
|
assert len(response1.chunks) == len(response2.chunks) == len(response3.chunks)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue