fix: fixed Milvus integration code

Signed-off-by: ChengZi <chen.zhang@zilliz.com>
This commit is contained in:
Cheney Zhang 2025-02-25 20:30:54 +08:00 committed by Ashwin Bharambe
parent 5d0d4c3467
commit fa67d79bfe
8 changed files with 77 additions and 162 deletions

View file

@ -68,6 +68,7 @@ A number of "adapters" are available for some popular Inference and Vector Store
| FAISS | Single Node |
| SQLite-Vec| Single Node |
| Chroma | Hosted and Single Node |
| Milvus | Hosted and Single Node |
| Postgres (PGVector) | Hosted and Single Node |
| Weaviate | Hosted |

View file

@ -2,7 +2,7 @@
The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include:
- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.),
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, FAISS, PGVector, etc.),
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, Milvus, FAISS, PGVector, etc.),
- Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.)
Providers come in two flavors:
@ -55,5 +55,6 @@ vector_io/sqlite-vec
vector_io/chromadb
vector_io/pgvector
vector_io/qdrant
vector_io/milvus
vector_io/weaviate
```

View file

@ -0,0 +1,31 @@
---
orphan: true
---
# Milvus
[Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It
allows you to store and query vectors directly within a Milvus database.
That means you're not limited to storing vectors in memory or in a separate service.
## Features
- Easy to use
- Fully integrated with Llama Stack
## Usage
To use Milvus in your Llama Stack project, follow these steps:
1. Install the necessary dependencies.
2. Configure your Llama Stack project to use Milvus.
3. Start storing and querying vectors.
## Installation
You can install Milvus using pymilvus:
```bash
pip install pymilvus
```
## Documentation
See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for more details about Milvus in general.

View file

@ -12,9 +12,7 @@ from .config import MilvusVectorIOConfig
async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, ProviderSpec]):
from .milvus import MilvusVectorIOAdapter
assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}"
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
impl = MilvusVectorIOAdapter(config, deps[Api.inference])
await impl.initialize()

View file

@ -1,143 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import os
from typing import Any, Dict, List, Optional
from numpy.typing import NDArray
from pymilvus import MilvusClient
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex,
VectorDBWithIndex,
)
from .config import MilvusVectorIOConfig
logger = logging.getLogger(__name__)
class MilvusIndex(EmbeddingIndex):
def __init__(self, client: MilvusClient, collection_name: str):
self.client = client
self.collection_name = collection_name.replace("-", "_")
async def delete(self):
if self.client.has_collection(self.collection_name):
self.client.drop_collection(collection_name=self.collection_name)
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
)
if not self.client.has_collection(self.collection_name):
self.client.create_collection(self.collection_name, dimension=len(embeddings[0]), auto_id=True)
data = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings, strict=False)):
chunk_id = f"{chunk.metadata['document_id']}:chunk-{i}"
data.append(
{
"chunk_id": chunk_id,
"vector": embedding,
"chunk_content": chunk.model_dump(),
}
)
self.client.insert(
self.collection_name,
data=data,
)
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
search_res = self.client.search(
collection_name=self.collection_name,
data=[embedding],
limit=k,
output_fields=["*"],
search_params={"params": {"radius": score_threshold}},
)
chunks = [Chunk(**res["entity"]["chunk_content"]) for res in search_res[0]]
scores = [res["distance"] for res in search_res[0]]
return QueryChunksResponse(chunks=chunks, scores=scores)
class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: MilvusVectorIOConfig, inference_api: Api.inference) -> None:
self.config = config
uri = self.config.model_dump(exclude_none=True)["db_path"]
uri = os.path.expanduser(uri)
self.client = MilvusClient(uri=uri)
self.cache = {}
self.inference_api = inference_api
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
self.client.close()
async def register_vector_db(
self,
vector_db: VectorDB,
) -> None:
index = VectorDBWithIndex(
vector_db=vector_db,
index=MilvusIndex(self.client, vector_db.identifier),
inference_api=self.inference_api,
)
self.cache[vector_db.identifier] = index
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> Optional[VectorDBWithIndex]:
if vector_db_id in self.cache:
return self.cache[vector_db_id]
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
if not vector_db:
raise ValueError(f"Vector DB {vector_db_id} not found")
index = VectorDBWithIndex(
vector_db=vector_db,
index=MilvusIndex(client=self.client, collection_name=vector_db.identifier),
inference_api=self.inference_api,
)
self.cache[vector_db_id] = index
return index
async def unregister_vector_db(self, vector_db_id: str) -> None:
if vector_db_id in self.cache:
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
async def insert_chunks(
self,
vector_db_id: str,
chunks: List[Chunk],
ttl_seconds: Optional[int] = None,
) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
await index.insert_chunks(chunks)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
return await index.query_chunks(query, params)

View file

@ -114,7 +114,7 @@ def available_providers() -> List[ProviderSpec]:
Api.vector_io,
AdapterSpec(
adapter_type="milvus",
pip_packages=EMBEDDING_DEPS + ["pymilvus"],
pip_packages=["pymilvus"],
module="llama_stack.providers.remote.vector_io.milvus",
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
),
@ -123,7 +123,7 @@ def available_providers() -> List[ProviderSpec]:
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::milvus",
pip_packages=EMBEDDING_DEPS + ["pymilvus"],
pip_packages=["pymilvus"],
module="llama_stack.providers.inline.vector_io.milvus",
config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig",
api_dependencies=[Api.inference],

View file

@ -15,6 +15,7 @@ from llama_stack.schema_utils import json_schema_type
class MilvusVectorIOConfig(BaseModel):
uri: str
token: Optional[str] = None
consistency_level: str = "Strong"
@classmethod
def sample_config(cls) -> Dict[str, Any]:

View file

@ -5,8 +5,11 @@
# the root directory of this source tree.
import logging
from typing import Any, Dict, List, Optional
import os
from typing import Any, Dict, List, Optional, Union
import hashlib
import uuid
from numpy.typing import NDArray
from pymilvus import MilvusClient
@ -14,20 +17,22 @@ from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex,
VectorDBWithIndex,
)
from .config import MilvusVectorIOConfig
from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
logger = logging.getLogger(__name__)
class MilvusIndex(EmbeddingIndex):
def __init__(self, client: MilvusClient, collection_name: str):
def __init__(self, client: MilvusClient, collection_name: str, consistency_level="Strong"):
self.client = client
self.collection_name = collection_name.replace("-", "_")
self.consistency_level = consistency_level
async def delete(self):
if self.client.has_collection(self.collection_name):
@ -38,11 +43,11 @@ class MilvusIndex(EmbeddingIndex):
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
)
if not self.client.has_collection(self.collection_name):
self.client.create_collection(self.collection_name, dimension=len(embeddings[0]), auto_id=True)
self.client.create_collection(self.collection_name, dimension=len(embeddings[0]), auto_id=True, consistency_level=self.consistency_level)
data = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings, strict=False)):
chunk_id = f"{chunk.metadata['document_id']}:chunk-{i}"
chunk_id = generate_chunk_id(chunk.metadata["document_id"], chunk.content)
data.append(
{
@ -51,10 +56,14 @@ class MilvusIndex(EmbeddingIndex):
"chunk_content": chunk.model_dump(),
}
)
self.client.insert(
self.collection_name,
data=data,
)
try:
self.client.insert(
self.collection_name,
data=data,
)
except Exception as e:
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
raise e
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
search_res = self.client.search(
@ -70,14 +79,20 @@ class MilvusIndex(EmbeddingIndex):
class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: MilvusVectorIOConfig, inference_api: Api.inference) -> None:
def __init__(self, config: Union[RemoteMilvusVectorIOConfig, InlineMilvusVectorIOConfig], inference_api: Api.inference) -> None:
self.config = config
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
self.cache = {}
self.client = None
self.inference_api = inference_api
async def initialize(self) -> None:
pass
if isinstance(self.config, RemoteMilvusVectorIOConfig):
logger.info(f"Connecting to Milvus server at {self.config.uri}")
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
else:
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
uri = os.path.expanduser(self.config.db_path)
self.client = MilvusClient(uri=uri)
async def shutdown(self) -> None:
self.client.close()
@ -86,9 +101,13 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
self,
vector_db: VectorDB,
) -> None:
if isinstance(self.config, RemoteMilvusVectorIOConfig):
consistency_level = self.config.consistency_level
else:
consistency_level = "Strong"
index = VectorDBWithIndex(
vector_db=vector_db,
index=MilvusIndex(self.client, vector_db.identifier),
index=MilvusIndex(self.client, vector_db.identifier, consistency_level=consistency_level),
inference_api=self.inference_api,
)
@ -138,3 +157,10 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
raise ValueError(f"Vector DB {vector_db_id} not found")
return await index.query_chunks(query, params)
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
hash_input = f"{document_id}:{chunk_text}".encode("utf-8")
return str(uuid.UUID(hashlib.md5(hash_input).hexdigest()))
# TODO: refactor this generate_chunk_id along with the `sqlite-vec` implementation into a separate utils file