diff --git a/docs/source/concepts/index.md b/docs/source/concepts/index.md index 969e12c1a..9dee2b859 100644 --- a/docs/source/concepts/index.md +++ b/docs/source/concepts/index.md @@ -34,7 +34,7 @@ We are working on adding a few more APIs to complete the application lifecycle. The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include: - LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.), -- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, FAISS, PGVector, etc.), +- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, Milvus, FAISS, PGVector, etc.), - Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.) Providers come in two flavors: diff --git a/llama_stack/providers/inline/vector_io/milvus/__init__.py b/llama_stack/providers/inline/vector_io/milvus/__init__.py new file mode 100644 index 000000000..a56a1af58 --- /dev/null +++ b/llama_stack/providers/inline/vector_io/milvus/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Dict + +from llama_stack.providers.datatypes import Api, ProviderSpec + +from .config import MilvusVectorIOConfig + + +async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, ProviderSpec]): + from .milvus import MilvusVectorIOAdapter + + assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}" + + impl = MilvusVectorIOAdapter(config, deps[Api.inference]) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/vector_io/milvus/config.py b/llama_stack/providers/inline/vector_io/milvus/config.py new file mode 100644 index 000000000..aa7ea6dcf --- /dev/null +++ b/llama_stack/providers/inline/vector_io/milvus/config.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict + +from pydantic import BaseModel + +from llama_stack.schema_utils import json_schema_type + + +@json_schema_type +class MilvusVectorIOConfig(BaseModel): + db_path: str + + @classmethod + def sample_config(cls) -> Dict[str, Any]: + return {"db_path": "{env.MILVUS_ENDPOINT}"} diff --git a/llama_stack/providers/inline/vector_io/milvus/milvus.py b/llama_stack/providers/inline/vector_io/milvus/milvus.py new file mode 100644 index 000000000..a9860ca32 --- /dev/null +++ b/llama_stack/providers/inline/vector_io/milvus/milvus.py @@ -0,0 +1,143 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import logging +import os +from typing import Any, Dict, List, Optional + +from numpy.typing import NDArray +from pymilvus import MilvusClient + +from llama_stack.apis.inference import InterleavedContent +from llama_stack.apis.vector_dbs import VectorDB +from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO +from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate +from llama_stack.providers.utils.memory.vector_store import ( + EmbeddingIndex, + VectorDBWithIndex, +) + +from .config import MilvusVectorIOConfig + +logger = logging.getLogger(__name__) + + +class MilvusIndex(EmbeddingIndex): + def __init__(self, client: MilvusClient, collection_name: str): + self.client = client + self.collection_name = collection_name.replace("-", "_") + + async def delete(self): + if self.client.has_collection(self.collection_name): + self.client.drop_collection(collection_name=self.collection_name) + + async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): + assert len(chunks) == len(embeddings), ( + f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" + ) + if not self.client.has_collection(self.collection_name): + self.client.create_collection(self.collection_name, dimension=len(embeddings[0]), auto_id=True) + + data = [] + for i, (chunk, embedding) in enumerate(zip(chunks, embeddings, strict=False)): + chunk_id = f"{chunk.metadata['document_id']}:chunk-{i}" + + data.append( + { + "chunk_id": chunk_id, + "vector": embedding, + "chunk_content": chunk.model_dump(), + } + ) + self.client.insert( + self.collection_name, + data=data, + ) + + async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + search_res = self.client.search( + collection_name=self.collection_name, + data=[embedding], + limit=k, + output_fields=["*"], + search_params={"params": {"radius": score_threshold}}, + ) + chunks = [Chunk(**res["entity"]["chunk_content"]) for res in search_res[0]] + scores = [res["distance"] for res in search_res[0]] + return QueryChunksResponse(chunks=chunks, scores=scores) + + +class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): + def __init__(self, config: MilvusVectorIOConfig, inference_api: Api.inference) -> None: + self.config = config + uri = self.config.model_dump(exclude_none=True)["db_path"] + uri = os.path.expanduser(uri) + self.client = MilvusClient(uri=uri) + self.cache = {} + self.inference_api = inference_api + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + self.client.close() + + async def register_vector_db( + self, + vector_db: VectorDB, + ) -> None: + index = VectorDBWithIndex( + vector_db=vector_db, + index=MilvusIndex(self.client, vector_db.identifier), + inference_api=self.inference_api, + ) + + self.cache[vector_db.identifier] = index + + async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> Optional[VectorDBWithIndex]: + if vector_db_id in self.cache: + return self.cache[vector_db_id] + + vector_db = await self.vector_db_store.get_vector_db(vector_db_id) + if not vector_db: + raise ValueError(f"Vector DB {vector_db_id} not found") + + index = VectorDBWithIndex( + vector_db=vector_db, + index=MilvusIndex(client=self.client, collection_name=vector_db.identifier), + inference_api=self.inference_api, + ) + self.cache[vector_db_id] = index + return index + + async def unregister_vector_db(self, vector_db_id: str) -> None: + if vector_db_id in self.cache: + await self.cache[vector_db_id].index.delete() + del self.cache[vector_db_id] + + async def insert_chunks( + self, + vector_db_id: str, + chunks: List[Chunk], + ttl_seconds: Optional[int] = None, + ) -> None: + index = await self._get_and_cache_vector_db_index(vector_db_id) + if not index: + raise ValueError(f"Vector DB {vector_db_id} not found") + + await index.insert_chunks(chunks) + + async def query_chunks( + self, + vector_db_id: str, + query: InterleavedContent, + params: Optional[Dict[str, Any]] = None, + ) -> QueryChunksResponse: + index = await self._get_and_cache_vector_db_index(vector_db_id) + if not index: + raise ValueError(f"Vector DB {vector_db_id} not found") + + return await index.query_chunks(query, params) diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py index ff4f9caf5..cfb7b343b 100644 --- a/llama_stack/providers/registry/vector_io.py +++ b/llama_stack/providers/registry/vector_io.py @@ -110,4 +110,22 @@ def available_providers() -> List[ProviderSpec]: ), api_dependencies=[Api.inference], ), + remote_provider_spec( + Api.vector_io, + AdapterSpec( + adapter_type="milvus", + pip_packages=EMBEDDING_DEPS + ["pymilvus"], + module="llama_stack.providers.remote.vector_io.milvus", + config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig", + ), + api_dependencies=[Api.inference], + ), + InlineProviderSpec( + api=Api.vector_io, + provider_type="inline::milvus", + pip_packages=EMBEDDING_DEPS + ["pymilvus"], + module="llama_stack.providers.inline.vector_io.milvus", + config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig", + api_dependencies=[Api.inference], + ), ] diff --git a/llama_stack/providers/remote/vector_io/milvus/__init__.py b/llama_stack/providers/remote/vector_io/milvus/__init__.py new file mode 100644 index 000000000..84cb1d748 --- /dev/null +++ b/llama_stack/providers/remote/vector_io/milvus/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Dict + +from llama_stack.providers.datatypes import Api, ProviderSpec + +from .config import MilvusVectorIOConfig + + +async def get_adapter_impl(config: MilvusVectorIOConfig, deps: Dict[Api, ProviderSpec]): + from .milvus import MilvusVectorIOAdapter + + assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}" + + impl = MilvusVectorIOAdapter(config, deps[Api.inference]) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/vector_io/milvus/config.py b/llama_stack/providers/remote/vector_io/milvus/config.py new file mode 100644 index 000000000..9dabbb8dc --- /dev/null +++ b/llama_stack/providers/remote/vector_io/milvus/config.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, Optional + +from pydantic import BaseModel + +from llama_stack.schema_utils import json_schema_type + + +@json_schema_type +class MilvusVectorIOConfig(BaseModel): + uri: str + token: Optional[str] = None + + @classmethod + def sample_config(cls) -> Dict[str, Any]: + return {"uri": "{env.MILVUS_ENDPOINT}", "token": "{env.MILVUS_TOKEN}"} diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py new file mode 100644 index 000000000..f08dd3096 --- /dev/null +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -0,0 +1,140 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import logging +from typing import Any, Dict, List, Optional + +from numpy.typing import NDArray +from pymilvus import MilvusClient + +from llama_stack.apis.inference import InterleavedContent +from llama_stack.apis.vector_dbs import VectorDB +from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO +from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate +from llama_stack.providers.utils.memory.vector_store import ( + EmbeddingIndex, + VectorDBWithIndex, +) + +from .config import MilvusVectorIOConfig + +logger = logging.getLogger(__name__) + + +class MilvusIndex(EmbeddingIndex): + def __init__(self, client: MilvusClient, collection_name: str): + self.client = client + self.collection_name = collection_name.replace("-", "_") + + async def delete(self): + if self.client.has_collection(self.collection_name): + self.client.drop_collection(collection_name=self.collection_name) + + async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): + assert len(chunks) == len(embeddings), ( + f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" + ) + if not self.client.has_collection(self.collection_name): + self.client.create_collection(self.collection_name, dimension=len(embeddings[0]), auto_id=True) + + data = [] + for i, (chunk, embedding) in enumerate(zip(chunks, embeddings, strict=False)): + chunk_id = f"{chunk.metadata['document_id']}:chunk-{i}" + + data.append( + { + "chunk_id": chunk_id, + "vector": embedding, + "chunk_content": chunk.model_dump(), + } + ) + self.client.insert( + self.collection_name, + data=data, + ) + + async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + search_res = self.client.search( + collection_name=self.collection_name, + data=[embedding], + limit=k, + output_fields=["*"], + search_params={"params": {"radius": score_threshold}}, + ) + chunks = [Chunk(**res["entity"]["chunk_content"]) for res in search_res[0]] + scores = [res["distance"] for res in search_res[0]] + return QueryChunksResponse(chunks=chunks, scores=scores) + + +class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): + def __init__(self, config: MilvusVectorIOConfig, inference_api: Api.inference) -> None: + self.config = config + self.client = MilvusClient(**self.config.model_dump(exclude_none=True)) + self.cache = {} + self.inference_api = inference_api + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + self.client.close() + + async def register_vector_db( + self, + vector_db: VectorDB, + ) -> None: + index = VectorDBWithIndex( + vector_db=vector_db, + index=MilvusIndex(self.client, vector_db.identifier), + inference_api=self.inference_api, + ) + + self.cache[vector_db.identifier] = index + + async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> Optional[VectorDBWithIndex]: + if vector_db_id in self.cache: + return self.cache[vector_db_id] + + vector_db = await self.vector_db_store.get_vector_db(vector_db_id) + if not vector_db: + raise ValueError(f"Vector DB {vector_db_id} not found") + + index = VectorDBWithIndex( + vector_db=vector_db, + index=MilvusIndex(client=self.client, collection_name=vector_db.identifier), + inference_api=self.inference_api, + ) + self.cache[vector_db_id] = index + return index + + async def unregister_vector_db(self, vector_db_id: str) -> None: + if vector_db_id in self.cache: + await self.cache[vector_db_id].index.delete() + del self.cache[vector_db_id] + + async def insert_chunks( + self, + vector_db_id: str, + chunks: List[Chunk], + ttl_seconds: Optional[int] = None, + ) -> None: + index = await self._get_and_cache_vector_db_index(vector_db_id) + if not index: + raise ValueError(f"Vector DB {vector_db_id} not found") + + await index.insert_chunks(chunks) + + async def query_chunks( + self, + vector_db_id: str, + query: InterleavedContent, + params: Optional[Dict[str, Any]] = None, + ) -> QueryChunksResponse: + index = await self._get_and_cache_vector_db_index(vector_db_id) + if not index: + raise ValueError(f"Vector DB {vector_db_id} not found") + + return await index.query_chunks(query, params) diff --git a/tests/api/vector_io/test_vector_io.py b/tests/api/vector_io/test_vector_io.py new file mode 100644 index 000000000..e69de29bb