diff --git a/llama_stack/apis/vector_stores/vector_stores.py b/llama_stack/apis/vector_stores/vector_stores.py index 524624028..724d8f3f9 100644 --- a/llama_stack/apis/vector_stores/vector_stores.py +++ b/llama_stack/apis/vector_stores/vector_stores.py @@ -18,6 +18,7 @@ class VectorStore(Resource): :param type: Type of resource, always 'vector_store' for vector stores :param embedding_model: Name of the embedding model to use for vector generation :param embedding_dimension: Dimension of the embedding vectors + :param distance_metric: Distance metric for vector similarity calculations (e.g., 'COSINE', 'L2', 'INNER_PRODUCT') """ type: Literal[ResourceType.vector_store] = ResourceType.vector_store @@ -25,6 +26,7 @@ class VectorStore(Resource): embedding_model: str embedding_dimension: int vector_store_name: str | None = None + distance_metric: str | None = None @property def vector_store_id(self) -> str: @@ -42,6 +44,7 @@ class VectorStoreInput(BaseModel): :param embedding_model: Name of the embedding model to use for vector generation :param embedding_dimension: Dimension of the embedding vectors :param provider_vector_store_id: (Optional) Provider-specific identifier for the vector store + :param distance_metric: (Optional) Distance metric for vector similarity calculations """ vector_store_id: str @@ -49,3 +52,4 @@ class VectorStoreInput(BaseModel): embedding_dimension: int provider_id: str | None = None provider_vector_store_id: str | None = None + distance_metric: str | None = None diff --git a/llama_stack/core/routers/vector_io.py b/llama_stack/core/routers/vector_io.py index 2b1701dc2..128ad6b3f 100644 --- a/llama_stack/core/routers/vector_io.py +++ b/llama_stack/core/routers/vector_io.py @@ -105,6 +105,7 @@ class VectorIORouter(VectorIO): embedding_model = extra.get("embedding_model") embedding_dimension = extra.get("embedding_dimension") provider_id = extra.get("provider_id") + distance_metric = extra.get("distance_metric") # Use default embedding model if not specified if ( @@ -154,6 +155,7 @@ class VectorIORouter(VectorIO): provider_id=provider_id, provider_vector_store_id=vector_store_id, vector_store_name=params.name, + distance_metric=distance_metric, ) provider = await self.routing_table.get_provider_impl(registered_vector_store.identifier) @@ -162,6 +164,8 @@ class VectorIORouter(VectorIO): params.model_extra = {} params.model_extra["provider_vector_store_id"] = registered_vector_store.provider_resource_id params.model_extra["provider_id"] = registered_vector_store.provider_id + if distance_metric is not None: + params.model_extra["distance_metric"] = distance_metric if embedding_model is not None: params.model_extra["embedding_model"] = embedding_model if embedding_dimension is not None: diff --git a/llama_stack/core/routing_tables/vector_stores.py b/llama_stack/core/routing_tables/vector_stores.py index c6c80a01e..08f43370a 100644 --- a/llama_stack/core/routing_tables/vector_stores.py +++ b/llama_stack/core/routing_tables/vector_stores.py @@ -49,6 +49,7 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl): provider_id: str | None = None, provider_vector_store_id: str | None = None, vector_store_name: str | None = None, + distance_metric: str | None = None, ) -> Any: if provider_id is None: if len(self.impls_by_provider_id) > 0: @@ -73,6 +74,7 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl): embedding_model=embedding_model, embedding_dimension=embedding_dimension, vector_store_name=vector_store_name, + distance_metric=distance_metric, ) await self.register_object(vector_store) return vector_store diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index 5e33d4ca3..b56a9289e 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -39,7 +39,11 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten class FaissIndex(EmbeddingIndex): - def __init__(self, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None): + def __init__( + self, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None, distance_metric: str = "L2" + ): + self._check_distance_metric_support(distance_metric) + self.distance_metric = distance_metric self.index = faiss.IndexFlatL2(dimension) self.chunk_by_index: dict[int, Chunk] = {} self.kvstore = kvstore @@ -51,8 +55,10 @@ class FaissIndex(EmbeddingIndex): self.chunk_ids: list[Any] = [] @classmethod - async def create(cls, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None): - instance = cls(dimension, kvstore, bank_id) + async def create( + cls, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None, distance_metric: str = "L2" + ): + instance = cls(dimension, kvstore, bank_id, distance_metric) await instance.initialize() return instance @@ -175,6 +181,22 @@ class FaissIndex(EmbeddingIndex): "Hybrid search is not supported - underlying DB FAISS does not support this search mode" ) + def _check_distance_metric_support(self, distance_metric: str) -> None: + """Check if the distance metric is supported by FAISS. + + Args: + distance_metric: The distance metric to check + + Raises: + NotImplementedError: If the distance metric is not supported yet + """ + if distance_metric != "L2": + # TODO: Implement support for other distance metrics in FAISS + raise NotImplementedError( + f"Distance metric '{distance_metric}' is not yet supported by the FAISS provider. " + f"Currently only 'L2' is supported. Please use 'L2' or switch to a different provider." + ) + class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate): def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None: @@ -229,9 +251,12 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco await self.kvstore.set(key=key, value=vector_store.model_dump_json()) # Store in cache + distance_metric = vector_store.distance_metric or "L2" self.cache[vector_store.identifier] = VectorStoreWithIndex( vector_store=vector_store, - index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier), + index=await FaissIndex.create( + vector_store.embedding_dimension, self.kvstore, vector_store.identifier, distance_metric + ), inference_api=self.inference_api, ) diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index 37294f173..ab35635b6 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -75,7 +75,14 @@ class SQLiteVecIndex(EmbeddingIndex): - An FTS5 table (fts_chunks_{bank_id}) for full-text keyword search. """ - def __init__(self, dimension: int, db_path: str, bank_id: str, kvstore: KVStore | None = None): + def __init__( + self, + dimension: int, + db_path: str, + bank_id: str, + kvstore: KVStore | None = None, + distance_metric: str = "COSINE", + ): self.dimension = dimension self.db_path = db_path self.bank_id = bank_id @@ -83,10 +90,12 @@ class SQLiteVecIndex(EmbeddingIndex): self.vector_table = _make_sql_identifier(f"vec_chunks_{bank_id}") self.fts_table = _make_sql_identifier(f"fts_chunks_{bank_id}") self.kvstore = kvstore + self._check_distance_metric_support(distance_metric) + self.distance_metric = distance_metric @classmethod - async def create(cls, dimension: int, db_path: str, bank_id: str): - instance = cls(dimension, db_path, bank_id) + async def create(cls, dimension: int, db_path: str, bank_id: str, distance_metric: str = "COSINE"): + instance = cls(dimension, db_path, bank_id, distance_metric=distance_metric) await instance.initialize() return instance @@ -373,6 +382,22 @@ class SQLiteVecIndex(EmbeddingIndex): await asyncio.to_thread(_delete_chunks) + def _check_distance_metric_support(self, distance_metric: str) -> None: + """Check if the distance metric is supported by SQLite-vec. + + Args: + distance_metric: The distance metric to check + + Raises: + NotImplementedError: If the distance metric is not supported yet + """ + if distance_metric != "COSINE": + # TODO: Implement support for other distance metrics in SQLite-vec + raise NotImplementedError( + f"Distance metric '{distance_metric}' is not yet supported by the SQLite-vec provider. " + f"Currently only 'COSINE' is supported. Please use 'COSINE' or switch to a different provider." + ) + class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate): """ @@ -412,8 +437,9 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro return [v.vector_store for v in self.cache.values()] async def register_vector_store(self, vector_store: VectorStore) -> None: + distance_metric = vector_store.distance_metric or "COSINE" index = await SQLiteVecIndex.create( - vector_store.embedding_dimension, self.config.db_path, vector_store.identifier + vector_store.embedding_dimension, self.config.db_path, vector_store.identifier, distance_metric ) self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api) diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index 2663ad43e..3e1a9ea9e 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -45,10 +45,14 @@ async def maybe_await(result): class ChromaIndex(EmbeddingIndex): - def __init__(self, client: ChromaClientType, collection, kvstore: KVStore | None = None): + def __init__( + self, client: ChromaClientType, collection, kvstore: KVStore | None = None, distance_metric: str = "COSINE" + ): self.client = client self.collection = collection self.kvstore = kvstore + self._check_distance_metric_support(distance_metric) + self.distance_metric = distance_metric async def initialize(self): pass @@ -102,6 +106,22 @@ class ChromaIndex(EmbeddingIndex): ids = [f"{chunk.document_id}:{chunk.chunk_id}" for chunk in chunks_for_deletion] await maybe_await(self.collection.delete(ids=ids)) + def _check_distance_metric_support(self, distance_metric: str) -> None: + """Check if the distance metric is supported by Chroma. + + Args: + distance_metric: The distance metric to check + + Raises: + NotImplementedError: If the distance metric is not supported yet + """ + if distance_metric != "COSINE": + # TODO: Implement support for other distance metrics in Chroma + raise NotImplementedError( + f"Distance metric '{distance_metric}' is not yet supported by the Chroma provider. " + f"Currently only 'COSINE' is supported. Please use 'COSINE' or switch to a different provider." + ) + async def query_hybrid( self, embedding: NDArray, @@ -157,8 +177,9 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc name=vector_store.identifier, metadata={"vector_store": vector_store.model_dump_json()} ) ) + distance_metric = vector_store.distance_metric or "COSINE" self.cache[vector_store.identifier] = VectorStoreWithIndex( - vector_store, ChromaIndex(self.client, collection), self.inference_api + vector_store, ChromaIndex(self.client, collection, distance_metric=distance_metric), self.inference_api ) async def unregister_vector_store(self, vector_store_id: str) -> None: diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index cccf13816..2be3c5cae 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -44,12 +44,19 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten class MilvusIndex(EmbeddingIndex): def __init__( - self, client: MilvusClient, collection_name: str, consistency_level="Strong", kvstore: KVStore | None = None + self, + client: MilvusClient, + collection_name: str, + consistency_level="Strong", + kvstore: KVStore | None = None, + distance_metric: str = "COSINE", ): self.client = client self.collection_name = sanitize_collection_name(collection_name) self.consistency_level = consistency_level self.kvstore = kvstore + self._check_distance_metric_support(distance_metric) + self.distance_metric = distance_metric async def initialize(self): # MilvusIndex does not require explicit initialization @@ -260,6 +267,22 @@ class MilvusIndex(EmbeddingIndex): logger.error(f"Error deleting chunks from Milvus collection {self.collection_name}: {e}") raise + def _check_distance_metric_support(self, distance_metric: str) -> None: + """Check if the distance metric is supported by Milvus. + + Args: + distance_metric: The distance metric to check + + Raises: + NotImplementedError: If the distance metric is not supported yet + """ + if distance_metric != "COSINE": + # TODO: Implement support for other distance metrics in Milvus + raise NotImplementedError( + f"Distance metric '{distance_metric}' is not yet supported by the Milvus provider. " + f"Currently only 'COSINE' is supported. Please use 'COSINE' or switch to a different provider." + ) + class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate): def __init__( @@ -316,9 +339,15 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc consistency_level = self.config.consistency_level else: consistency_level = "Strong" + distance_metric = vector_store.distance_metric or "COSINE" index = VectorStoreWithIndex( vector_store=vector_store, - index=MilvusIndex(self.client, vector_store.identifier, consistency_level=consistency_level), + index=MilvusIndex( + self.client, + vector_store.identifier, + consistency_level=consistency_level, + distance_metric=distance_metric, + ), inference_api=self.inference_api, ) diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index f28bd3cd9..83ceb3835 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -382,8 +382,13 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt upsert_models(self.conn, [(vector_store.identifier, vector_store)]) # Create and cache the PGVector index table for the vector DB + distance_metric = vector_store.distance_metric or "COSINE" # Default to COSINE if not specified pgvector_index = PGVectorIndex( - vector_store=vector_store, dimension=vector_store.embedding_dimension, conn=self.conn, kvstore=self.kvstore + vector_store=vector_store, + dimension=vector_store.embedding_dimension, + conn=self.conn, + kvstore=self.kvstore, + distance_metric=distance_metric, ) await pgvector_index.initialize() index = VectorStoreWithIndex(vector_store, index=pgvector_index, inference_api=self.inference_api) @@ -420,7 +425,10 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt if not vector_store: raise VectorStoreNotFoundError(vector_store_id) - index = PGVectorIndex(vector_store, vector_store.embedding_dimension, self.conn) + distance_metric = vector_store.distance_metric or "COSINE" # Default to COSINE if not specified + index = PGVectorIndex( + vector_store, vector_store.embedding_dimension, self.conn, distance_metric=distance_metric + ) await index.initialize() self.cache[vector_store_id] = VectorStoreWithIndex(vector_store, index, self.inference_api) return self.cache[vector_store_id] diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index 93d0894a6..090acf591 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -56,9 +56,11 @@ def convert_id(_id: str) -> str: class QdrantIndex(EmbeddingIndex): - def __init__(self, client: AsyncQdrantClient, collection_name: str): + def __init__(self, client: AsyncQdrantClient, collection_name: str, distance_metric: str = "COSINE"): self.client = client self.collection_name = collection_name + self._check_distance_metric_support(distance_metric) + self.distance_metric = distance_metric async def initialize(self) -> None: # Qdrant collections are created on-demand in add_chunks @@ -144,6 +146,22 @@ class QdrantIndex(EmbeddingIndex): async def delete(self): await self.client.delete_collection(collection_name=self.collection_name) + def _check_distance_metric_support(self, distance_metric: str) -> None: + """Check if the distance metric is supported by Qdrant. + + Args: + distance_metric: The distance metric to check + + Raises: + NotImplementedError: If the distance metric is not supported yet + """ + if distance_metric != "COSINE": + # TODO: Implement support for other distance metrics in Qdrant + raise NotImplementedError( + f"Distance metric '{distance_metric}' is not yet supported by the Qdrant provider. " + f"Currently only 'COSINE' is supported. Please use 'COSINE' or switch to a different provider." + ) + class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate): def __init__( @@ -187,9 +205,10 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc key = f"{VECTOR_DBS_PREFIX}{vector_store.identifier}" await self.kvstore.set(key=key, value=vector_store.model_dump_json()) + distance_metric = vector_store.distance_metric or "COSINE" index = VectorStoreWithIndex( vector_store=vector_store, - index=QdrantIndex(self.client, vector_store.identifier), + index=QdrantIndex(self.client, vector_store.identifier, distance_metric=distance_metric), inference_api=self.inference_api, ) diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index 66922aa3f..3576213f3 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -45,10 +45,18 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten class WeaviateIndex(EmbeddingIndex): - def __init__(self, client: weaviate.WeaviateClient, collection_name: str, kvstore: KVStore | None = None): + def __init__( + self, + client: weaviate.WeaviateClient, + collection_name: str, + kvstore: KVStore | None = None, + distance_metric: str = "COSINE", + ): self.client = client self.collection_name = sanitize_collection_name(collection_name, weaviate_format=True) self.kvstore = kvstore + self._check_distance_metric_support(distance_metric) + self.distance_metric = distance_metric async def initialize(self): pass @@ -82,6 +90,22 @@ class WeaviateIndex(EmbeddingIndex): chunk_ids = [chunk.chunk_id for chunk in chunks_for_deletion] collection.data.delete_many(where=Filter.by_property("chunk_id").contains_any(chunk_ids)) + def _check_distance_metric_support(self, distance_metric: str) -> None: + """Check if the distance metric is supported by Weaviate. + + Args: + distance_metric: The distance metric to check + + Raises: + NotImplementedError: If the distance metric is not supported yet + """ + if distance_metric != "COSINE": + # TODO: Implement support for other distance metrics in Weaviate + raise NotImplementedError( + f"Distance metric '{distance_metric}' is not yet supported by the Weaviate provider. " + f"Currently only 'COSINE' is supported. Please use 'COSINE' or switch to a different provider." + ) + async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: """ Performs vector search using Weaviate's built-in vector search. @@ -329,8 +353,11 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv ], ) + distance_metric = vector_store.distance_metric or "COSINE" self.cache[vector_store.identifier] = VectorStoreWithIndex( - vector_store, WeaviateIndex(client=client, collection_name=sanitized_collection_name), self.inference_api + vector_store, + WeaviateIndex(client=client, collection_name=sanitized_collection_name, distance_metric=distance_metric), + self.inference_api, ) async def unregister_vector_store(self, vector_store_id: str) -> None: diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index 8f9fb9fb4..c5139ee87 100644 --- a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -392,6 +392,9 @@ class OpenAIVectorStoreMixin(ABC): if provider_id is None: raise ValueError("Provider ID is required but was not provided") + # Extract distance_metric from extra_body if provided + distance_metric = extra_body.get("distance_metric") + # call to the provider to create any index, etc. vector_store = VectorStore( identifier=vector_store_id, @@ -400,6 +403,7 @@ class OpenAIVectorStoreMixin(ABC): provider_id=provider_id, provider_resource_id=vector_store_id, vector_store_name=params.name, + distance_metric=distance_metric, ) await self.register_vector_store(vector_store)