Feat: Implement keyword search in milvus

Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
Varsha Prasad Narsing 2025-05-22 15:21:40 -07:00
parent 40e2c97915
commit 2cb927b498
3 changed files with 247 additions and 6 deletions

View file

@ -12,7 +12,7 @@ import uuid
from typing import Any
from numpy.typing import NDArray
from pymilvus import MilvusClient
from pymilvus import DataType, MilvusClient
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
@ -43,6 +43,8 @@ class MilvusIndex(EmbeddingIndex):
self.client = client
self.collection_name = collection_name.replace("-", "_")
self.consistency_level = consistency_level
self.bm25 = None
self.vectorizer = None
async def delete(self):
if await asyncio.to_thread(self.client.has_collection, self.collection_name):
@ -53,11 +55,42 @@ class MilvusIndex(EmbeddingIndex):
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
)
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
# Create schema for vector search
schema = self.client.create_schema()
schema.add_field(
field_name="chunk_id",
datatype=DataType.VARCHAR,
is_primary=True,
max_length=100,
)
schema.add_field(
field_name="content",
datatype=DataType.VARCHAR,
max_length=65535,
)
schema.add_field(
field_name="vector",
datatype=DataType.FLOAT_VECTOR,
dim=len(embeddings[0]),
)
schema.add_field(
field_name="chunk_content",
datatype=DataType.JSON,
)
# Create indexes
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="vector",
index_type="FLAT",
metric_type="COSINE",
)
await asyncio.to_thread(
self.client.create_collection,
self.collection_name,
dimension=len(embeddings[0]),
auto_id=True,
schema=schema,
index_params=index_params,
consistency_level=self.consistency_level,
)
@ -68,6 +101,7 @@ class MilvusIndex(EmbeddingIndex):
data.append(
{
"chunk_id": chunk_id,
"content": chunk.content,
"vector": embedding,
"chunk_content": chunk.model_dump(),
}
@ -87,9 +121,10 @@ class MilvusIndex(EmbeddingIndex):
self.client.search,
collection_name=self.collection_name,
data=[embedding],
anns_field="vector",
limit=k,
output_fields=["*"],
search_params={"params": {"radius": score_threshold}},
search_params={"metric_type": "COSINE", "params": {"score_threshold": score_threshold}},
)
chunks = [Chunk(**res["entity"]["chunk_content"]) for res in search_res[0]]
scores = [res["distance"] for res in search_res[0]]
@ -101,7 +136,17 @@ class MilvusIndex(EmbeddingIndex):
k: int,
score_threshold: float,
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Milvus")
# Simple text search using content field
search_res = await asyncio.to_thread(
self.client.query,
collection_name=self.collection_name,
filter=f'content like "%{query_string}%"',
output_fields=["*"],
limit=k,
)
chunks = [Chunk(**res["chunk_content"]) for res in search_res]
scores = [1.0] * len(chunks) # Simple binary score for text search
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_hybrid(
self,
@ -195,6 +240,14 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
if params and params.get("mode") == "keyword":
# Check if this is inline Milvus (Milvus-Lite)
if hasattr(self.config, "db_path"):
raise NotImplementedError(
"Keyword search is not supported in Milvus-Lite. "
"Please use a remote Milvus server for keyword search functionality."
)
return await index.query_chunks(query, params)
async def openai_create_vector_store(