Fix formatting and finalize Milvus BM25 integration

Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
Varsha Prasad Narsing 2025-07-07 12:48:40 -07:00 committed by Varsha
parent 86cca275c1
commit ac039e6bac
3 changed files with 63 additions and 40 deletions

View file

@ -12,7 +12,7 @@ import re
from typing import Any
from numpy.typing import NDArray
from pymilvus import DataType, MilvusClient
from pymilvus import DataType, Function, FunctionType, MilvusClient
from llama_stack.apis.files.files import Files
from llama_stack.apis.inference import Inference, InterleavedContent
@ -113,7 +113,6 @@ class MilvusIndex(EmbeddingIndex):
)
# Add BM25 function for full-text search
from pymilvus import Function, FunctionType
bm25_function = Function(
name="text_bm25_emb",
input_field_names=["content"],
@ -159,7 +158,7 @@ class MilvusIndex(EmbeddingIndex):
anns_field="vector",
limit=k,
output_fields=["*"],
search_params={"metric_type": "COSINE", "params": {"score_threshold": score_threshold}},
search_params={"params": {"radius": score_threshold}},
)
chunks = [Chunk(**res["entity"]["chunk_content"]) for res in search_res[0]]
scores = [res["distance"] for res in search_res[0]]
@ -175,7 +174,7 @@ class MilvusIndex(EmbeddingIndex):
Perform BM25-based keyword search using Milvus's built-in full-text search.
"""
try:
from pymilvus import Function, FunctionType
# Use Milvus's built-in BM25 search
search_res = await asyncio.to_thread(
self.client.search,
collection_name=self.collection_name,
@ -189,24 +188,46 @@ class MilvusIndex(EmbeddingIndex):
}
},
)
chunks = []
scores = []
for res in search_res[0]:
chunk = Chunk(**res["entity"]["chunk_content"])
chunks.append(chunk)
scores.append(res["distance"]) # BM25 score from Milvus
# Filter by score threshold
filtered_results = [(chunk, score) for chunk, score in zip(chunks, scores, strict=False) if score >= score_threshold]
if filtered_results:
chunks, scores = zip(*filtered_results, strict=False)
return QueryChunksResponse(chunks=list(chunks), scores=list(scores))
else:
return QueryChunksResponse(chunks=[], scores=[])
filtered_chunks = [chunk for chunk, score in zip(chunks, scores, strict=False) if score >= score_threshold]
filtered_scores = [score for score in scores if score >= score_threshold]
return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores)
except Exception as e:
logger.error(f"Error performing BM25 search: {e}")
# Fallback to simple text search
return await self._fallback_keyword_search(query_string, k, score_threshold)
async def _fallback_keyword_search(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
"""
Fallback to simple text search when BM25 search is not available.
"""
# Simple text search using content field
search_res = await asyncio.to_thread(
self.client.query,
collection_name=self.collection_name,
filter=f'content like "%{query_string}%"',
output_fields=["*"],
limit=k,
)
chunks = [Chunk(**res["chunk_content"]) for res in search_res]
scores = [1.0] * len(chunks) # Simple binary score for text search
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_hybrid(
self,
embedding: NDArray,