fix(vector-io): unify score calculation to use cosine and normalize to [0,1]

This commit is contained in:
ChristianZaccaria 2025-09-04 13:03:59 +01:00
parent 9618adba89
commit a0e0c7030b
9 changed files with 166 additions and 42 deletions

View file

@ -52,7 +52,7 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten
class FaissIndex(EmbeddingIndex):
def __init__(self, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
self.index = faiss.IndexFlatL2(dimension)
self.index = faiss.IndexFlatIP(dimension)
self.chunk_by_index: dict[int, Chunk] = {}
self.kvstore = kvstore
self.bank_id = bank_id
@ -122,8 +122,12 @@ class FaissIndex(EmbeddingIndex):
for i, chunk in enumerate(chunks):
self.chunk_by_index[indexlen + i] = chunk
# Normalize embeddings for cosine similarity
normalized_embeddings = np.array(embeddings).astype(np.float32)
faiss.normalize_L2(normalized_embeddings)
async with self.chunk_id_lock:
self.index.add(np.array(embeddings).astype(np.float32))
self.index.add(normalized_embeddings)
self.chunk_ids.extend([chunk.chunk_id for chunk in chunks])
# Save updated index
@ -160,18 +164,28 @@ class FaissIndex(EmbeddingIndex):
k: int,
score_threshold: float,
) -> QueryChunksResponse:
distances, indices = await asyncio.to_thread(self.index.search, embedding.reshape(1, -1).astype(np.float32), k)
logger.info(
f"FAISS VECTOR SEARCH CALLED: embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}"
)
# Normalize query embedding for cosine similarity
query_embedding = embedding.reshape(1, -1).astype(np.float32)
faiss.normalize_L2(query_embedding)
distances, indices = await asyncio.to_thread(self.index.search, query_embedding, k)
chunks = []
scores = []
for d, i in zip(distances[0], indices[0], strict=False):
if i < 0:
continue
score = 1.0 / float(d) if d != 0 else float("inf")
# For IndexFlatIP with normalized vectors, d is cosine similarity in [-1,1]
score = (float(d) + 1.0) / 2.0 # rescale to [0,1]
logger.info(f"Computed score {score} from distance {d} for chunk id {self.chunk_ids[int(i)]}")
if score < score_threshold:
continue
chunks.append(self.chunk_by_index[int(i)])
scores.append(score)
logger.info(f"FAISS VECTOR SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}")
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_keyword(
@ -241,7 +255,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
"""
try:
vector_dimension = 128 # sample dimension
faiss.IndexFlatL2(vector_dimension)
faiss.IndexFlatIP(vector_dimension)
return HealthResponse(status=HealthStatus.OK)
except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")

View file

@ -109,7 +109,7 @@ class SQLiteVecIndex(EmbeddingIndex):
# Create the virtual table for embeddings.
cur.execute(f"""
CREATE VIRTUAL TABLE IF NOT EXISTS [{self.vector_table}]
USING vec0(embedding FLOAT[{self.dimension}], id TEXT);
USING vec0(embedding FLOAT[{self.dimension}] distance_metric=cosine, id TEXT);
""")
connection.commit()
# FTS5 table (for keyword search) - creating both the tables by default. Will use the relevant one
@ -224,6 +224,9 @@ class SQLiteVecIndex(EmbeddingIndex):
"""
Performs vector-based search using a virtual table for vector similarity.
"""
logger.info(
f"SQLITE-VEC VECTOR SEARCH CALLED: embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}"
)
def _execute_query():
connection = _create_sqlite_connection(self.db_path)
@ -248,7 +251,10 @@ class SQLiteVecIndex(EmbeddingIndex):
chunks, scores = [], []
for row in rows:
_id, chunk_json, distance = row
score = 1.0 / distance if distance != 0 else float("inf")
distance = float(distance)
# Cosine distance range [0,2] -> normalized to [0,1]
score = 1.0 - (distance / 2.0)
logger.info(f"Computed score {score} from distance {distance} for chunk id {_id}")
if score < score_threshold:
continue
try:
@ -258,6 +264,8 @@ class SQLiteVecIndex(EmbeddingIndex):
continue
chunks.append(chunk)
scores.append(score)
logger.info(f"SQLITE-VEC VECTOR SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}")
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_keyword(

View file

@ -33,7 +33,7 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
log = get_logger(name=__name__, category="vector_io::chroma")
logger = get_logger(name=__name__, category="vector_io::chroma")
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI
@ -76,6 +76,9 @@ class ChromaIndex(EmbeddingIndex):
)
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
logger.info(
f"CHROMA VECTOR SEARCH CALLED: embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}"
)
results = await maybe_await(
self.collection.query(
query_embeddings=[embedding.tolist()],
@ -93,16 +96,19 @@ class ChromaIndex(EmbeddingIndex):
doc = json.loads(doc)
chunk = Chunk(**doc)
except Exception:
log.exception(f"Failed to parse document: {doc}")
logger.exception(f"Failed to parse document: {doc}")
continue
score = 1.0 / float(dist) if dist != 0 else float("inf")
# Cosine distance range [0,2] -> normalized to [0,1]
score = 1.0 - (float(dist) / 2.0)
logger.info(f"Computed score {score} from distance {dist} for chunk id {chunk.chunk_id}")
if score < score_threshold:
continue
chunks.append(chunk)
scores.append(score)
logger.info(f"CHROMA VECTOR SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}")
return QueryChunksResponse(chunks=chunks, scores=scores)
async def delete(self):
@ -140,7 +146,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
inference_api: Api.inference,
files_api: Files | None,
) -> None:
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
logger.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
self.config = config
self.inference_api = inference_api
self.client = None
@ -154,7 +160,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.vector_db_store = self.kvstore
if isinstance(self.config, RemoteChromaVectorIOConfig):
log.info(f"Connecting to Chroma server at: {self.config.url}")
logger.info(f"Connecting to Chroma server at: {self.config.url}")
url = self.config.url.rstrip("/")
parsed = urlparse(url)
@ -163,7 +169,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.client = await chromadb.AsyncHttpClient(host=parsed.hostname, port=parsed.port)
else:
log.info(f"Connecting to Chroma local db at: {self.config.db_path}")
logger.info(f"Connecting to Chroma local db at: {self.config.db_path}")
self.client = chromadb.PersistentClient(path=self.config.db_path)
self.openai_vector_stores = await self._load_openai_vector_stores()
@ -177,7 +183,10 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
collection = await maybe_await(
self.client.get_or_create_collection(
name=vector_db.identifier,
metadata={"vector_db": vector_db.model_dump_json()},
metadata={
"vector_db": vector_db.model_dump_json(),
"hnsw:space": "cosine", # Returns cosine distance
},
)
)
self.cache[vector_db.identifier] = VectorDBWithIndex(
@ -186,7 +195,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
async def unregister_vector_db(self, vector_db_id: str) -> None:
if vector_db_id not in self.cache:
log.warning(f"Vector DB {vector_db_id} not found")
logger.warning(f"Vector DB {vector_db_id} not found")
return
await self.cache[vector_db_id].index.delete()

View file

@ -153,6 +153,9 @@ class MilvusIndex(EmbeddingIndex):
raise e
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
logger.info(
f"MILVUS VECTOR SEARCH CALLED: embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}"
)
search_res = await asyncio.to_thread(
self.client.search,
collection_name=self.collection_name,
@ -162,8 +165,16 @@ class MilvusIndex(EmbeddingIndex):
output_fields=["*"],
search_params={"params": {"radius": score_threshold}},
)
chunks = [Chunk(**res["entity"]["chunk_content"]) for res in search_res[0]]
scores = [res["distance"] for res in search_res[0]]
chunks, scores = [], []
for res in search_res[0]:
score = float(res["distance"] + 1.0) / 2.0 # rescale to [0,1]
if score < score_threshold:
continue
chunks.append(Chunk(**res["entity"]["chunk_content"]))
scores.append(score)
logger.info(f"MILVUS VECTOR SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}")
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_keyword(

View file

@ -39,7 +39,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryA
from .config import PGVectorVectorIOConfig
log = get_logger(name=__name__, category="vector_io::pgvector")
logger = get_logger(name=__name__, category="vector_io::pgvector")
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::"
@ -132,7 +132,7 @@ class PGVectorIndex(EmbeddingIndex):
"""
)
except Exception as e:
log.exception(f"Error creating PGVectorIndex for vector_db: {self.vector_db.identifier}")
logger.exception(f"Error creating PGVectorIndex for vector_db: {self.vector_db.identifier}")
raise RuntimeError(f"Error creating PGVectorIndex for vector_db: {self.vector_db.identifier}") from e
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
@ -179,6 +179,9 @@ class PGVectorIndex(EmbeddingIndex):
Returns:
QueryChunksResponse with combined results
"""
logger.info(
f"PGVECTOR VECTOR SEARCH CALLED: embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}"
)
pgvector_search_function = self.get_pgvector_search_function()
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
@ -196,12 +199,15 @@ class PGVectorIndex(EmbeddingIndex):
chunks = []
scores = []
for doc, dist in results:
score = 1.0 / float(dist) if dist != 0 else float("inf")
# Cosine distance range [0,2] -> normalized to [0,1]
score = 1.0 - (float(dist) / 2.0)
logger.info(f"Computed score {score} from distance {dist}")
if score < score_threshold:
continue
chunks.append(Chunk(**doc))
scores.append(score)
logger.info(f"PGVECTOR VECTOR SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}")
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_keyword(
@ -356,7 +362,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
self.metadata_collection_name = "openai_vector_stores_metadata"
async def initialize(self) -> None:
log.info(f"Initializing PGVector memory adapter with config: {self.config}")
logger.info(f"Initializing PGVector memory adapter with config: {self.config}")
self.kvstore = await kvstore_impl(self.config.kvstore)
await self.initialize_openai_vector_stores()
@ -372,7 +378,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
version = check_extension_version(cur)
if version:
log.info(f"Vector extension version: {version}")
logger.info(f"Vector extension version: {version}")
else:
raise RuntimeError("Vector extension is not installed.")
@ -385,13 +391,13 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
"""
)
except Exception as e:
log.exception("Could not connect to PGVector database server")
logger.exception("Could not connect to PGVector database server")
raise RuntimeError("Could not connect to PGVector database server") from e
async def shutdown(self) -> None:
if self.conn is not None:
self.conn.close()
log.info("Connection to PGVector database server closed")
logger.info("Connection to PGVector database server closed")
async def register_vector_db(self, vector_db: VectorDB) -> None:
# Persist vector DB metadata in the KV store

View file

@ -36,7 +36,7 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
log = get_logger(name=__name__, category="vector_io::qdrant")
logger = get_logger(name=__name__, category="vector_io::qdrant")
CHUNK_ID_KEY = "_chunk_id"
# KV store prefixes for vector databases
@ -98,10 +98,13 @@ class QdrantIndex(EmbeddingIndex):
points_selector=models.PointIdsList(points=chunk_ids),
)
except Exception as e:
log.error(f"Error deleting chunks from Qdrant collection {self.collection_name}: {e}")
logger.error(f"Error deleting chunks from Qdrant collection {self.collection_name}: {e}")
raise
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
logger.info(
f"QDRANT VECTOR SEARCH CALLED: embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}"
)
results = (
await self.client.query_points(
collection_name=self.collection_name,
@ -120,12 +123,15 @@ class QdrantIndex(EmbeddingIndex):
try:
chunk = Chunk(**point.payload["chunk_content"])
except Exception:
log.exception("Failed to parse chunk")
logger.exception("Failed to parse chunk")
continue
chunks.append(chunk)
scores.append(point.score)
# Cosine similarity range [-1,1] -> normalized to [0,1]
scores.append((point.score + 1.0) / 2.0)
logger.info(f"Computed score {point.score} for chunk id {point.id}")
logger.info(f"QDRANT VECTOR SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}")
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_keyword(

View file

@ -34,7 +34,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
from .config import WeaviateVectorIOConfig
log = get_logger(name=__name__, category="vector_io::weaviate")
logger = get_logger(name=__name__, category="vector_io::weaviate")
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"
@ -88,6 +88,9 @@ class WeaviateIndex(EmbeddingIndex):
collection.data.delete_many(where=Filter.by_property("chunk_id").contains_any(chunk_ids))
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
logger.info(
f"WEAVIATE VECTOR SEARCH CALLED: embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}"
)
sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True)
collection = self.client.collections.get(sanitized_collection_name)
@ -105,16 +108,21 @@ class WeaviateIndex(EmbeddingIndex):
chunk_dict = json.loads(chunk_json)
chunk = Chunk(**chunk_dict)
except Exception:
log.exception(f"Failed to parse document: {chunk_json}")
logger.exception(f"Failed to parse document: {chunk_json}")
continue
score = 1.0 / doc.metadata.distance if doc.metadata.distance != 0 else float("inf")
if doc.metadata.distance is None:
continue
# Cosine distance range [0,2] -> normalized to [0,1]
score = 1.0 - (float(doc.metadata.distance) / 2.0)
logger.info(f"Computed score {score} from distance {doc.metadata.distance} for chunk id {chunk.chunk_id}")
if score < score_threshold:
continue
chunks.append(chunk)
scores.append(score)
logger.info(f"WEAVIATE VECTOR SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}")
return QueryChunksResponse(chunks=chunks, scores=scores)
async def delete(self, chunk_ids: list[str] | None = None) -> None:
@ -174,7 +182,7 @@ class WeaviateVectorIOAdapter(
def _get_client(self) -> weaviate.Client:
if "localhost" in self.config.weaviate_cluster_url:
log.info("using Weaviate locally in container")
logger.info("using Weaviate locally in container")
host, port = self.config.weaviate_cluster_url.split(":")
key = "local_test"
client = weaviate.connect_to_local(
@ -182,7 +190,7 @@ class WeaviateVectorIOAdapter(
port=port,
)
else:
log.info("Using Weaviate remote cluster with URL")
logger.info("Using Weaviate remote cluster with URL")
key = f"{self.config.weaviate_cluster_url}::{self.config.weaviate_api_key}"
if key in self.client_cache:
return self.client_cache[key]
@ -200,7 +208,7 @@ class WeaviateVectorIOAdapter(
self.kvstore = await kvstore_impl(self.config.kvstore)
else:
self.kvstore = None
log.info("No kvstore configured, registry will not persist across restarts")
logger.info("No kvstore configured, registry will not persist across restarts")
# Load existing vector DB definitions
if self.kvstore is not None:
@ -257,7 +265,7 @@ class WeaviateVectorIOAdapter(
client = self._get_client()
sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True)
if sanitized_collection_name not in self.cache or client.collections.exists(sanitized_collection_name) is False:
log.warning(f"Vector DB {sanitized_collection_name} not found")
logger.warning(f"Vector DB {sanitized_collection_name} not found")
return
client.collections.delete(sanitized_collection_name)
await self.cache[sanitized_collection_name].index.delete()

View file

@ -222,3 +222,63 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
assert len(response.chunks) > 0
assert response.chunks[0].metadata["document_id"] == "doc1"
assert response.chunks[0].metadata["source"] == "precomputed"
def test_vector_similarity_scores_are_normalized(
client_with_empty_registry, embedding_model_id, embedding_dimension, sample_chunks
):
"""Test that vector similarity scores are properly normalized to [0,1] range for all vector providers."""
vector_db_name = "test_score_normalization_db"
register_response = client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_name,
embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension,
)
actual_vector_db_id = register_response.identifier
# Insert sample chunks
client_with_empty_registry.vector_io.insert(
vector_db_id=actual_vector_db_id,
chunks=sample_chunks,
)
# Test various queries to ensure score normalization across different similarity levels
test_queries = [
# High similarity query that should match Python doc chunk
"Python programming language with readable code",
# Medium similarity query
"artificial intelligence and machine learning systems",
# Lower similarity query
"What is the capital of France?",
# High similarity query that should match neural networks chunk
"biological neural networks and artificial neurons",
# Very low similarity query to test edge case normalization
"xyzabc random nonsense gibberish qwerty asdfgh",
]
for query in test_queries:
response = client_with_empty_registry.vector_io.query(
vector_db_id=actual_vector_db_id,
query=query,
)
# Verify response structure
assert response is not None, f"Query '{query}' returned None response"
assert len(response.chunks) > 0, f"Query '{query}' returned no chunks"
assert len(response.scores) > 0, f"Query '{query}' returned no scores"
assert len(response.chunks) == len(response.scores), "Mismatch between chunks and scores count"
# Verify all scores are normalized to [0,1] range
for i, score in enumerate(response.scores):
assert isinstance(score, (int | float)), f"Score at index {i} is not numeric: {type(score)}"
assert 0.0 <= score <= 1.0, (
f"Score at index {i} is not normalized: {score} (should be in [0,1] range) for query '{query}'"
)
# Verify scores are in descending order (most similar first)
for i in range(1, len(response.scores)):
assert response.scores[i - 1] >= response.scores[i], (
f"Scores not in descending order at indices {i - 1} and {i}: "
f"{response.scores[i - 1]} >= {response.scores[i]} for query '{query}'"
)

View file

@ -112,14 +112,16 @@ async def faiss_adapter(faiss_config, mock_inference_api, mock_files_api) -> Fai
yield adapter
async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_identical(
async def test_faiss_query_vector_returns_perfect_score_when_query_and_embedding_are_identical(
faiss_index, sample_chunks, sample_embeddings, embedding_dimension
):
await faiss_index.add_chunks(sample_chunks, sample_embeddings)
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
with patch.object(faiss_index.index, "search") as mock_search:
mock_search.return_value = (np.array([[0.0, 0.1]]), np.array([[0, 1]]))
# IndexFlatIP with normalized vectors returns cosine similarity scores [-1,1]
# These will be normalized to [0,1] using (score + 1.0) / 2.0
mock_search.return_value = (np.array([[1.0, 0.6]]), np.array([[0, 1]]))
response = await faiss_index.query_vector(embedding=query_embedding, k=2, score_threshold=0.0)
@ -127,8 +129,8 @@ async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_
assert len(response.chunks) == 2
assert len(response.scores) == 2
assert response.scores[0] == float("inf") # infinity (1.0 / 0.0)
assert response.scores[1] == 10.0 # (1.0 / 0.1 = 10.0)
assert response.scores[0] == 1.0 # (1.0 + 1.0) / 2.0 = 1.0 (perfect similarity)
assert response.scores[1] == 0.8 # (0.6 + 1.0) / 2.0 = 0.8 (high similarity)
assert response.chunks[0] == sample_chunks[0]
assert response.chunks[1] == sample_chunks[1]
@ -141,7 +143,7 @@ async def test_health_success():
inference_api = MagicMock()
files_api = MagicMock()
with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatL2") as mock_index_flat:
with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatIP") as mock_index_flat:
mock_index_flat.return_value = MagicMock()
adapter = FaissVectorIOAdapter(config=config, inference_api=inference_api, files_api=files_api)
@ -153,7 +155,7 @@ async def test_health_success():
assert response["status"] == HealthStatus.OK
assert "message" not in response
# Verifying that IndexFlatL2 was called with the correct dimension
# Verifying that IndexFlatIP was called with the correct dimension
mock_index_flat.assert_called_once_with(128) # VECTOR_DIMENSION is 128
@ -164,7 +166,7 @@ async def test_health_failure():
inference_api = MagicMock()
files_api = MagicMock()
with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatL2") as mock_index_flat:
with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatIP") as mock_index_flat:
mock_index_flat.side_effect = Exception("Test error")
adapter = FaissVectorIOAdapter(config=config, inference_api=inference_api, files_api=files_api)