mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-24 16:50:00 +00:00
[Feat] Implement keyword search in FAISS
This commit is contained in:
parent
7f834339ba
commit
aa7579efaf
3 changed files with 142 additions and 1 deletions
|
|
@ -9,6 +9,8 @@ import base64
|
|||
import io
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections import Counter
|
||||
from typing import Any
|
||||
|
||||
import faiss
|
||||
|
|
@ -29,6 +31,7 @@ from llama_stack.providers.datatypes import (
|
|||
HealthStatus,
|
||||
VectorDBsProtocolPrivate,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
|
|
@ -49,6 +52,35 @@ OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:{VERSION}::"
|
|||
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:{VERSION}::"
|
||||
|
||||
|
||||
def _tokenize_text(text: str) -> list[str]:
|
||||
"""Tokenize text into words, converting to lowercase and removing punctuation."""
|
||||
words = re.findall(r"\b\w+\b", text.lower())
|
||||
return [word for word in words if len(word) > 2]
|
||||
|
||||
|
||||
def _calculate_tf_idf_score(query_tokens: list[str], document_tokens: list[str]) -> float:
|
||||
"""
|
||||
Calculate a simple TF-IDF-like score for keyword matching.
|
||||
This is a simplified version that doesn't require pre-computed IDF values.
|
||||
"""
|
||||
if not query_tokens or not document_tokens:
|
||||
return 0.0
|
||||
|
||||
query_freq = Counter(query_tokens)
|
||||
doc_freq = Counter(document_tokens)
|
||||
|
||||
score = 0.0
|
||||
for term, query_count in query_freq.items():
|
||||
if term in doc_freq:
|
||||
score += doc_freq[term] * query_count
|
||||
|
||||
if score > 0 and len(document_tokens) > 0:
|
||||
score /= len(document_tokens)
|
||||
return score
|
||||
|
||||
return 0.0
|
||||
|
||||
|
||||
class FaissIndex(EmbeddingIndex):
|
||||
def __init__(self, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
|
||||
self.index = faiss.IndexFlatL2(dimension)
|
||||
|
|
@ -174,7 +206,27 @@ class FaissIndex(EmbeddingIndex):
|
|||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in FAISS")
|
||||
query_tokens = _tokenize_text(query_string)
|
||||
if not query_tokens:
|
||||
return QueryChunksResponse(chunks=[], scores=[])
|
||||
|
||||
# Calculate scores for all chunks
|
||||
chunk_scores = []
|
||||
for chunk in self.chunk_by_index.values():
|
||||
document_content = interleaved_content_as_str(chunk.content)
|
||||
document_tokens = _tokenize_text(document_content)
|
||||
score = _calculate_tf_idf_score(query_tokens, document_tokens)
|
||||
if score > 0 and score >= score_threshold:
|
||||
chunk_scores.append((chunk, score))
|
||||
|
||||
# Sort by score (descending) and take top k
|
||||
chunk_scores.sort(key=lambda x: x[1], reverse=True)
|
||||
top_k = chunk_scores[:k]
|
||||
|
||||
chunks = [chunk for chunk, _ in top_k]
|
||||
scores = [score for _, score in top_k]
|
||||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_hybrid(
|
||||
self,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue