mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 18:51:55 +00:00
Fix formatting and finalize Milvus BM25 integration
Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
parent
86cca275c1
commit
ac039e6bac
3 changed files with 63 additions and 40 deletions
|
|
@ -12,7 +12,7 @@ import re
|
|||
from typing import Any
|
||||
|
||||
from numpy.typing import NDArray
|
||||
from pymilvus import DataType, MilvusClient
|
||||
from pymilvus import DataType, Function, FunctionType, MilvusClient
|
||||
|
||||
from llama_stack.apis.files.files import Files
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
|
|
@ -113,7 +113,6 @@ class MilvusIndex(EmbeddingIndex):
|
|||
)
|
||||
|
||||
# Add BM25 function for full-text search
|
||||
from pymilvus import Function, FunctionType
|
||||
bm25_function = Function(
|
||||
name="text_bm25_emb",
|
||||
input_field_names=["content"],
|
||||
|
|
@ -159,7 +158,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
anns_field="vector",
|
||||
limit=k,
|
||||
output_fields=["*"],
|
||||
search_params={"metric_type": "COSINE", "params": {"score_threshold": score_threshold}},
|
||||
search_params={"params": {"radius": score_threshold}},
|
||||
)
|
||||
chunks = [Chunk(**res["entity"]["chunk_content"]) for res in search_res[0]]
|
||||
scores = [res["distance"] for res in search_res[0]]
|
||||
|
|
@ -175,7 +174,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
Perform BM25-based keyword search using Milvus's built-in full-text search.
|
||||
"""
|
||||
try:
|
||||
from pymilvus import Function, FunctionType
|
||||
# Use Milvus's built-in BM25 search
|
||||
search_res = await asyncio.to_thread(
|
||||
self.client.search,
|
||||
collection_name=self.collection_name,
|
||||
|
|
@ -189,24 +188,46 @@ class MilvusIndex(EmbeddingIndex):
|
|||
}
|
||||
},
|
||||
)
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
for res in search_res[0]:
|
||||
chunk = Chunk(**res["entity"]["chunk_content"])
|
||||
chunks.append(chunk)
|
||||
scores.append(res["distance"]) # BM25 score from Milvus
|
||||
|
||||
# Filter by score threshold
|
||||
filtered_results = [(chunk, score) for chunk, score in zip(chunks, scores, strict=False) if score >= score_threshold]
|
||||
if filtered_results:
|
||||
chunks, scores = zip(*filtered_results, strict=False)
|
||||
return QueryChunksResponse(chunks=list(chunks), scores=list(scores))
|
||||
else:
|
||||
return QueryChunksResponse(chunks=[], scores=[])
|
||||
filtered_chunks = [chunk for chunk, score in zip(chunks, scores, strict=False) if score >= score_threshold]
|
||||
filtered_scores = [score for score in scores if score >= score_threshold]
|
||||
|
||||
return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error performing BM25 search: {e}")
|
||||
# Fallback to simple text search
|
||||
return await self._fallback_keyword_search(query_string, k, score_threshold)
|
||||
|
||||
async def _fallback_keyword_search(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
"""
|
||||
Fallback to simple text search when BM25 search is not available.
|
||||
"""
|
||||
# Simple text search using content field
|
||||
search_res = await asyncio.to_thread(
|
||||
self.client.query,
|
||||
collection_name=self.collection_name,
|
||||
filter=f'content like "%{query_string}%"',
|
||||
output_fields=["*"],
|
||||
limit=k,
|
||||
)
|
||||
chunks = [Chunk(**res["chunk_content"]) for res in search_res]
|
||||
scores = [1.0] * len(chunks) # Simple binary score for text search
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_hybrid(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue