From fa67d79bfe6c295616815ee5aec136ba9ec20c27 Mon Sep 17 00:00:00 2001 From: Cheney Zhang Date: Tue, 25 Feb 2025 20:30:54 +0800 Subject: [PATCH] fix: fixed Milvus integration code Signed-off-by: ChengZi --- docs/source/index.md | 1 + docs/source/providers/index.md | 3 +- docs/source/providers/vector_io/mivus.md | 31 ++++ .../inline/vector_io/milvus/__init__.py | 4 +- .../inline/vector_io/milvus/milvus.py | 143 ------------------ llama_stack/providers/registry/vector_io.py | 4 +- .../remote/vector_io/milvus/config.py | 1 + .../remote/vector_io/milvus/milvus.py | 52 +++++-- 8 files changed, 77 insertions(+), 162 deletions(-) create mode 100644 docs/source/providers/vector_io/mivus.md delete mode 100644 llama_stack/providers/inline/vector_io/milvus/milvus.py diff --git a/docs/source/index.md b/docs/source/index.md index 4a698e28f..0d0508466 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -68,6 +68,7 @@ A number of "adapters" are available for some popular Inference and Vector Store | FAISS | Single Node | | SQLite-Vec| Single Node | | Chroma | Hosted and Single Node | +| Milvus | Hosted and Single Node | | Postgres (PGVector) | Hosted and Single Node | | Weaviate | Hosted | diff --git a/docs/source/providers/index.md b/docs/source/providers/index.md index 55db9aa13..f8997a281 100644 --- a/docs/source/providers/index.md +++ b/docs/source/providers/index.md @@ -2,7 +2,7 @@ The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include: - LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.), -- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, FAISS, PGVector, etc.), +- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, Milvus, FAISS, PGVector, etc.), - Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.) Providers come in two flavors: @@ -55,5 +55,6 @@ vector_io/sqlite-vec vector_io/chromadb vector_io/pgvector vector_io/qdrant +vector_io/milvus vector_io/weaviate ``` diff --git a/docs/source/providers/vector_io/mivus.md b/docs/source/providers/vector_io/mivus.md new file mode 100644 index 000000000..c57339c97 --- /dev/null +++ b/docs/source/providers/vector_io/mivus.md @@ -0,0 +1,31 @@ +--- +orphan: true +--- +# Milvus + +[Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It +allows you to store and query vectors directly within a Milvus database. +That means you're not limited to storing vectors in memory or in a separate service. + +## Features + +- Easy to use +- Fully integrated with Llama Stack + +## Usage + +To use Milvus in your Llama Stack project, follow these steps: + +1. Install the necessary dependencies. +2. Configure your Llama Stack project to use Milvus. +3. Start storing and querying vectors. + +## Installation + +You can install Milvus using pymilvus: + +```bash +pip install pymilvus +``` +## Documentation +See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for more details about Milvus in general. diff --git a/llama_stack/providers/inline/vector_io/milvus/__init__.py b/llama_stack/providers/inline/vector_io/milvus/__init__.py index a56a1af58..bee6b2ded 100644 --- a/llama_stack/providers/inline/vector_io/milvus/__init__.py +++ b/llama_stack/providers/inline/vector_io/milvus/__init__.py @@ -12,9 +12,7 @@ from .config import MilvusVectorIOConfig async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, ProviderSpec]): - from .milvus import MilvusVectorIOAdapter - - assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}" + from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter impl = MilvusVectorIOAdapter(config, deps[Api.inference]) await impl.initialize() diff --git a/llama_stack/providers/inline/vector_io/milvus/milvus.py b/llama_stack/providers/inline/vector_io/milvus/milvus.py deleted file mode 100644 index a9860ca32..000000000 --- a/llama_stack/providers/inline/vector_io/milvus/milvus.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import logging -import os -from typing import Any, Dict, List, Optional - -from numpy.typing import NDArray -from pymilvus import MilvusClient - -from llama_stack.apis.inference import InterleavedContent -from llama_stack.apis.vector_dbs import VectorDB -from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO -from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate -from llama_stack.providers.utils.memory.vector_store import ( - EmbeddingIndex, - VectorDBWithIndex, -) - -from .config import MilvusVectorIOConfig - -logger = logging.getLogger(__name__) - - -class MilvusIndex(EmbeddingIndex): - def __init__(self, client: MilvusClient, collection_name: str): - self.client = client - self.collection_name = collection_name.replace("-", "_") - - async def delete(self): - if self.client.has_collection(self.collection_name): - self.client.drop_collection(collection_name=self.collection_name) - - async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): - assert len(chunks) == len(embeddings), ( - f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" - ) - if not self.client.has_collection(self.collection_name): - self.client.create_collection(self.collection_name, dimension=len(embeddings[0]), auto_id=True) - - data = [] - for i, (chunk, embedding) in enumerate(zip(chunks, embeddings, strict=False)): - chunk_id = f"{chunk.metadata['document_id']}:chunk-{i}" - - data.append( - { - "chunk_id": chunk_id, - "vector": embedding, - "chunk_content": chunk.model_dump(), - } - ) - self.client.insert( - self.collection_name, - data=data, - ) - - async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: - search_res = self.client.search( - collection_name=self.collection_name, - data=[embedding], - limit=k, - output_fields=["*"], - search_params={"params": {"radius": score_threshold}}, - ) - chunks = [Chunk(**res["entity"]["chunk_content"]) for res in search_res[0]] - scores = [res["distance"] for res in search_res[0]] - return QueryChunksResponse(chunks=chunks, scores=scores) - - -class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): - def __init__(self, config: MilvusVectorIOConfig, inference_api: Api.inference) -> None: - self.config = config - uri = self.config.model_dump(exclude_none=True)["db_path"] - uri = os.path.expanduser(uri) - self.client = MilvusClient(uri=uri) - self.cache = {} - self.inference_api = inference_api - - async def initialize(self) -> None: - pass - - async def shutdown(self) -> None: - self.client.close() - - async def register_vector_db( - self, - vector_db: VectorDB, - ) -> None: - index = VectorDBWithIndex( - vector_db=vector_db, - index=MilvusIndex(self.client, vector_db.identifier), - inference_api=self.inference_api, - ) - - self.cache[vector_db.identifier] = index - - async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> Optional[VectorDBWithIndex]: - if vector_db_id in self.cache: - return self.cache[vector_db_id] - - vector_db = await self.vector_db_store.get_vector_db(vector_db_id) - if not vector_db: - raise ValueError(f"Vector DB {vector_db_id} not found") - - index = VectorDBWithIndex( - vector_db=vector_db, - index=MilvusIndex(client=self.client, collection_name=vector_db.identifier), - inference_api=self.inference_api, - ) - self.cache[vector_db_id] = index - return index - - async def unregister_vector_db(self, vector_db_id: str) -> None: - if vector_db_id in self.cache: - await self.cache[vector_db_id].index.delete() - del self.cache[vector_db_id] - - async def insert_chunks( - self, - vector_db_id: str, - chunks: List[Chunk], - ttl_seconds: Optional[int] = None, - ) -> None: - index = await self._get_and_cache_vector_db_index(vector_db_id) - if not index: - raise ValueError(f"Vector DB {vector_db_id} not found") - - await index.insert_chunks(chunks) - - async def query_chunks( - self, - vector_db_id: str, - query: InterleavedContent, - params: Optional[Dict[str, Any]] = None, - ) -> QueryChunksResponse: - index = await self._get_and_cache_vector_db_index(vector_db_id) - if not index: - raise ValueError(f"Vector DB {vector_db_id} not found") - - return await index.query_chunks(query, params) diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py index cfb7b343b..b15b71622 100644 --- a/llama_stack/providers/registry/vector_io.py +++ b/llama_stack/providers/registry/vector_io.py @@ -114,7 +114,7 @@ def available_providers() -> List[ProviderSpec]: Api.vector_io, AdapterSpec( adapter_type="milvus", - pip_packages=EMBEDDING_DEPS + ["pymilvus"], + pip_packages=["pymilvus"], module="llama_stack.providers.remote.vector_io.milvus", config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig", ), @@ -123,7 +123,7 @@ def available_providers() -> List[ProviderSpec]: InlineProviderSpec( api=Api.vector_io, provider_type="inline::milvus", - pip_packages=EMBEDDING_DEPS + ["pymilvus"], + pip_packages=["pymilvus"], module="llama_stack.providers.inline.vector_io.milvus", config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig", api_dependencies=[Api.inference], diff --git a/llama_stack/providers/remote/vector_io/milvus/config.py b/llama_stack/providers/remote/vector_io/milvus/config.py index 9dabbb8dc..34bd42987 100644 --- a/llama_stack/providers/remote/vector_io/milvus/config.py +++ b/llama_stack/providers/remote/vector_io/milvus/config.py @@ -15,6 +15,7 @@ from llama_stack.schema_utils import json_schema_type class MilvusVectorIOConfig(BaseModel): uri: str token: Optional[str] = None + consistency_level: str = "Strong" @classmethod def sample_config(cls) -> Dict[str, Any]: diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index f08dd3096..d950a4be0 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -5,8 +5,11 @@ # the root directory of this source tree. import logging -from typing import Any, Dict, List, Optional +import os +from typing import Any, Dict, List, Optional, Union +import hashlib +import uuid from numpy.typing import NDArray from pymilvus import MilvusClient @@ -14,20 +17,22 @@ from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate +from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig from llama_stack.providers.utils.memory.vector_store import ( EmbeddingIndex, VectorDBWithIndex, ) -from .config import MilvusVectorIOConfig +from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig logger = logging.getLogger(__name__) class MilvusIndex(EmbeddingIndex): - def __init__(self, client: MilvusClient, collection_name: str): + def __init__(self, client: MilvusClient, collection_name: str, consistency_level="Strong"): self.client = client self.collection_name = collection_name.replace("-", "_") + self.consistency_level = consistency_level async def delete(self): if self.client.has_collection(self.collection_name): @@ -38,11 +43,11 @@ class MilvusIndex(EmbeddingIndex): f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" ) if not self.client.has_collection(self.collection_name): - self.client.create_collection(self.collection_name, dimension=len(embeddings[0]), auto_id=True) + self.client.create_collection(self.collection_name, dimension=len(embeddings[0]), auto_id=True, consistency_level=self.consistency_level) data = [] for i, (chunk, embedding) in enumerate(zip(chunks, embeddings, strict=False)): - chunk_id = f"{chunk.metadata['document_id']}:chunk-{i}" + chunk_id = generate_chunk_id(chunk.metadata["document_id"], chunk.content) data.append( { @@ -51,10 +56,14 @@ class MilvusIndex(EmbeddingIndex): "chunk_content": chunk.model_dump(), } ) - self.client.insert( - self.collection_name, - data=data, - ) + try: + self.client.insert( + self.collection_name, + data=data, + ) + except Exception as e: + logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}") + raise e async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: search_res = self.client.search( @@ -70,14 +79,20 @@ class MilvusIndex(EmbeddingIndex): class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): - def __init__(self, config: MilvusVectorIOConfig, inference_api: Api.inference) -> None: + def __init__(self, config: Union[RemoteMilvusVectorIOConfig, InlineMilvusVectorIOConfig], inference_api: Api.inference) -> None: self.config = config - self.client = MilvusClient(**self.config.model_dump(exclude_none=True)) self.cache = {} + self.client = None self.inference_api = inference_api async def initialize(self) -> None: - pass + if isinstance(self.config, RemoteMilvusVectorIOConfig): + logger.info(f"Connecting to Milvus server at {self.config.uri}") + self.client = MilvusClient(**self.config.model_dump(exclude_none=True)) + else: + logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}") + uri = os.path.expanduser(self.config.db_path) + self.client = MilvusClient(uri=uri) async def shutdown(self) -> None: self.client.close() @@ -86,9 +101,13 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): self, vector_db: VectorDB, ) -> None: + if isinstance(self.config, RemoteMilvusVectorIOConfig): + consistency_level = self.config.consistency_level + else: + consistency_level = "Strong" index = VectorDBWithIndex( vector_db=vector_db, - index=MilvusIndex(self.client, vector_db.identifier), + index=MilvusIndex(self.client, vector_db.identifier, consistency_level=consistency_level), inference_api=self.inference_api, ) @@ -138,3 +157,10 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): raise ValueError(f"Vector DB {vector_db_id} not found") return await index.query_chunks(query, params) + +def generate_chunk_id(document_id: str, chunk_text: str) -> str: + """Generate a unique chunk ID using a hash of document ID and chunk text.""" + hash_input = f"{document_id}:{chunk_text}".encode("utf-8") + return str(uuid.UUID(hashlib.md5(hash_input).hexdigest())) + +# TODO: refactor this generate_chunk_id along with the `sqlite-vec` implementation into a separate utils file