mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-16 01:53:10 +00:00
Feat: Implement keyword search in milvus
Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
parent
40e2c97915
commit
2cb927b498
3 changed files with 247 additions and 6 deletions
|
@ -12,7 +12,7 @@ import uuid
|
|||
from typing import Any
|
||||
|
||||
from numpy.typing import NDArray
|
||||
from pymilvus import MilvusClient
|
||||
from pymilvus import DataType, MilvusClient
|
||||
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
|
@ -43,6 +43,8 @@ class MilvusIndex(EmbeddingIndex):
|
|||
self.client = client
|
||||
self.collection_name = collection_name.replace("-", "_")
|
||||
self.consistency_level = consistency_level
|
||||
self.bm25 = None
|
||||
self.vectorizer = None
|
||||
|
||||
async def delete(self):
|
||||
if await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
||||
|
@ -53,11 +55,42 @@ class MilvusIndex(EmbeddingIndex):
|
|||
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
)
|
||||
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
||||
# Create schema for vector search
|
||||
schema = self.client.create_schema()
|
||||
schema.add_field(
|
||||
field_name="chunk_id",
|
||||
datatype=DataType.VARCHAR,
|
||||
is_primary=True,
|
||||
max_length=100,
|
||||
)
|
||||
schema.add_field(
|
||||
field_name="content",
|
||||
datatype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
)
|
||||
schema.add_field(
|
||||
field_name="vector",
|
||||
datatype=DataType.FLOAT_VECTOR,
|
||||
dim=len(embeddings[0]),
|
||||
)
|
||||
schema.add_field(
|
||||
field_name="chunk_content",
|
||||
datatype=DataType.JSON,
|
||||
)
|
||||
|
||||
# Create indexes
|
||||
index_params = self.client.prepare_index_params()
|
||||
index_params.add_index(
|
||||
field_name="vector",
|
||||
index_type="FLAT",
|
||||
metric_type="COSINE",
|
||||
)
|
||||
|
||||
await asyncio.to_thread(
|
||||
self.client.create_collection,
|
||||
self.collection_name,
|
||||
dimension=len(embeddings[0]),
|
||||
auto_id=True,
|
||||
schema=schema,
|
||||
index_params=index_params,
|
||||
consistency_level=self.consistency_level,
|
||||
)
|
||||
|
||||
|
@ -68,6 +101,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
data.append(
|
||||
{
|
||||
"chunk_id": chunk_id,
|
||||
"content": chunk.content,
|
||||
"vector": embedding,
|
||||
"chunk_content": chunk.model_dump(),
|
||||
}
|
||||
|
@ -87,9 +121,10 @@ class MilvusIndex(EmbeddingIndex):
|
|||
self.client.search,
|
||||
collection_name=self.collection_name,
|
||||
data=[embedding],
|
||||
anns_field="vector",
|
||||
limit=k,
|
||||
output_fields=["*"],
|
||||
search_params={"params": {"radius": score_threshold}},
|
||||
search_params={"metric_type": "COSINE", "params": {"score_threshold": score_threshold}},
|
||||
)
|
||||
chunks = [Chunk(**res["entity"]["chunk_content"]) for res in search_res[0]]
|
||||
scores = [res["distance"] for res in search_res[0]]
|
||||
|
@ -101,7 +136,17 @@ class MilvusIndex(EmbeddingIndex):
|
|||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in Milvus")
|
||||
# Simple text search using content field
|
||||
search_res = await asyncio.to_thread(
|
||||
self.client.query,
|
||||
collection_name=self.collection_name,
|
||||
filter=f'content like "%{query_string}%"',
|
||||
output_fields=["*"],
|
||||
limit=k,
|
||||
)
|
||||
chunks = [Chunk(**res["chunk_content"]) for res in search_res]
|
||||
scores = [1.0] * len(chunks) # Simple binary score for text search
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_hybrid(
|
||||
self,
|
||||
|
@ -195,6 +240,14 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
if not index:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
|
||||
if params and params.get("mode") == "keyword":
|
||||
# Check if this is inline Milvus (Milvus-Lite)
|
||||
if hasattr(self.config, "db_path"):
|
||||
raise NotImplementedError(
|
||||
"Keyword search is not supported in Milvus-Lite. "
|
||||
"Please use a remote Milvus server for keyword search functionality."
|
||||
)
|
||||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def openai_create_vector_store(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue