fix: fixed Milvus integration code

Signed-off-by: ChengZi <chen.zhang@zilliz.com>
This commit is contained in:
Cheney Zhang 2025-02-25 20:30:54 +08:00 committed by Ashwin Bharambe
parent 5d0d4c3467
commit fa67d79bfe
8 changed files with 77 additions and 162 deletions

View file

@ -68,6 +68,7 @@ A number of "adapters" are available for some popular Inference and Vector Store
| FAISS | Single Node | | FAISS | Single Node |
| SQLite-Vec| Single Node | | SQLite-Vec| Single Node |
| Chroma | Hosted and Single Node | | Chroma | Hosted and Single Node |
| Milvus | Hosted and Single Node |
| Postgres (PGVector) | Hosted and Single Node | | Postgres (PGVector) | Hosted and Single Node |
| Weaviate | Hosted | | Weaviate | Hosted |

View file

@ -2,7 +2,7 @@
The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include: The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include:
- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.), - LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.),
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, FAISS, PGVector, etc.), - Vector databases (e.g., ChromaDB, Weaviate, Qdrant, Milvus, FAISS, PGVector, etc.),
- Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.) - Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.)
Providers come in two flavors: Providers come in two flavors:
@ -55,5 +55,6 @@ vector_io/sqlite-vec
vector_io/chromadb vector_io/chromadb
vector_io/pgvector vector_io/pgvector
vector_io/qdrant vector_io/qdrant
vector_io/milvus
vector_io/weaviate vector_io/weaviate
``` ```

View file

@ -0,0 +1,31 @@
---
orphan: true
---
# Milvus
[Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It
allows you to store and query vectors directly within a Milvus database.
That means you're not limited to storing vectors in memory or in a separate service.
## Features
- Easy to use
- Fully integrated with Llama Stack
## Usage
To use Milvus in your Llama Stack project, follow these steps:
1. Install the necessary dependencies.
2. Configure your Llama Stack project to use Milvus.
3. Start storing and querying vectors.
## Installation
You can install Milvus using pymilvus:
```bash
pip install pymilvus
```
## Documentation
See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for more details about Milvus in general.

View file

@ -12,9 +12,7 @@ from .config import MilvusVectorIOConfig
async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, ProviderSpec]): async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, ProviderSpec]):
from .milvus import MilvusVectorIOAdapter from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = MilvusVectorIOAdapter(config, deps[Api.inference]) impl = MilvusVectorIOAdapter(config, deps[Api.inference])
await impl.initialize() await impl.initialize()

View file

@ -1,143 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import os
from typing import Any, Dict, List, Optional
from numpy.typing import NDArray
from pymilvus import MilvusClient
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex,
VectorDBWithIndex,
)
from .config import MilvusVectorIOConfig
logger = logging.getLogger(__name__)
class MilvusIndex(EmbeddingIndex):
def __init__(self, client: MilvusClient, collection_name: str):
self.client = client
self.collection_name = collection_name.replace("-", "_")
async def delete(self):
if self.client.has_collection(self.collection_name):
self.client.drop_collection(collection_name=self.collection_name)
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
)
if not self.client.has_collection(self.collection_name):
self.client.create_collection(self.collection_name, dimension=len(embeddings[0]), auto_id=True)
data = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings, strict=False)):
chunk_id = f"{chunk.metadata['document_id']}:chunk-{i}"
data.append(
{
"chunk_id": chunk_id,
"vector": embedding,
"chunk_content": chunk.model_dump(),
}
)
self.client.insert(
self.collection_name,
data=data,
)
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
search_res = self.client.search(
collection_name=self.collection_name,
data=[embedding],
limit=k,
output_fields=["*"],
search_params={"params": {"radius": score_threshold}},
)
chunks = [Chunk(**res["entity"]["chunk_content"]) for res in search_res[0]]
scores = [res["distance"] for res in search_res[0]]
return QueryChunksResponse(chunks=chunks, scores=scores)
class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: MilvusVectorIOConfig, inference_api: Api.inference) -> None:
self.config = config
uri = self.config.model_dump(exclude_none=True)["db_path"]
uri = os.path.expanduser(uri)
self.client = MilvusClient(uri=uri)
self.cache = {}
self.inference_api = inference_api
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
self.client.close()
async def register_vector_db(
self,
vector_db: VectorDB,
) -> None:
index = VectorDBWithIndex(
vector_db=vector_db,
index=MilvusIndex(self.client, vector_db.identifier),
inference_api=self.inference_api,
)
self.cache[vector_db.identifier] = index
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> Optional[VectorDBWithIndex]:
if vector_db_id in self.cache:
return self.cache[vector_db_id]
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
if not vector_db:
raise ValueError(f"Vector DB {vector_db_id} not found")
index = VectorDBWithIndex(
vector_db=vector_db,
index=MilvusIndex(client=self.client, collection_name=vector_db.identifier),
inference_api=self.inference_api,
)
self.cache[vector_db_id] = index
return index
async def unregister_vector_db(self, vector_db_id: str) -> None:
if vector_db_id in self.cache:
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
async def insert_chunks(
self,
vector_db_id: str,
chunks: List[Chunk],
ttl_seconds: Optional[int] = None,
) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
await index.insert_chunks(chunks)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
return await index.query_chunks(query, params)

View file

@ -114,7 +114,7 @@ def available_providers() -> List[ProviderSpec]:
Api.vector_io, Api.vector_io,
AdapterSpec( AdapterSpec(
adapter_type="milvus", adapter_type="milvus",
pip_packages=EMBEDDING_DEPS + ["pymilvus"], pip_packages=["pymilvus"],
module="llama_stack.providers.remote.vector_io.milvus", module="llama_stack.providers.remote.vector_io.milvus",
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig", config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
), ),
@ -123,7 +123,7 @@ def available_providers() -> List[ProviderSpec]:
InlineProviderSpec( InlineProviderSpec(
api=Api.vector_io, api=Api.vector_io,
provider_type="inline::milvus", provider_type="inline::milvus",
pip_packages=EMBEDDING_DEPS + ["pymilvus"], pip_packages=["pymilvus"],
module="llama_stack.providers.inline.vector_io.milvus", module="llama_stack.providers.inline.vector_io.milvus",
config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig", config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig",
api_dependencies=[Api.inference], api_dependencies=[Api.inference],

View file

@ -15,6 +15,7 @@ from llama_stack.schema_utils import json_schema_type
class MilvusVectorIOConfig(BaseModel): class MilvusVectorIOConfig(BaseModel):
uri: str uri: str
token: Optional[str] = None token: Optional[str] = None
consistency_level: str = "Strong"
@classmethod @classmethod
def sample_config(cls) -> Dict[str, Any]: def sample_config(cls) -> Dict[str, Any]:

View file

@ -5,8 +5,11 @@
# the root directory of this source tree. # the root directory of this source tree.
import logging import logging
from typing import Any, Dict, List, Optional import os
from typing import Any, Dict, List, Optional, Union
import hashlib
import uuid
from numpy.typing import NDArray from numpy.typing import NDArray
from pymilvus import MilvusClient from pymilvus import MilvusClient
@ -14,20 +17,22 @@ from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex, EmbeddingIndex,
VectorDBWithIndex, VectorDBWithIndex,
) )
from .config import MilvusVectorIOConfig from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class MilvusIndex(EmbeddingIndex): class MilvusIndex(EmbeddingIndex):
def __init__(self, client: MilvusClient, collection_name: str): def __init__(self, client: MilvusClient, collection_name: str, consistency_level="Strong"):
self.client = client self.client = client
self.collection_name = collection_name.replace("-", "_") self.collection_name = collection_name.replace("-", "_")
self.consistency_level = consistency_level
async def delete(self): async def delete(self):
if self.client.has_collection(self.collection_name): if self.client.has_collection(self.collection_name):
@ -38,11 +43,11 @@ class MilvusIndex(EmbeddingIndex):
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
) )
if not self.client.has_collection(self.collection_name): if not self.client.has_collection(self.collection_name):
self.client.create_collection(self.collection_name, dimension=len(embeddings[0]), auto_id=True) self.client.create_collection(self.collection_name, dimension=len(embeddings[0]), auto_id=True, consistency_level=self.consistency_level)
data = [] data = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings, strict=False)): for i, (chunk, embedding) in enumerate(zip(chunks, embeddings, strict=False)):
chunk_id = f"{chunk.metadata['document_id']}:chunk-{i}" chunk_id = generate_chunk_id(chunk.metadata["document_id"], chunk.content)
data.append( data.append(
{ {
@ -51,10 +56,14 @@ class MilvusIndex(EmbeddingIndex):
"chunk_content": chunk.model_dump(), "chunk_content": chunk.model_dump(),
} }
) )
self.client.insert( try:
self.collection_name, self.client.insert(
data=data, self.collection_name,
) data=data,
)
except Exception as e:
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
raise e
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
search_res = self.client.search( search_res = self.client.search(
@ -70,14 +79,20 @@ class MilvusIndex(EmbeddingIndex):
class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: MilvusVectorIOConfig, inference_api: Api.inference) -> None: def __init__(self, config: Union[RemoteMilvusVectorIOConfig, InlineMilvusVectorIOConfig], inference_api: Api.inference) -> None:
self.config = config self.config = config
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
self.cache = {} self.cache = {}
self.client = None
self.inference_api = inference_api self.inference_api = inference_api
async def initialize(self) -> None: async def initialize(self) -> None:
pass if isinstance(self.config, RemoteMilvusVectorIOConfig):
logger.info(f"Connecting to Milvus server at {self.config.uri}")
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
else:
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
uri = os.path.expanduser(self.config.db_path)
self.client = MilvusClient(uri=uri)
async def shutdown(self) -> None: async def shutdown(self) -> None:
self.client.close() self.client.close()
@ -86,9 +101,13 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
self, self,
vector_db: VectorDB, vector_db: VectorDB,
) -> None: ) -> None:
if isinstance(self.config, RemoteMilvusVectorIOConfig):
consistency_level = self.config.consistency_level
else:
consistency_level = "Strong"
index = VectorDBWithIndex( index = VectorDBWithIndex(
vector_db=vector_db, vector_db=vector_db,
index=MilvusIndex(self.client, vector_db.identifier), index=MilvusIndex(self.client, vector_db.identifier, consistency_level=consistency_level),
inference_api=self.inference_api, inference_api=self.inference_api,
) )
@ -138,3 +157,10 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
raise ValueError(f"Vector DB {vector_db_id} not found") raise ValueError(f"Vector DB {vector_db_id} not found")
return await index.query_chunks(query, params) return await index.query_chunks(query, params)
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
hash_input = f"{document_id}:{chunk_text}".encode("utf-8")
return str(uuid.UUID(hashlib.md5(hash_input).hexdigest()))
# TODO: refactor this generate_chunk_id along with the `sqlite-vec` implementation into a separate utils file