mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
fix(vector-io): unify score calculation to use cosine and normalize to [0,1]
This commit is contained in:
parent
9618adba89
commit
a0e0c7030b
9 changed files with 166 additions and 42 deletions
|
@ -52,7 +52,7 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten
|
|||
|
||||
class FaissIndex(EmbeddingIndex):
|
||||
def __init__(self, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
|
||||
self.index = faiss.IndexFlatL2(dimension)
|
||||
self.index = faiss.IndexFlatIP(dimension)
|
||||
self.chunk_by_index: dict[int, Chunk] = {}
|
||||
self.kvstore = kvstore
|
||||
self.bank_id = bank_id
|
||||
|
@ -122,8 +122,12 @@ class FaissIndex(EmbeddingIndex):
|
|||
for i, chunk in enumerate(chunks):
|
||||
self.chunk_by_index[indexlen + i] = chunk
|
||||
|
||||
# Normalize embeddings for cosine similarity
|
||||
normalized_embeddings = np.array(embeddings).astype(np.float32)
|
||||
faiss.normalize_L2(normalized_embeddings)
|
||||
|
||||
async with self.chunk_id_lock:
|
||||
self.index.add(np.array(embeddings).astype(np.float32))
|
||||
self.index.add(normalized_embeddings)
|
||||
self.chunk_ids.extend([chunk.chunk_id for chunk in chunks])
|
||||
|
||||
# Save updated index
|
||||
|
@ -160,18 +164,28 @@ class FaissIndex(EmbeddingIndex):
|
|||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
distances, indices = await asyncio.to_thread(self.index.search, embedding.reshape(1, -1).astype(np.float32), k)
|
||||
logger.info(
|
||||
f"FAISS VECTOR SEARCH CALLED: embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}"
|
||||
)
|
||||
# Normalize query embedding for cosine similarity
|
||||
query_embedding = embedding.reshape(1, -1).astype(np.float32)
|
||||
faiss.normalize_L2(query_embedding)
|
||||
|
||||
distances, indices = await asyncio.to_thread(self.index.search, query_embedding, k)
|
||||
chunks = []
|
||||
scores = []
|
||||
for d, i in zip(distances[0], indices[0], strict=False):
|
||||
if i < 0:
|
||||
continue
|
||||
score = 1.0 / float(d) if d != 0 else float("inf")
|
||||
# For IndexFlatIP with normalized vectors, d is cosine similarity in [-1,1]
|
||||
score = (float(d) + 1.0) / 2.0 # rescale to [0,1]
|
||||
logger.info(f"Computed score {score} from distance {d} for chunk id {self.chunk_ids[int(i)]}")
|
||||
if score < score_threshold:
|
||||
continue
|
||||
chunks.append(self.chunk_by_index[int(i)])
|
||||
scores.append(score)
|
||||
|
||||
logger.info(f"FAISS VECTOR SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}")
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
|
@ -241,7 +255,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
|||
"""
|
||||
try:
|
||||
vector_dimension = 128 # sample dimension
|
||||
faiss.IndexFlatL2(vector_dimension)
|
||||
faiss.IndexFlatIP(vector_dimension)
|
||||
return HealthResponse(status=HealthStatus.OK)
|
||||
except Exception as e:
|
||||
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
||||
|
|
|
@ -109,7 +109,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
# Create the virtual table for embeddings.
|
||||
cur.execute(f"""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS [{self.vector_table}]
|
||||
USING vec0(embedding FLOAT[{self.dimension}], id TEXT);
|
||||
USING vec0(embedding FLOAT[{self.dimension}] distance_metric=cosine, id TEXT);
|
||||
""")
|
||||
connection.commit()
|
||||
# FTS5 table (for keyword search) - creating both the tables by default. Will use the relevant one
|
||||
|
@ -224,6 +224,9 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
"""
|
||||
Performs vector-based search using a virtual table for vector similarity.
|
||||
"""
|
||||
logger.info(
|
||||
f"SQLITE-VEC VECTOR SEARCH CALLED: embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}"
|
||||
)
|
||||
|
||||
def _execute_query():
|
||||
connection = _create_sqlite_connection(self.db_path)
|
||||
|
@ -248,7 +251,10 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
chunks, scores = [], []
|
||||
for row in rows:
|
||||
_id, chunk_json, distance = row
|
||||
score = 1.0 / distance if distance != 0 else float("inf")
|
||||
distance = float(distance)
|
||||
# Cosine distance range [0,2] -> normalized to [0,1]
|
||||
score = 1.0 - (distance / 2.0)
|
||||
logger.info(f"Computed score {score} from distance {distance} for chunk id {_id}")
|
||||
if score < score_threshold:
|
||||
continue
|
||||
try:
|
||||
|
@ -258,6 +264,8 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
continue
|
||||
chunks.append(chunk)
|
||||
scores.append(score)
|
||||
|
||||
logger.info(f"SQLITE-VEC VECTOR SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}")
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
|
|
|
@ -33,7 +33,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
|
||||
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
|
||||
|
||||
log = get_logger(name=__name__, category="vector_io::chroma")
|
||||
logger = get_logger(name=__name__, category="vector_io::chroma")
|
||||
|
||||
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI
|
||||
|
||||
|
@ -76,6 +76,9 @@ class ChromaIndex(EmbeddingIndex):
|
|||
)
|
||||
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
logger.info(
|
||||
f"CHROMA VECTOR SEARCH CALLED: embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}"
|
||||
)
|
||||
results = await maybe_await(
|
||||
self.collection.query(
|
||||
query_embeddings=[embedding.tolist()],
|
||||
|
@ -93,16 +96,19 @@ class ChromaIndex(EmbeddingIndex):
|
|||
doc = json.loads(doc)
|
||||
chunk = Chunk(**doc)
|
||||
except Exception:
|
||||
log.exception(f"Failed to parse document: {doc}")
|
||||
logger.exception(f"Failed to parse document: {doc}")
|
||||
continue
|
||||
|
||||
score = 1.0 / float(dist) if dist != 0 else float("inf")
|
||||
# Cosine distance range [0,2] -> normalized to [0,1]
|
||||
score = 1.0 - (float(dist) / 2.0)
|
||||
logger.info(f"Computed score {score} from distance {dist} for chunk id {chunk.chunk_id}")
|
||||
if score < score_threshold:
|
||||
continue
|
||||
|
||||
chunks.append(chunk)
|
||||
scores.append(score)
|
||||
|
||||
logger.info(f"CHROMA VECTOR SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}")
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def delete(self):
|
||||
|
@ -140,7 +146,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
inference_api: Api.inference,
|
||||
files_api: Files | None,
|
||||
) -> None:
|
||||
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
|
||||
logger.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.client = None
|
||||
|
@ -154,7 +160,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self.vector_db_store = self.kvstore
|
||||
|
||||
if isinstance(self.config, RemoteChromaVectorIOConfig):
|
||||
log.info(f"Connecting to Chroma server at: {self.config.url}")
|
||||
logger.info(f"Connecting to Chroma server at: {self.config.url}")
|
||||
url = self.config.url.rstrip("/")
|
||||
parsed = urlparse(url)
|
||||
|
||||
|
@ -163,7 +169,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
|
||||
self.client = await chromadb.AsyncHttpClient(host=parsed.hostname, port=parsed.port)
|
||||
else:
|
||||
log.info(f"Connecting to Chroma local db at: {self.config.db_path}")
|
||||
logger.info(f"Connecting to Chroma local db at: {self.config.db_path}")
|
||||
self.client = chromadb.PersistentClient(path=self.config.db_path)
|
||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||
|
||||
|
@ -177,7 +183,10 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
collection = await maybe_await(
|
||||
self.client.get_or_create_collection(
|
||||
name=vector_db.identifier,
|
||||
metadata={"vector_db": vector_db.model_dump_json()},
|
||||
metadata={
|
||||
"vector_db": vector_db.model_dump_json(),
|
||||
"hnsw:space": "cosine", # Returns cosine distance
|
||||
},
|
||||
)
|
||||
)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||
|
@ -186,7 +195,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
if vector_db_id not in self.cache:
|
||||
log.warning(f"Vector DB {vector_db_id} not found")
|
||||
logger.warning(f"Vector DB {vector_db_id} not found")
|
||||
return
|
||||
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
|
|
|
@ -153,6 +153,9 @@ class MilvusIndex(EmbeddingIndex):
|
|||
raise e
|
||||
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
logger.info(
|
||||
f"MILVUS VECTOR SEARCH CALLED: embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}"
|
||||
)
|
||||
search_res = await asyncio.to_thread(
|
||||
self.client.search,
|
||||
collection_name=self.collection_name,
|
||||
|
@ -162,8 +165,16 @@ class MilvusIndex(EmbeddingIndex):
|
|||
output_fields=["*"],
|
||||
search_params={"params": {"radius": score_threshold}},
|
||||
)
|
||||
chunks = [Chunk(**res["entity"]["chunk_content"]) for res in search_res[0]]
|
||||
scores = [res["distance"] for res in search_res[0]]
|
||||
|
||||
chunks, scores = [], []
|
||||
for res in search_res[0]:
|
||||
score = float(res["distance"] + 1.0) / 2.0 # rescale to [0,1]
|
||||
if score < score_threshold:
|
||||
continue
|
||||
chunks.append(Chunk(**res["entity"]["chunk_content"]))
|
||||
scores.append(score)
|
||||
|
||||
logger.info(f"MILVUS VECTOR SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}")
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
|
|
|
@ -39,7 +39,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryA
|
|||
|
||||
from .config import PGVectorVectorIOConfig
|
||||
|
||||
log = get_logger(name=__name__, category="vector_io::pgvector")
|
||||
logger = get_logger(name=__name__, category="vector_io::pgvector")
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::"
|
||||
|
@ -132,7 +132,7 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
"""
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(f"Error creating PGVectorIndex for vector_db: {self.vector_db.identifier}")
|
||||
logger.exception(f"Error creating PGVectorIndex for vector_db: {self.vector_db.identifier}")
|
||||
raise RuntimeError(f"Error creating PGVectorIndex for vector_db: {self.vector_db.identifier}") from e
|
||||
|
||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
||||
|
@ -179,6 +179,9 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
Returns:
|
||||
QueryChunksResponse with combined results
|
||||
"""
|
||||
logger.info(
|
||||
f"PGVECTOR VECTOR SEARCH CALLED: embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}"
|
||||
)
|
||||
pgvector_search_function = self.get_pgvector_search_function()
|
||||
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
|
@ -196,12 +199,15 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
chunks = []
|
||||
scores = []
|
||||
for doc, dist in results:
|
||||
score = 1.0 / float(dist) if dist != 0 else float("inf")
|
||||
# Cosine distance range [0,2] -> normalized to [0,1]
|
||||
score = 1.0 - (float(dist) / 2.0)
|
||||
logger.info(f"Computed score {score} from distance {dist}")
|
||||
if score < score_threshold:
|
||||
continue
|
||||
chunks.append(Chunk(**doc))
|
||||
scores.append(score)
|
||||
|
||||
logger.info(f"PGVECTOR VECTOR SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}")
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
|
@ -356,7 +362,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
|||
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
log.info(f"Initializing PGVector memory adapter with config: {self.config}")
|
||||
logger.info(f"Initializing PGVector memory adapter with config: {self.config}")
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
await self.initialize_openai_vector_stores()
|
||||
|
||||
|
@ -372,7 +378,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
|||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
version = check_extension_version(cur)
|
||||
if version:
|
||||
log.info(f"Vector extension version: {version}")
|
||||
logger.info(f"Vector extension version: {version}")
|
||||
else:
|
||||
raise RuntimeError("Vector extension is not installed.")
|
||||
|
||||
|
@ -385,13 +391,13 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
|||
"""
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception("Could not connect to PGVector database server")
|
||||
logger.exception("Could not connect to PGVector database server")
|
||||
raise RuntimeError("Could not connect to PGVector database server") from e
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self.conn is not None:
|
||||
self.conn.close()
|
||||
log.info("Connection to PGVector database server closed")
|
||||
logger.info("Connection to PGVector database server closed")
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
# Persist vector DB metadata in the KV store
|
||||
|
|
|
@ -36,7 +36,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
|
||||
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
|
||||
|
||||
log = get_logger(name=__name__, category="vector_io::qdrant")
|
||||
logger = get_logger(name=__name__, category="vector_io::qdrant")
|
||||
CHUNK_ID_KEY = "_chunk_id"
|
||||
|
||||
# KV store prefixes for vector databases
|
||||
|
@ -98,10 +98,13 @@ class QdrantIndex(EmbeddingIndex):
|
|||
points_selector=models.PointIdsList(points=chunk_ids),
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting chunks from Qdrant collection {self.collection_name}: {e}")
|
||||
logger.error(f"Error deleting chunks from Qdrant collection {self.collection_name}: {e}")
|
||||
raise
|
||||
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
logger.info(
|
||||
f"QDRANT VECTOR SEARCH CALLED: embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}"
|
||||
)
|
||||
results = (
|
||||
await self.client.query_points(
|
||||
collection_name=self.collection_name,
|
||||
|
@ -120,12 +123,15 @@ class QdrantIndex(EmbeddingIndex):
|
|||
try:
|
||||
chunk = Chunk(**point.payload["chunk_content"])
|
||||
except Exception:
|
||||
log.exception("Failed to parse chunk")
|
||||
logger.exception("Failed to parse chunk")
|
||||
continue
|
||||
|
||||
chunks.append(chunk)
|
||||
scores.append(point.score)
|
||||
# Cosine similarity range [-1,1] -> normalized to [0,1]
|
||||
scores.append((point.score + 1.0) / 2.0)
|
||||
logger.info(f"Computed score {point.score} for chunk id {point.id}")
|
||||
|
||||
logger.info(f"QDRANT VECTOR SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}")
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_keyword(
|
||||
|
|
|
@ -34,7 +34,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
|
|||
|
||||
from .config import WeaviateVectorIOConfig
|
||||
|
||||
log = get_logger(name=__name__, category="vector_io::weaviate")
|
||||
logger = get_logger(name=__name__, category="vector_io::weaviate")
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"
|
||||
|
@ -88,6 +88,9 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
collection.data.delete_many(where=Filter.by_property("chunk_id").contains_any(chunk_ids))
|
||||
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
logger.info(
|
||||
f"WEAVIATE VECTOR SEARCH CALLED: embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}"
|
||||
)
|
||||
sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True)
|
||||
collection = self.client.collections.get(sanitized_collection_name)
|
||||
|
||||
|
@ -105,16 +108,21 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
chunk_dict = json.loads(chunk_json)
|
||||
chunk = Chunk(**chunk_dict)
|
||||
except Exception:
|
||||
log.exception(f"Failed to parse document: {chunk_json}")
|
||||
logger.exception(f"Failed to parse document: {chunk_json}")
|
||||
continue
|
||||
|
||||
score = 1.0 / doc.metadata.distance if doc.metadata.distance != 0 else float("inf")
|
||||
if doc.metadata.distance is None:
|
||||
continue
|
||||
# Cosine distance range [0,2] -> normalized to [0,1]
|
||||
score = 1.0 - (float(doc.metadata.distance) / 2.0)
|
||||
logger.info(f"Computed score {score} from distance {doc.metadata.distance} for chunk id {chunk.chunk_id}")
|
||||
if score < score_threshold:
|
||||
continue
|
||||
|
||||
chunks.append(chunk)
|
||||
scores.append(score)
|
||||
|
||||
logger.info(f"WEAVIATE VECTOR SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}")
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def delete(self, chunk_ids: list[str] | None = None) -> None:
|
||||
|
@ -174,7 +182,7 @@ class WeaviateVectorIOAdapter(
|
|||
|
||||
def _get_client(self) -> weaviate.Client:
|
||||
if "localhost" in self.config.weaviate_cluster_url:
|
||||
log.info("using Weaviate locally in container")
|
||||
logger.info("using Weaviate locally in container")
|
||||
host, port = self.config.weaviate_cluster_url.split(":")
|
||||
key = "local_test"
|
||||
client = weaviate.connect_to_local(
|
||||
|
@ -182,7 +190,7 @@ class WeaviateVectorIOAdapter(
|
|||
port=port,
|
||||
)
|
||||
else:
|
||||
log.info("Using Weaviate remote cluster with URL")
|
||||
logger.info("Using Weaviate remote cluster with URL")
|
||||
key = f"{self.config.weaviate_cluster_url}::{self.config.weaviate_api_key}"
|
||||
if key in self.client_cache:
|
||||
return self.client_cache[key]
|
||||
|
@ -200,7 +208,7 @@ class WeaviateVectorIOAdapter(
|
|||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
else:
|
||||
self.kvstore = None
|
||||
log.info("No kvstore configured, registry will not persist across restarts")
|
||||
logger.info("No kvstore configured, registry will not persist across restarts")
|
||||
|
||||
# Load existing vector DB definitions
|
||||
if self.kvstore is not None:
|
||||
|
@ -257,7 +265,7 @@ class WeaviateVectorIOAdapter(
|
|||
client = self._get_client()
|
||||
sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True)
|
||||
if sanitized_collection_name not in self.cache or client.collections.exists(sanitized_collection_name) is False:
|
||||
log.warning(f"Vector DB {sanitized_collection_name} not found")
|
||||
logger.warning(f"Vector DB {sanitized_collection_name} not found")
|
||||
return
|
||||
client.collections.delete(sanitized_collection_name)
|
||||
await self.cache[sanitized_collection_name].index.delete()
|
||||
|
|
|
@ -222,3 +222,63 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
|
|||
assert len(response.chunks) > 0
|
||||
assert response.chunks[0].metadata["document_id"] == "doc1"
|
||||
assert response.chunks[0].metadata["source"] == "precomputed"
|
||||
|
||||
|
||||
def test_vector_similarity_scores_are_normalized(
|
||||
client_with_empty_registry, embedding_model_id, embedding_dimension, sample_chunks
|
||||
):
|
||||
"""Test that vector similarity scores are properly normalized to [0,1] range for all vector providers."""
|
||||
vector_db_name = "test_score_normalization_db"
|
||||
register_response = client_with_empty_registry.vector_dbs.register(
|
||||
vector_db_id=vector_db_name,
|
||||
embedding_model=embedding_model_id,
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
|
||||
actual_vector_db_id = register_response.identifier
|
||||
|
||||
# Insert sample chunks
|
||||
client_with_empty_registry.vector_io.insert(
|
||||
vector_db_id=actual_vector_db_id,
|
||||
chunks=sample_chunks,
|
||||
)
|
||||
|
||||
# Test various queries to ensure score normalization across different similarity levels
|
||||
test_queries = [
|
||||
# High similarity query that should match Python doc chunk
|
||||
"Python programming language with readable code",
|
||||
# Medium similarity query
|
||||
"artificial intelligence and machine learning systems",
|
||||
# Lower similarity query
|
||||
"What is the capital of France?",
|
||||
# High similarity query that should match neural networks chunk
|
||||
"biological neural networks and artificial neurons",
|
||||
# Very low similarity query to test edge case normalization
|
||||
"xyzabc random nonsense gibberish qwerty asdfgh",
|
||||
]
|
||||
|
||||
for query in test_queries:
|
||||
response = client_with_empty_registry.vector_io.query(
|
||||
vector_db_id=actual_vector_db_id,
|
||||
query=query,
|
||||
)
|
||||
|
||||
# Verify response structure
|
||||
assert response is not None, f"Query '{query}' returned None response"
|
||||
assert len(response.chunks) > 0, f"Query '{query}' returned no chunks"
|
||||
assert len(response.scores) > 0, f"Query '{query}' returned no scores"
|
||||
assert len(response.chunks) == len(response.scores), "Mismatch between chunks and scores count"
|
||||
|
||||
# Verify all scores are normalized to [0,1] range
|
||||
for i, score in enumerate(response.scores):
|
||||
assert isinstance(score, (int | float)), f"Score at index {i} is not numeric: {type(score)}"
|
||||
assert 0.0 <= score <= 1.0, (
|
||||
f"Score at index {i} is not normalized: {score} (should be in [0,1] range) for query '{query}'"
|
||||
)
|
||||
|
||||
# Verify scores are in descending order (most similar first)
|
||||
for i in range(1, len(response.scores)):
|
||||
assert response.scores[i - 1] >= response.scores[i], (
|
||||
f"Scores not in descending order at indices {i - 1} and {i}: "
|
||||
f"{response.scores[i - 1]} >= {response.scores[i]} for query '{query}'"
|
||||
)
|
||||
|
|
|
@ -112,14 +112,16 @@ async def faiss_adapter(faiss_config, mock_inference_api, mock_files_api) -> Fai
|
|||
yield adapter
|
||||
|
||||
|
||||
async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_identical(
|
||||
async def test_faiss_query_vector_returns_perfect_score_when_query_and_embedding_are_identical(
|
||||
faiss_index, sample_chunks, sample_embeddings, embedding_dimension
|
||||
):
|
||||
await faiss_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||
|
||||
with patch.object(faiss_index.index, "search") as mock_search:
|
||||
mock_search.return_value = (np.array([[0.0, 0.1]]), np.array([[0, 1]]))
|
||||
# IndexFlatIP with normalized vectors returns cosine similarity scores [-1,1]
|
||||
# These will be normalized to [0,1] using (score + 1.0) / 2.0
|
||||
mock_search.return_value = (np.array([[1.0, 0.6]]), np.array([[0, 1]]))
|
||||
|
||||
response = await faiss_index.query_vector(embedding=query_embedding, k=2, score_threshold=0.0)
|
||||
|
||||
|
@ -127,8 +129,8 @@ async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_
|
|||
assert len(response.chunks) == 2
|
||||
assert len(response.scores) == 2
|
||||
|
||||
assert response.scores[0] == float("inf") # infinity (1.0 / 0.0)
|
||||
assert response.scores[1] == 10.0 # (1.0 / 0.1 = 10.0)
|
||||
assert response.scores[0] == 1.0 # (1.0 + 1.0) / 2.0 = 1.0 (perfect similarity)
|
||||
assert response.scores[1] == 0.8 # (0.6 + 1.0) / 2.0 = 0.8 (high similarity)
|
||||
|
||||
assert response.chunks[0] == sample_chunks[0]
|
||||
assert response.chunks[1] == sample_chunks[1]
|
||||
|
@ -141,7 +143,7 @@ async def test_health_success():
|
|||
inference_api = MagicMock()
|
||||
files_api = MagicMock()
|
||||
|
||||
with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatL2") as mock_index_flat:
|
||||
with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatIP") as mock_index_flat:
|
||||
mock_index_flat.return_value = MagicMock()
|
||||
adapter = FaissVectorIOAdapter(config=config, inference_api=inference_api, files_api=files_api)
|
||||
|
||||
|
@ -153,7 +155,7 @@ async def test_health_success():
|
|||
assert response["status"] == HealthStatus.OK
|
||||
assert "message" not in response
|
||||
|
||||
# Verifying that IndexFlatL2 was called with the correct dimension
|
||||
# Verifying that IndexFlatIP was called with the correct dimension
|
||||
mock_index_flat.assert_called_once_with(128) # VECTOR_DIMENSION is 128
|
||||
|
||||
|
||||
|
@ -164,7 +166,7 @@ async def test_health_failure():
|
|||
inference_api = MagicMock()
|
||||
files_api = MagicMock()
|
||||
|
||||
with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatL2") as mock_index_flat:
|
||||
with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatIP") as mock_index_flat:
|
||||
mock_index_flat.side_effect = Exception("Test error")
|
||||
|
||||
adapter = FaissVectorIOAdapter(config=config, inference_api=inference_api, files_api=files_api)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue