fix(vector-io): unify score calculation to use cosine and normalize to [0,1]

This commit is contained in:
ChristianZaccaria 2025-09-04 13:03:59 +01:00
parent 9618adba89
commit a0e0c7030b
9 changed files with 166 additions and 42 deletions

View file

@ -33,7 +33,7 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
log = get_logger(name=__name__, category="vector_io::chroma")
logger = get_logger(name=__name__, category="vector_io::chroma")
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI
@ -76,6 +76,9 @@ class ChromaIndex(EmbeddingIndex):
)
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
logger.info(
f"CHROMA VECTOR SEARCH CALLED: embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}"
)
results = await maybe_await(
self.collection.query(
query_embeddings=[embedding.tolist()],
@ -93,16 +96,19 @@ class ChromaIndex(EmbeddingIndex):
doc = json.loads(doc)
chunk = Chunk(**doc)
except Exception:
log.exception(f"Failed to parse document: {doc}")
logger.exception(f"Failed to parse document: {doc}")
continue
score = 1.0 / float(dist) if dist != 0 else float("inf")
# Cosine distance range [0,2] -> normalized to [0,1]
score = 1.0 - (float(dist) / 2.0)
logger.info(f"Computed score {score} from distance {dist} for chunk id {chunk.chunk_id}")
if score < score_threshold:
continue
chunks.append(chunk)
scores.append(score)
logger.info(f"CHROMA VECTOR SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}")
return QueryChunksResponse(chunks=chunks, scores=scores)
async def delete(self):
@ -140,7 +146,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
inference_api: Api.inference,
files_api: Files | None,
) -> None:
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
logger.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
self.config = config
self.inference_api = inference_api
self.client = None
@ -154,7 +160,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.vector_db_store = self.kvstore
if isinstance(self.config, RemoteChromaVectorIOConfig):
log.info(f"Connecting to Chroma server at: {self.config.url}")
logger.info(f"Connecting to Chroma server at: {self.config.url}")
url = self.config.url.rstrip("/")
parsed = urlparse(url)
@ -163,7 +169,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.client = await chromadb.AsyncHttpClient(host=parsed.hostname, port=parsed.port)
else:
log.info(f"Connecting to Chroma local db at: {self.config.db_path}")
logger.info(f"Connecting to Chroma local db at: {self.config.db_path}")
self.client = chromadb.PersistentClient(path=self.config.db_path)
self.openai_vector_stores = await self._load_openai_vector_stores()
@ -177,7 +183,10 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
collection = await maybe_await(
self.client.get_or_create_collection(
name=vector_db.identifier,
metadata={"vector_db": vector_db.model_dump_json()},
metadata={
"vector_db": vector_db.model_dump_json(),
"hnsw:space": "cosine", # Returns cosine distance
},
)
)
self.cache[vector_db.identifier] = VectorDBWithIndex(
@ -186,7 +195,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
async def unregister_vector_db(self, vector_db_id: str) -> None:
if vector_db_id not in self.cache:
log.warning(f"Vector DB {vector_db_id} not found")
logger.warning(f"Vector DB {vector_db_id} not found")
return
await self.cache[vector_db_id].index.delete()