fix(vector-io): unify score calculation to use cosine and normalize to [0,1]

This commit is contained in:
ChristianZaccaria 2025-09-04 13:03:59 +01:00
parent 9618adba89
commit a0e0c7030b
9 changed files with 166 additions and 42 deletions

View file

@ -34,7 +34,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
from .config import WeaviateVectorIOConfig
log = get_logger(name=__name__, category="vector_io::weaviate")
logger = get_logger(name=__name__, category="vector_io::weaviate")
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"
@ -88,6 +88,9 @@ class WeaviateIndex(EmbeddingIndex):
collection.data.delete_many(where=Filter.by_property("chunk_id").contains_any(chunk_ids))
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
logger.info(
f"WEAVIATE VECTOR SEARCH CALLED: embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}"
)
sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True)
collection = self.client.collections.get(sanitized_collection_name)
@ -105,16 +108,21 @@ class WeaviateIndex(EmbeddingIndex):
chunk_dict = json.loads(chunk_json)
chunk = Chunk(**chunk_dict)
except Exception:
log.exception(f"Failed to parse document: {chunk_json}")
logger.exception(f"Failed to parse document: {chunk_json}")
continue
score = 1.0 / doc.metadata.distance if doc.metadata.distance != 0 else float("inf")
if doc.metadata.distance is None:
continue
# Cosine distance range [0,2] -> normalized to [0,1]
score = 1.0 - (float(doc.metadata.distance) / 2.0)
logger.info(f"Computed score {score} from distance {doc.metadata.distance} for chunk id {chunk.chunk_id}")
if score < score_threshold:
continue
chunks.append(chunk)
scores.append(score)
logger.info(f"WEAVIATE VECTOR SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}")
return QueryChunksResponse(chunks=chunks, scores=scores)
async def delete(self, chunk_ids: list[str] | None = None) -> None:
@ -174,7 +182,7 @@ class WeaviateVectorIOAdapter(
def _get_client(self) -> weaviate.Client:
if "localhost" in self.config.weaviate_cluster_url:
log.info("using Weaviate locally in container")
logger.info("using Weaviate locally in container")
host, port = self.config.weaviate_cluster_url.split(":")
key = "local_test"
client = weaviate.connect_to_local(
@ -182,7 +190,7 @@ class WeaviateVectorIOAdapter(
port=port,
)
else:
log.info("Using Weaviate remote cluster with URL")
logger.info("Using Weaviate remote cluster with URL")
key = f"{self.config.weaviate_cluster_url}::{self.config.weaviate_api_key}"
if key in self.client_cache:
return self.client_cache[key]
@ -200,7 +208,7 @@ class WeaviateVectorIOAdapter(
self.kvstore = await kvstore_impl(self.config.kvstore)
else:
self.kvstore = None
log.info("No kvstore configured, registry will not persist across restarts")
logger.info("No kvstore configured, registry will not persist across restarts")
# Load existing vector DB definitions
if self.kvstore is not None:
@ -257,7 +265,7 @@ class WeaviateVectorIOAdapter(
client = self._get_client()
sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True)
if sanitized_collection_name not in self.cache or client.collections.exists(sanitized_collection_name) is False:
log.warning(f"Vector DB {sanitized_collection_name} not found")
logger.warning(f"Vector DB {sanitized_collection_name} not found")
return
client.collections.delete(sanitized_collection_name)
await self.cache[sanitized_collection_name].index.delete()