diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index f9e4bb38e..b49049a8a 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -11468,6 +11468,32 @@ "ttl_seconds": { "type": "integer", "description": "The time to live of the chunks." + }, + "params": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + }, + "description": "Optional parameters for the insertion operation, such as distance_metric for vector databases." } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 9175c97fc..ad76e5535 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -8095,6 +8095,19 @@ components: ttl_seconds: type: integer description: The time to live of the chunks. + params: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: >- + Optional parameters for the insertion operation, such as distance_metric + for vector databases. additionalProperties: false required: - vector_db_id diff --git a/llama_stack/apis/vector_io/vector_io.py b/llama_stack/apis/vector_io/vector_io.py index 2d4131315..a033cd90f 100644 --- a/llama_stack/apis/vector_io/vector_io.py +++ b/llama_stack/apis/vector_io/vector_io.py @@ -306,6 +306,7 @@ class VectorIO(Protocol): vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None, + params: dict[str, Any] | None = None, ) -> None: """Insert chunks into a vector database. @@ -315,6 +316,7 @@ class VectorIO(Protocol): If `metadata` is provided, you configure how Llama Stack formats the chunk during generation. If `embedding` is not provided, it will be computed later. :param ttl_seconds: The time to live of the chunks. + :param params: Optional parameters for the insertion operation, such as distance_metric for vector databases. """ ... diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index 355750b25..98200e733 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -96,7 +96,7 @@ class FaissIndex(EmbeddingIndex): await self.kvstore.delete(f"{FAISS_INDEX_PREFIX}{self.bank_id}") - async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): + async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray, metadata: dict[str, Any] | None = None): # Add dimension check embedding_dim = embeddings.shape[1] if len(embeddings.shape) > 1 else embeddings.shape[0] if embedding_dim != self.index.d: @@ -234,6 +234,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None, + params: dict[str, Any] | None = None, ) -> None: index = self.cache.get(vector_db_id) if index is None: diff --git a/llama_stack/providers/inline/vector_io/qdrant/config.py b/llama_stack/providers/inline/vector_io/qdrant/config.py index 61d026984..7cc91d918 100644 --- a/llama_stack/providers/inline/vector_io/qdrant/config.py +++ b/llama_stack/providers/inline/vector_io/qdrant/config.py @@ -5,7 +5,7 @@ # the root directory of this source tree. -from typing import Any, Literal +from typing import Any from pydantic import BaseModel @@ -15,7 +15,6 @@ from llama_stack.schema_utils import json_schema_type @json_schema_type class QdrantVectorIOConfig(BaseModel): path: str - distance_metric: Literal["COSINE", "DOT", "EUCLID", "MANHATTAN"] = "COSINE" @classmethod def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]: diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index 7e977635a..58bf8fa8c 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -178,7 +178,9 @@ class SQLiteVecIndex(EmbeddingIndex): await asyncio.to_thread(_drop_tables) - async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray, batch_size: int = 500): + async def add_chunks( + self, chunks: list[Chunk], embeddings: NDArray, metadata: dict[str, Any] | None = None, batch_size: int = 500 + ): """ Add new chunks along with their embeddings using batch inserts. For each chunk, we insert its JSON into the metadata table and then insert its @@ -729,7 +731,13 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc await asyncio.to_thread(_delete) - async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None: + async def insert_chunks( + self, + vector_db_id: str, + chunks: list[Chunk], + ttl_seconds: int | None = None, + params: dict[str, Any] | None = None, + ) -> None: if vector_db_id not in self.cache: raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}") # The VectorDBWithIndex helper is expected to compute embeddings via the inference_api diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index 3bef39e9c..cc38553ff 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -55,7 +55,7 @@ class ChromaIndex(EmbeddingIndex): self.client = client self.collection = collection - async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): + async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray, metadata: dict[str, Any] | None = None): assert len(chunks) == len(embeddings), ( f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" ) diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index 182227a85..f697579fd 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -53,7 +53,7 @@ class MilvusIndex(EmbeddingIndex): if await asyncio.to_thread(self.client.has_collection, self.collection_name): await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name) - async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): + async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray, metadata: dict[str, Any] | None = None): assert len(chunks) == len(embeddings), ( f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" ) diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index c3cdef9b8..96f16146e 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -88,7 +88,7 @@ class PGVectorIndex(EmbeddingIndex): """ ) - async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): + async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray, metadata: dict[str, Any] | None = None): assert len(chunks) == len(embeddings), ( f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" ) diff --git a/llama_stack/providers/remote/vector_io/qdrant/config.py b/llama_stack/providers/remote/vector_io/qdrant/config.py index 0d8e08663..314d3f5f1 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/config.py +++ b/llama_stack/providers/remote/vector_io/qdrant/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Literal +from typing import Any from pydantic import BaseModel @@ -23,7 +23,6 @@ class QdrantVectorIOConfig(BaseModel): prefix: str | None = None timeout: int | None = None host: str | None = None - distance_metric: Literal["COSINE", "DOT", "EUCLID", "MANHATTAN"] = "COSINE" @classmethod def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]: diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index 8787a4900..3a73c57cd 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -18,17 +18,7 @@ from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import ( Chunk, QueryChunksResponse, - SearchRankingOptions, VectorIO, - VectorStoreChunkingStrategy, - VectorStoreDeleteResponse, - VectorStoreFileContentsResponse, - VectorStoreFileObject, - VectorStoreFileStatus, - VectorStoreListFilesResponse, - VectorStoreListResponse, - VectorStoreObject, - VectorStoreSearchResponsePage, ) from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig @@ -57,24 +47,41 @@ def convert_id(_id: str) -> str: class QdrantIndex(EmbeddingIndex): - def __init__(self, client: AsyncQdrantClient, collection_name: str, distance_metric: str = "COSINE"): + def __init__(self, client: AsyncQdrantClient, collection_name: str): self.client = client self.collection_name = collection_name - self.distance_metric = distance_metric + self._distance_metric = None # Will be set when collection is created - async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): + async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray, metadata: dict[str, Any] | None = None): assert len(chunks) == len(embeddings), ( f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" ) + # Extract distance_metric from metadata if provided, default to COSINE + distance_metric = "COSINE" # Default + if metadata is not None and "distance_metric" in metadata: + distance_metric = metadata["distance_metric"] + if not await self.client.collection_exists(self.collection_name): - # Get distance metric, defaulting to COSINE - distance = getattr(models.Distance, self.distance_metric, models.Distance.COSINE) + # Create collection with the specified distance metric + distance = getattr(models.Distance, distance_metric, models.Distance.COSINE) + self._distance_metric = distance_metric await self.client.create_collection( self.collection_name, vectors_config=models.VectorParams(size=len(embeddings[0]), distance=distance), ) + else: + # Collection already exists, warn if different distance metric was requested + if self._distance_metric is None: + # For now, assume COSINE as default since we can't easily extract it from collection info + self._distance_metric = "COSINE" + + if self._distance_metric != distance_metric: + log.warning( + f"Collection {self.collection_name} was created with distance metric '{self._distance_metric}', " + f"but '{distance_metric}' was requested. Using existing distance metric." + ) points = [] for _i, (chunk, embedding) in enumerate(zip(chunks, embeddings, strict=False)): @@ -90,6 +97,7 @@ class QdrantIndex(EmbeddingIndex): await self.client.upsert(collection_name=self.collection_name, points=points) async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + # Distance metric is set at collection creation and cannot be changed results = ( await self.client.query_points( collection_name=self.collection_name, @@ -170,9 +178,8 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP # Create metadata collection if it doesn't exist if not await self.client.collection_exists(metadata_collection): - # Get distance metric from config, defaulting to COSINE for backward compatibility - distance_metric = getattr(self.config, "distance_metric", "COSINE") - distance = getattr(models.Distance, distance_metric, models.Distance.COSINE) + # Use default distance metric for metadata collection + distance = models.Distance.COSINE await self.client.create_collection( collection_name=metadata_collection, @@ -226,13 +233,101 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP collection_name=metadata_collection, points_selector=models.PointIdsList(points=[convert_id(store_id)]) ) + async def _save_openai_vector_store_file( + self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]] + ) -> None: + """Save vector store file metadata to Qdrant collection metadata.""" + # Store file metadata in a special collection for vector store file metadata + file_metadata_collection = f"{OPENAI_VECTOR_STORES_METADATA_COLLECTION}_files" + + # Create file metadata collection if it doesn't exist + if not await self.client.collection_exists(file_metadata_collection): + distance = models.Distance.COSINE + await self.client.create_collection( + collection_name=file_metadata_collection, + vectors_config=models.VectorParams(size=1, distance=distance), + ) + + # Store file metadata as a point with dummy vector + file_key = f"{store_id}:{file_id}" + await self.client.upsert( + collection_name=file_metadata_collection, + points=[ + models.PointStruct( + id=convert_id(file_key), + vector=[0.0], # Dummy vector + payload={"file_info": file_info, "file_contents": file_contents}, + ) + ], + ) + + async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]: + """Load vector store file metadata from Qdrant.""" + file_metadata_collection = f"{OPENAI_VECTOR_STORES_METADATA_COLLECTION}_files" + + if not await self.client.collection_exists(file_metadata_collection): + return {} + + file_key = f"{store_id}:{file_id}" + points = await self.client.retrieve( + collection_name=file_metadata_collection, + ids=[convert_id(file_key)], + with_payload=True, + ) + + if points and points[0].payload and "file_info" in points[0].payload: + return points[0].payload["file_info"] + return {} + + async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]: + """Load vector store file contents from Qdrant.""" + file_metadata_collection = f"{OPENAI_VECTOR_STORES_METADATA_COLLECTION}_files" + + if not await self.client.collection_exists(file_metadata_collection): + return [] + + file_key = f"{store_id}:{file_id}" + points = await self.client.retrieve( + collection_name=file_metadata_collection, + ids=[convert_id(file_key)], + with_payload=True, + ) + + if points and points[0].payload and "file_contents" in points[0].payload: + return points[0].payload["file_contents"] + return [] + + async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None: + """Update vector store file metadata in Qdrant.""" + file_metadata_collection = f"{OPENAI_VECTOR_STORES_METADATA_COLLECTION}_files" + + if not await self.client.collection_exists(file_metadata_collection): + return + + # Get existing file contents + existing_contents = await self._load_openai_vector_store_file_contents(store_id, file_id) + + # Update with new file info but keep existing contents + await self._save_openai_vector_store_file(store_id, file_id, file_info, existing_contents) + + async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None: + """Delete vector store file metadata from Qdrant.""" + file_metadata_collection = f"{OPENAI_VECTOR_STORES_METADATA_COLLECTION}_files" + + if await self.client.collection_exists(file_metadata_collection): + file_key = f"{store_id}:{file_id}" + await self.client.delete( + collection_name=file_metadata_collection, + points_selector=models.PointIdsList(points=[convert_id(file_key)]), + ) + async def register_vector_db( self, vector_db: VectorDB, ) -> None: index = VectorDBWithIndex( vector_db=vector_db, - index=QdrantIndex(self.client, vector_db.identifier, self.config.distance_metric), + index=QdrantIndex(self.client, vector_db.identifier), inference_api=self.inference_api, ) @@ -253,9 +348,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP index = VectorDBWithIndex( vector_db=vector_db, - index=QdrantIndex( - client=self.client, collection_name=vector_db.identifier, distance_metric=self.config.distance_metric - ), + index=QdrantIndex(client=self.client, collection_name=vector_db.identifier), inference_api=self.inference_api, ) self.cache[vector_db_id] = index @@ -266,12 +359,23 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None, + params: dict[str, Any] | None = None, ) -> None: index = await self._get_and_cache_vector_db_index(vector_db_id) if not index: raise ValueError(f"Vector DB {vector_db_id} not found") - await index.insert_chunks(chunks) + # Extract distance_metric from params if provided + distance_metric = None + if params is not None: + distance_metric = params.get("distance_metric") + + # Create metadata dict with distance_metric if provided + metadata = None + if distance_metric is not None: + metadata = {"distance_metric": distance_metric} + + await index.insert_chunks(chunks, metadata=metadata) async def query_chunks( self, @@ -284,108 +388,3 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP raise ValueError(f"Vector DB {vector_db_id} not found") return await index.query_chunks(query, params) - - async def openai_create_vector_store( - self, - name: str, - file_ids: list[str] | None = None, - expires_after: dict[str, Any] | None = None, - chunking_strategy: dict[str, Any] | None = None, - metadata: dict[str, Any] | None = None, - embedding_model: str | None = None, - embedding_dimension: int | None = 384, - provider_id: str | None = None, - provider_vector_db_id: str | None = None, - ) -> VectorStoreObject: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") - - async def openai_list_vector_stores( - self, - limit: int | None = 20, - order: str | None = "desc", - after: str | None = None, - before: str | None = None, - ) -> VectorStoreListResponse: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") - - async def openai_retrieve_vector_store( - self, - vector_store_id: str, - ) -> VectorStoreObject: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") - - async def openai_update_vector_store( - self, - vector_store_id: str, - name: str | None = None, - expires_after: dict[str, Any] | None = None, - metadata: dict[str, Any] | None = None, - ) -> VectorStoreObject: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") - - async def openai_delete_vector_store( - self, - vector_store_id: str, - ) -> VectorStoreDeleteResponse: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") - - async def openai_search_vector_store( - self, - vector_store_id: str, - query: str | list[str], - filters: dict[str, Any] | None = None, - max_num_results: int | None = 10, - ranking_options: SearchRankingOptions | None = None, - rewrite_query: bool | None = False, - search_mode: str | None = "vector", - ) -> VectorStoreSearchResponsePage: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") - - async def openai_attach_file_to_vector_store( - self, - vector_store_id: str, - file_id: str, - attributes: dict[str, Any] | None = None, - chunking_strategy: VectorStoreChunkingStrategy | None = None, - ) -> VectorStoreFileObject: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") - - async def openai_list_files_in_vector_store( - self, - vector_store_id: str, - limit: int | None = 20, - order: str | None = "desc", - after: str | None = None, - before: str | None = None, - filter: VectorStoreFileStatus | None = None, - ) -> VectorStoreListFilesResponse: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") - - async def openai_retrieve_vector_store_file( - self, - vector_store_id: str, - file_id: str, - ) -> VectorStoreFileObject: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") - - async def openai_retrieve_vector_store_file_contents( - self, - vector_store_id: str, - file_id: str, - ) -> VectorStoreFileContentsResponse: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") - - async def openai_update_vector_store_file( - self, - vector_store_id: str, - file_id: str, - attributes: dict[str, Any] | None = None, - ) -> VectorStoreFileObject: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") - - async def openai_delete_vector_store_file( - self, - vector_store_id: str, - file_id: str, - ) -> VectorStoreFileObject: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index c63dd70c6..5ecfce31d 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -33,7 +33,7 @@ class WeaviateIndex(EmbeddingIndex): self.client = client self.collection_name = collection_name - async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): + async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray, metadata: dict[str, Any] | None = None): assert len(chunks) == len(embeddings), ( f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" ) diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index ab204a75a..225dda317 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -214,7 +214,7 @@ def _validate_embedding(embedding: NDArray, index: int, expected_dimension: int) class EmbeddingIndex(ABC): @abstractmethod - async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): + async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray, metadata: dict[str, Any] | None = None): raise NotImplementedError() @abstractmethod @@ -251,6 +251,7 @@ class VectorDBWithIndex: async def insert_chunks( self, chunks: list[Chunk], + distance_metric: str | None = None, ) -> None: chunks_to_embed = [] for i, c in enumerate(chunks): @@ -271,7 +272,13 @@ class VectorDBWithIndex: c.embedding = embedding embeddings = np.array([c.embedding for c in chunks], dtype=np.float32) - await self.index.add_chunks(chunks, embeddings) + + # Create metadata dict with distance_metric if provided + metadata = None + if distance_metric is not None: + metadata = {"distance_metric": distance_metric} + + await self.index.add_chunks(chunks, embeddings, metadata=metadata) async def query_chunks( self,