mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-08 21:04:39 +00:00
fix(vector-io): unify score calculation to use cosine and normalize to [0,1]
This commit is contained in:
parent
9618adba89
commit
a0e0c7030b
9 changed files with 166 additions and 42 deletions
|
@ -34,7 +34,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
|
|||
|
||||
from .config import WeaviateVectorIOConfig
|
||||
|
||||
log = get_logger(name=__name__, category="vector_io::weaviate")
|
||||
logger = get_logger(name=__name__, category="vector_io::weaviate")
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"
|
||||
|
@ -88,6 +88,9 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
collection.data.delete_many(where=Filter.by_property("chunk_id").contains_any(chunk_ids))
|
||||
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
logger.info(
|
||||
f"WEAVIATE VECTOR SEARCH CALLED: embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}"
|
||||
)
|
||||
sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True)
|
||||
collection = self.client.collections.get(sanitized_collection_name)
|
||||
|
||||
|
@ -105,16 +108,21 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
chunk_dict = json.loads(chunk_json)
|
||||
chunk = Chunk(**chunk_dict)
|
||||
except Exception:
|
||||
log.exception(f"Failed to parse document: {chunk_json}")
|
||||
logger.exception(f"Failed to parse document: {chunk_json}")
|
||||
continue
|
||||
|
||||
score = 1.0 / doc.metadata.distance if doc.metadata.distance != 0 else float("inf")
|
||||
if doc.metadata.distance is None:
|
||||
continue
|
||||
# Cosine distance range [0,2] -> normalized to [0,1]
|
||||
score = 1.0 - (float(doc.metadata.distance) / 2.0)
|
||||
logger.info(f"Computed score {score} from distance {doc.metadata.distance} for chunk id {chunk.chunk_id}")
|
||||
if score < score_threshold:
|
||||
continue
|
||||
|
||||
chunks.append(chunk)
|
||||
scores.append(score)
|
||||
|
||||
logger.info(f"WEAVIATE VECTOR SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}")
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def delete(self, chunk_ids: list[str] | None = None) -> None:
|
||||
|
@ -174,7 +182,7 @@ class WeaviateVectorIOAdapter(
|
|||
|
||||
def _get_client(self) -> weaviate.Client:
|
||||
if "localhost" in self.config.weaviate_cluster_url:
|
||||
log.info("using Weaviate locally in container")
|
||||
logger.info("using Weaviate locally in container")
|
||||
host, port = self.config.weaviate_cluster_url.split(":")
|
||||
key = "local_test"
|
||||
client = weaviate.connect_to_local(
|
||||
|
@ -182,7 +190,7 @@ class WeaviateVectorIOAdapter(
|
|||
port=port,
|
||||
)
|
||||
else:
|
||||
log.info("Using Weaviate remote cluster with URL")
|
||||
logger.info("Using Weaviate remote cluster with URL")
|
||||
key = f"{self.config.weaviate_cluster_url}::{self.config.weaviate_api_key}"
|
||||
if key in self.client_cache:
|
||||
return self.client_cache[key]
|
||||
|
@ -200,7 +208,7 @@ class WeaviateVectorIOAdapter(
|
|||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
else:
|
||||
self.kvstore = None
|
||||
log.info("No kvstore configured, registry will not persist across restarts")
|
||||
logger.info("No kvstore configured, registry will not persist across restarts")
|
||||
|
||||
# Load existing vector DB definitions
|
||||
if self.kvstore is not None:
|
||||
|
@ -257,7 +265,7 @@ class WeaviateVectorIOAdapter(
|
|||
client = self._get_client()
|
||||
sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True)
|
||||
if sanitized_collection_name not in self.cache or client.collections.exists(sanitized_collection_name) is False:
|
||||
log.warning(f"Vector DB {sanitized_collection_name} not found")
|
||||
logger.warning(f"Vector DB {sanitized_collection_name} not found")
|
||||
return
|
||||
client.collections.delete(sanitized_collection_name)
|
||||
await self.cache[sanitized_collection_name].index.delete()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue