diff --git a/docs/source/concepts/index.md b/docs/source/concepts/index.md index 969e12c1a..9dee2b859 100644 --- a/docs/source/concepts/index.md +++ b/docs/source/concepts/index.md @@ -34,7 +34,7 @@ We are working on adding a few more APIs to complete the application lifecycle. The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include: - LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.), -- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, FAISS, PGVector, etc.), +- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, Milvus, FAISS, PGVector, etc.), - Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.) Providers come in two flavors: diff --git a/docs/source/index.md b/docs/source/index.md index 4a698e28f..0d0508466 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -68,6 +68,7 @@ A number of "adapters" are available for some popular Inference and Vector Store | FAISS | Single Node | | SQLite-Vec| Single Node | | Chroma | Hosted and Single Node | +| Milvus | Hosted and Single Node | | Postgres (PGVector) | Hosted and Single Node | | Weaviate | Hosted | diff --git a/docs/source/providers/index.md b/docs/source/providers/index.md index 55db9aa13..f8997a281 100644 --- a/docs/source/providers/index.md +++ b/docs/source/providers/index.md @@ -2,7 +2,7 @@ The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include: - LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.), -- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, FAISS, PGVector, etc.), +- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, Milvus, FAISS, PGVector, etc.), - Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.) Providers come in two flavors: @@ -55,5 +55,6 @@ vector_io/sqlite-vec vector_io/chromadb vector_io/pgvector vector_io/qdrant +vector_io/milvus vector_io/weaviate ``` diff --git a/docs/source/providers/vector_io/mivus.md b/docs/source/providers/vector_io/mivus.md new file mode 100644 index 000000000..8d2f043d5 --- /dev/null +++ b/docs/source/providers/vector_io/mivus.md @@ -0,0 +1,31 @@ +--- +orphan: true +--- +# Milvus + +[Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It +allows you to store and query vectors directly within a Milvus database. +That means you're not limited to storing vectors in memory or in a separate service. + +## Features + +- Easy to use +- Fully integrated with Llama Stack + +## Usage + +To use Milvus in your Llama Stack project, follow these steps: + +1. Install the necessary dependencies. +2. Configure your Llama Stack project to use Milvus. +3. Start storing and querying vectors. + +## Installation + +You can install Milvus using pymilvus: + +```bash +pip install pymilvus +``` +## Documentation +See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for more details about Milvus in general. diff --git a/llama_stack/providers/inline/vector_io/milvus/__init__.py b/llama_stack/providers/inline/vector_io/milvus/__init__.py new file mode 100644 index 000000000..bee6b2ded --- /dev/null +++ b/llama_stack/providers/inline/vector_io/milvus/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Dict + +from llama_stack.providers.datatypes import Api, ProviderSpec + +from .config import MilvusVectorIOConfig + + +async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, ProviderSpec]): + from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter + + impl = MilvusVectorIOAdapter(config, deps[Api.inference]) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/vector_io/milvus/config.py b/llama_stack/providers/inline/vector_io/milvus/config.py new file mode 100644 index 000000000..0e11d8c7c --- /dev/null +++ b/llama_stack/providers/inline/vector_io/milvus/config.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict + +from pydantic import BaseModel + +from llama_stack.schema_utils import json_schema_type + + +@json_schema_type +class MilvusVectorIOConfig(BaseModel): + db_path: str + + @classmethod + def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]: + return {"db_path": "${env.MILVUS_DB_PATH}"} diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py index ff4f9caf5..b15b71622 100644 --- a/llama_stack/providers/registry/vector_io.py +++ b/llama_stack/providers/registry/vector_io.py @@ -110,4 +110,22 @@ def available_providers() -> List[ProviderSpec]: ), api_dependencies=[Api.inference], ), + remote_provider_spec( + Api.vector_io, + AdapterSpec( + adapter_type="milvus", + pip_packages=["pymilvus"], + module="llama_stack.providers.remote.vector_io.milvus", + config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig", + ), + api_dependencies=[Api.inference], + ), + InlineProviderSpec( + api=Api.vector_io, + provider_type="inline::milvus", + pip_packages=["pymilvus"], + module="llama_stack.providers.inline.vector_io.milvus", + config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig", + api_dependencies=[Api.inference], + ), ] diff --git a/llama_stack/providers/remote/vector_io/milvus/__init__.py b/llama_stack/providers/remote/vector_io/milvus/__init__.py new file mode 100644 index 000000000..84cb1d748 --- /dev/null +++ b/llama_stack/providers/remote/vector_io/milvus/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Dict + +from llama_stack.providers.datatypes import Api, ProviderSpec + +from .config import MilvusVectorIOConfig + + +async def get_adapter_impl(config: MilvusVectorIOConfig, deps: Dict[Api, ProviderSpec]): + from .milvus import MilvusVectorIOAdapter + + assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}" + + impl = MilvusVectorIOAdapter(config, deps[Api.inference]) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/vector_io/milvus/config.py b/llama_stack/providers/remote/vector_io/milvus/config.py new file mode 100644 index 000000000..17da6b23d --- /dev/null +++ b/llama_stack/providers/remote/vector_io/milvus/config.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, Optional + +from pydantic import BaseModel + +from llama_stack.schema_utils import json_schema_type + + +@json_schema_type +class MilvusVectorIOConfig(BaseModel): + uri: str + token: Optional[str] = None + consistency_level: str = "Strong" + + @classmethod + def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]: + return {"uri": "${env.MILVUS_ENDPOINT}", "token": "${env.MILVUS_TOKEN}"} diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py new file mode 100644 index 000000000..8ca9212bc --- /dev/null +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -0,0 +1,175 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import hashlib +import logging +import os +import uuid +from typing import Any, Dict, List, Optional, Union + +from numpy.typing import NDArray +from pymilvus import MilvusClient + +from llama_stack.apis.inference import InterleavedContent +from llama_stack.apis.vector_dbs import VectorDB +from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO +from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate +from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig +from llama_stack.providers.utils.memory.vector_store import ( + EmbeddingIndex, + VectorDBWithIndex, +) + +from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig + +logger = logging.getLogger(__name__) + + +class MilvusIndex(EmbeddingIndex): + def __init__(self, client: MilvusClient, collection_name: str, consistency_level="Strong"): + self.client = client + self.collection_name = collection_name.replace("-", "_") + self.consistency_level = consistency_level + + async def delete(self): + if self.client.has_collection(self.collection_name): + self.client.drop_collection(collection_name=self.collection_name) + + async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): + assert len(chunks) == len(embeddings), ( + f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" + ) + if not self.client.has_collection(self.collection_name): + self.client.create_collection( + self.collection_name, + dimension=len(embeddings[0]), + auto_id=True, + consistency_level=self.consistency_level, + ) + + data = [] + for chunk, embedding in zip(chunks, embeddings, strict=False): + chunk_id = generate_chunk_id(chunk.metadata["document_id"], chunk.content) + + data.append( + { + "chunk_id": chunk_id, + "vector": embedding, + "chunk_content": chunk.model_dump(), + } + ) + try: + self.client.insert( + self.collection_name, + data=data, + ) + except Exception as e: + logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}") + raise e + + async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + search_res = self.client.search( + collection_name=self.collection_name, + data=[embedding], + limit=k, + output_fields=["*"], + search_params={"params": {"radius": score_threshold}}, + ) + chunks = [Chunk(**res["entity"]["chunk_content"]) for res in search_res[0]] + scores = [res["distance"] for res in search_res[0]] + return QueryChunksResponse(chunks=chunks, scores=scores) + + +class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): + def __init__( + self, config: Union[RemoteMilvusVectorIOConfig, InlineMilvusVectorIOConfig], inference_api: Api.inference + ) -> None: + self.config = config + self.cache = {} + self.client = None + self.inference_api = inference_api + + async def initialize(self) -> None: + if isinstance(self.config, RemoteMilvusVectorIOConfig): + logger.info(f"Connecting to Milvus server at {self.config.uri}") + self.client = MilvusClient(**self.config.model_dump(exclude_none=True)) + else: + logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}") + uri = os.path.expanduser(self.config.db_path) + self.client = MilvusClient(uri=uri) + + async def shutdown(self) -> None: + self.client.close() + + async def register_vector_db( + self, + vector_db: VectorDB, + ) -> None: + if isinstance(self.config, RemoteMilvusVectorIOConfig): + consistency_level = self.config.consistency_level + else: + consistency_level = "Strong" + index = VectorDBWithIndex( + vector_db=vector_db, + index=MilvusIndex(self.client, vector_db.identifier, consistency_level=consistency_level), + inference_api=self.inference_api, + ) + + self.cache[vector_db.identifier] = index + + async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> Optional[VectorDBWithIndex]: + if vector_db_id in self.cache: + return self.cache[vector_db_id] + + vector_db = await self.vector_db_store.get_vector_db(vector_db_id) + if not vector_db: + raise ValueError(f"Vector DB {vector_db_id} not found") + + index = VectorDBWithIndex( + vector_db=vector_db, + index=MilvusIndex(client=self.client, collection_name=vector_db.identifier), + inference_api=self.inference_api, + ) + self.cache[vector_db_id] = index + return index + + async def unregister_vector_db(self, vector_db_id: str) -> None: + if vector_db_id in self.cache: + await self.cache[vector_db_id].index.delete() + del self.cache[vector_db_id] + + async def insert_chunks( + self, + vector_db_id: str, + chunks: List[Chunk], + ttl_seconds: Optional[int] = None, + ) -> None: + index = await self._get_and_cache_vector_db_index(vector_db_id) + if not index: + raise ValueError(f"Vector DB {vector_db_id} not found") + + await index.insert_chunks(chunks) + + async def query_chunks( + self, + vector_db_id: str, + query: InterleavedContent, + params: Optional[Dict[str, Any]] = None, + ) -> QueryChunksResponse: + index = await self._get_and_cache_vector_db_index(vector_db_id) + if not index: + raise ValueError(f"Vector DB {vector_db_id} not found") + + return await index.query_chunks(query, params) + + +def generate_chunk_id(document_id: str, chunk_text: str) -> str: + """Generate a unique chunk ID using a hash of document ID and chunk text.""" + hash_input = f"{document_id}:{chunk_text}".encode("utf-8") + return str(uuid.UUID(hashlib.md5(hash_input).hexdigest())) + + +# TODO: refactor this generate_chunk_id along with the `sqlite-vec` implementation into a separate utils file