mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-24 20:18:06 +00:00
Fix adding index for BM25 func
Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
parent
ac039e6bac
commit
143bc7eb74
4 changed files with 30 additions and 10 deletions
|
|
@ -8,7 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
|
@ -17,7 +17,7 @@ class MilvusVectorIOConfig(BaseModel):
|
|||
uri: str = Field(description="The URI of the Milvus server")
|
||||
token: str | None = Field(description="The token of the Milvus server")
|
||||
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
|
||||
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
|
||||
kvstore: KVStoreConfig = Field(description="Config for KV store backend")
|
||||
|
||||
# This configuration allows additional fields to be passed through to the underlying Milvus client.
|
||||
# See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general.
|
||||
|
|
@ -25,4 +25,11 @@ class MilvusVectorIOConfig(BaseModel):
|
|||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {"uri": "${env.MILVUS_ENDPOINT}", "token": "${env.MILVUS_TOKEN}"}
|
||||
return {
|
||||
"uri": "${env.MILVUS_ENDPOINT}",
|
||||
"token": "${env.MILVUS_TOKEN}",
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="milvus_remote_registry.db",
|
||||
),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -74,7 +74,9 @@ class MilvusIndex(EmbeddingIndex):
|
|||
assert len(chunks) == len(embeddings), (
|
||||
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
)
|
||||
|
||||
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
||||
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
|
||||
# Create schema for vector search
|
||||
schema = self.client.create_schema()
|
||||
schema.add_field(
|
||||
|
|
@ -98,7 +100,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
field_name="chunk_content",
|
||||
datatype=DataType.JSON,
|
||||
)
|
||||
# Add sparse vector field for BM25
|
||||
# Add sparse vector field for BM25 (required by the function)
|
||||
schema.add_field(
|
||||
field_name="sparse",
|
||||
datatype=DataType.SPARSE_FLOAT_VECTOR,
|
||||
|
|
@ -111,6 +113,12 @@ class MilvusIndex(EmbeddingIndex):
|
|||
index_type="FLAT",
|
||||
metric_type="COSINE",
|
||||
)
|
||||
# Add index for sparse field (required by BM25 function)
|
||||
index_params.add_index(
|
||||
field_name="sparse",
|
||||
index_type="SPARSE_INVERTED_INDEX",
|
||||
metric_type="BM25",
|
||||
)
|
||||
|
||||
# Add BM25 function for full-text search
|
||||
bm25_function = Function(
|
||||
|
|
@ -137,7 +145,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
"content": chunk.content,
|
||||
"vector": embedding,
|
||||
"chunk_content": chunk.model_dump(),
|
||||
# sparse field will be automatically populated by BM25 function
|
||||
# sparse field will be handled by BM25 function automatically
|
||||
}
|
||||
)
|
||||
try:
|
||||
|
|
@ -220,7 +228,8 @@ class MilvusIndex(EmbeddingIndex):
|
|||
search_res = await asyncio.to_thread(
|
||||
self.client.query,
|
||||
collection_name=self.collection_name,
|
||||
filter=f'content like "%{query_string}%"',
|
||||
filter='content like "%{content}%"',
|
||||
filter_params={"content": query_string},
|
||||
output_fields=["*"],
|
||||
limit=k,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue