forked from phoenix-oss/llama-stack-mirror
# What does this PR do? This PR converts blocking Milvus Client calls to non-blocking. Another one for https://github.com/meta-llama/llama-stack/issues/1489 ## Test Plan I ran the integration tests from https://github.com/meta-llama/llama-stack/pull/1467 with: ```python pytest -s -v tests/integration/vector_io/test_vector_io.py \ --stack-config inference=sentence-transformers,vector_io=inline::milvus \ --embedding-model all-miniLM-L6-V2 --env MILVUS_DB_PATH=/tmp/moo.db INFO 2025-03-28 21:35:22,726 tests.integration.conftest:41 tests: Setting DISABLE_CODE_SANDBOX=1 for macOS /Users/farceo/dev/llama-stack/.venv/lib/python3.10/site-packages/pytest_asyncio/plugin.py:207: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset. The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session" warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET)) =============================================================================================================================================================================================================================================================== test session starts =============================================================================================================================================================================================================================================================== platform darwin -- Python 3.10.16, pytest-8.3.4, pluggy-1.5.0 -- /Users/farceo/dev/llama-stack/.venv/bin/python3 cachedir: .pytest_cache metadata: {'Python': '3.10.16', 'Platform': 'macOS-15.3.1-arm64-arm-64bit', 'Packages': {'pytest': '8.3.4', 'pluggy': '1.5.0'}, 'Plugins': {'cov': '6.0.0', 'html': '4.1.1', 'metadata': '3.1.1', 'asyncio': '0.25.3', 'anyio': '4.8.0', 'nbval': '0.11.0'}} rootdir: /Users/farceo/dev/llama-stack configfile: pyproject.toml plugins: cov-6.0.0, html-4.1.1, metadata-3.1.1, asyncio-0.25.3, anyio-4.8.0, nbval-0.11.0 asyncio: mode=strict, asyncio_default_fixture_loop_scope=None collected 7 items tests/integration/vector_io/test_vector_io.py::test_vector_db_retrieve[emb=all-miniLM-L6-V2] PASSED tests/integration/vector_io/test_vector_io.py::test_vector_db_register[emb=all-miniLM-L6-V2] PASSED tests/integration/vector_io/test_vector_io.py::test_insert_chunks[emb=all-miniLM-L6-V2-test_case0] PASSED tests/integration/vector_io/test_vector_io.py::test_insert_chunks[emb=all-miniLM-L6-V2-test_case1] PASSED tests/integration/vector_io/test_vector_io.py::test_insert_chunks[emb=all-miniLM-L6-V2-test_case2] PASSED tests/integration/vector_io/test_vector_io.py::test_insert_chunks[emb=all-miniLM-L6-V2-test_case3] PASSED tests/integration/vector_io/test_vector_io.py::test_insert_chunks[emb=all-miniLM-L6-V2-test_case4] PASSED ========================================================================================================================================================================================================================================================= 7 passed, 2 warnings in 40.33s ========================================================================================================================================================================================================================================================== ``` [//]: # (## Documentation) Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
179 lines
6.6 KiB
Python
179 lines
6.6 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import asyncio
|
|
import hashlib
|
|
import logging
|
|
import os
|
|
import uuid
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
|
from numpy.typing import NDArray
|
|
from pymilvus import MilvusClient
|
|
|
|
from llama_stack.apis.inference import InterleavedContent
|
|
from llama_stack.apis.vector_dbs import VectorDB
|
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
|
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
|
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
|
|
from llama_stack.providers.utils.memory.vector_store import (
|
|
EmbeddingIndex,
|
|
VectorDBWithIndex,
|
|
)
|
|
|
|
from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MilvusIndex(EmbeddingIndex):
|
|
def __init__(self, client: MilvusClient, collection_name: str, consistency_level="Strong"):
|
|
self.client = client
|
|
self.collection_name = collection_name.replace("-", "_")
|
|
self.consistency_level = consistency_level
|
|
|
|
async def delete(self):
|
|
if await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
|
await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name)
|
|
|
|
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
|
assert len(chunks) == len(embeddings), (
|
|
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
|
)
|
|
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
|
await asyncio.to_thread(
|
|
self.client.create_collection,
|
|
self.collection_name,
|
|
dimension=len(embeddings[0]),
|
|
auto_id=True,
|
|
consistency_level=self.consistency_level,
|
|
)
|
|
|
|
data = []
|
|
for chunk, embedding in zip(chunks, embeddings, strict=False):
|
|
chunk_id = generate_chunk_id(chunk.metadata["document_id"], chunk.content)
|
|
|
|
data.append(
|
|
{
|
|
"chunk_id": chunk_id,
|
|
"vector": embedding,
|
|
"chunk_content": chunk.model_dump(),
|
|
}
|
|
)
|
|
try:
|
|
await asyncio.to_thread(
|
|
self.client.insert,
|
|
self.collection_name,
|
|
data=data,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
|
|
raise e
|
|
|
|
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
|
search_res = await asyncio.to_thread(
|
|
self.client.search,
|
|
collection_name=self.collection_name,
|
|
data=[embedding],
|
|
limit=k,
|
|
output_fields=["*"],
|
|
search_params={"params": {"radius": score_threshold}},
|
|
)
|
|
chunks = [Chunk(**res["entity"]["chunk_content"]) for res in search_res[0]]
|
|
scores = [res["distance"] for res in search_res[0]]
|
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
|
|
|
|
|
class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|
def __init__(
|
|
self, config: Union[RemoteMilvusVectorIOConfig, InlineMilvusVectorIOConfig], inference_api: Api.inference
|
|
) -> None:
|
|
self.config = config
|
|
self.cache = {}
|
|
self.client = None
|
|
self.inference_api = inference_api
|
|
|
|
async def initialize(self) -> None:
|
|
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
|
logger.info(f"Connecting to Milvus server at {self.config.uri}")
|
|
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
|
|
else:
|
|
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
|
|
uri = os.path.expanduser(self.config.db_path)
|
|
self.client = MilvusClient(uri=uri)
|
|
|
|
async def shutdown(self) -> None:
|
|
self.client.close()
|
|
|
|
async def register_vector_db(
|
|
self,
|
|
vector_db: VectorDB,
|
|
) -> None:
|
|
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
|
consistency_level = self.config.consistency_level
|
|
else:
|
|
consistency_level = "Strong"
|
|
index = VectorDBWithIndex(
|
|
vector_db=vector_db,
|
|
index=MilvusIndex(self.client, vector_db.identifier, consistency_level=consistency_level),
|
|
inference_api=self.inference_api,
|
|
)
|
|
|
|
self.cache[vector_db.identifier] = index
|
|
|
|
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> Optional[VectorDBWithIndex]:
|
|
if vector_db_id in self.cache:
|
|
return self.cache[vector_db_id]
|
|
|
|
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
|
|
if not vector_db:
|
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
|
|
|
index = VectorDBWithIndex(
|
|
vector_db=vector_db,
|
|
index=MilvusIndex(client=self.client, collection_name=vector_db.identifier),
|
|
inference_api=self.inference_api,
|
|
)
|
|
self.cache[vector_db_id] = index
|
|
return index
|
|
|
|
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
|
if vector_db_id in self.cache:
|
|
await self.cache[vector_db_id].index.delete()
|
|
del self.cache[vector_db_id]
|
|
|
|
async def insert_chunks(
|
|
self,
|
|
vector_db_id: str,
|
|
chunks: List[Chunk],
|
|
ttl_seconds: Optional[int] = None,
|
|
) -> None:
|
|
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
|
if not index:
|
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
|
|
|
await index.insert_chunks(chunks)
|
|
|
|
async def query_chunks(
|
|
self,
|
|
vector_db_id: str,
|
|
query: InterleavedContent,
|
|
params: Optional[Dict[str, Any]] = None,
|
|
) -> QueryChunksResponse:
|
|
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
|
if not index:
|
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
|
|
|
return await index.query_chunks(query, params)
|
|
|
|
|
|
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
|
|
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
|
|
hash_input = f"{document_id}:{chunk_text}".encode("utf-8")
|
|
return str(uuid.UUID(hashlib.md5(hash_input).hexdigest()))
|
|
|
|
|
|
# TODO: refactor this generate_chunk_id along with the `sqlite-vec` implementation into a separate utils file
|