mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-13 14:42:50 +00:00
[WIP] Configurable distance_metric:
- Configurable distance_metric enabled for PGVector. - Added plumbing to support configuring more distance metrics for each vector provider.
This commit is contained in:
parent
658fb2c777
commit
9658581cf7
11 changed files with 187 additions and 18 deletions
|
|
@ -39,7 +39,11 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten
|
|||
|
||||
|
||||
class FaissIndex(EmbeddingIndex):
|
||||
def __init__(self, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
|
||||
def __init__(
|
||||
self, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None, distance_metric: str = "L2"
|
||||
):
|
||||
self._check_distance_metric_support(distance_metric)
|
||||
self.distance_metric = distance_metric
|
||||
self.index = faiss.IndexFlatL2(dimension)
|
||||
self.chunk_by_index: dict[int, Chunk] = {}
|
||||
self.kvstore = kvstore
|
||||
|
|
@ -51,8 +55,10 @@ class FaissIndex(EmbeddingIndex):
|
|||
self.chunk_ids: list[Any] = []
|
||||
|
||||
@classmethod
|
||||
async def create(cls, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
|
||||
instance = cls(dimension, kvstore, bank_id)
|
||||
async def create(
|
||||
cls, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None, distance_metric: str = "L2"
|
||||
):
|
||||
instance = cls(dimension, kvstore, bank_id, distance_metric)
|
||||
await instance.initialize()
|
||||
return instance
|
||||
|
||||
|
|
@ -175,6 +181,22 @@ class FaissIndex(EmbeddingIndex):
|
|||
"Hybrid search is not supported - underlying DB FAISS does not support this search mode"
|
||||
)
|
||||
|
||||
def _check_distance_metric_support(self, distance_metric: str) -> None:
|
||||
"""Check if the distance metric is supported by FAISS.
|
||||
|
||||
Args:
|
||||
distance_metric: The distance metric to check
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the distance metric is not supported yet
|
||||
"""
|
||||
if distance_metric != "L2":
|
||||
# TODO: Implement support for other distance metrics in FAISS
|
||||
raise NotImplementedError(
|
||||
f"Distance metric '{distance_metric}' is not yet supported by the FAISS provider. "
|
||||
f"Currently only 'L2' is supported. Please use 'L2' or switch to a different provider."
|
||||
)
|
||||
|
||||
|
||||
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
|
||||
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
|
||||
|
|
@ -229,9 +251,12 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
|
|||
await self.kvstore.set(key=key, value=vector_store.model_dump_json())
|
||||
|
||||
# Store in cache
|
||||
distance_metric = vector_store.distance_metric or "L2"
|
||||
self.cache[vector_store.identifier] = VectorStoreWithIndex(
|
||||
vector_store=vector_store,
|
||||
index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
|
||||
index=await FaissIndex.create(
|
||||
vector_store.embedding_dimension, self.kvstore, vector_store.identifier, distance_metric
|
||||
),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -75,7 +75,14 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
- An FTS5 table (fts_chunks_{bank_id}) for full-text keyword search.
|
||||
"""
|
||||
|
||||
def __init__(self, dimension: int, db_path: str, bank_id: str, kvstore: KVStore | None = None):
|
||||
def __init__(
|
||||
self,
|
||||
dimension: int,
|
||||
db_path: str,
|
||||
bank_id: str,
|
||||
kvstore: KVStore | None = None,
|
||||
distance_metric: str = "COSINE",
|
||||
):
|
||||
self.dimension = dimension
|
||||
self.db_path = db_path
|
||||
self.bank_id = bank_id
|
||||
|
|
@ -83,10 +90,12 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
self.vector_table = _make_sql_identifier(f"vec_chunks_{bank_id}")
|
||||
self.fts_table = _make_sql_identifier(f"fts_chunks_{bank_id}")
|
||||
self.kvstore = kvstore
|
||||
self._check_distance_metric_support(distance_metric)
|
||||
self.distance_metric = distance_metric
|
||||
|
||||
@classmethod
|
||||
async def create(cls, dimension: int, db_path: str, bank_id: str):
|
||||
instance = cls(dimension, db_path, bank_id)
|
||||
async def create(cls, dimension: int, db_path: str, bank_id: str, distance_metric: str = "COSINE"):
|
||||
instance = cls(dimension, db_path, bank_id, distance_metric=distance_metric)
|
||||
await instance.initialize()
|
||||
return instance
|
||||
|
||||
|
|
@ -373,6 +382,22 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
|
||||
await asyncio.to_thread(_delete_chunks)
|
||||
|
||||
def _check_distance_metric_support(self, distance_metric: str) -> None:
|
||||
"""Check if the distance metric is supported by SQLite-vec.
|
||||
|
||||
Args:
|
||||
distance_metric: The distance metric to check
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the distance metric is not supported yet
|
||||
"""
|
||||
if distance_metric != "COSINE":
|
||||
# TODO: Implement support for other distance metrics in SQLite-vec
|
||||
raise NotImplementedError(
|
||||
f"Distance metric '{distance_metric}' is not yet supported by the SQLite-vec provider. "
|
||||
f"Currently only 'COSINE' is supported. Please use 'COSINE' or switch to a different provider."
|
||||
)
|
||||
|
||||
|
||||
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
|
||||
"""
|
||||
|
|
@ -412,8 +437,9 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
|
|||
return [v.vector_store for v in self.cache.values()]
|
||||
|
||||
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
||||
distance_metric = vector_store.distance_metric or "COSINE"
|
||||
index = await SQLiteVecIndex.create(
|
||||
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
|
||||
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier, distance_metric
|
||||
)
|
||||
self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue