mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-08 21:04:39 +00:00
fix(vector-io): unify score calculation to use cosine and normalize to [0,1]
This commit is contained in:
parent
9618adba89
commit
a0e0c7030b
9 changed files with 166 additions and 42 deletions
|
@ -33,7 +33,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
|
||||
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
|
||||
|
||||
log = get_logger(name=__name__, category="vector_io::chroma")
|
||||
logger = get_logger(name=__name__, category="vector_io::chroma")
|
||||
|
||||
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI
|
||||
|
||||
|
@ -76,6 +76,9 @@ class ChromaIndex(EmbeddingIndex):
|
|||
)
|
||||
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
logger.info(
|
||||
f"CHROMA VECTOR SEARCH CALLED: embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}"
|
||||
)
|
||||
results = await maybe_await(
|
||||
self.collection.query(
|
||||
query_embeddings=[embedding.tolist()],
|
||||
|
@ -93,16 +96,19 @@ class ChromaIndex(EmbeddingIndex):
|
|||
doc = json.loads(doc)
|
||||
chunk = Chunk(**doc)
|
||||
except Exception:
|
||||
log.exception(f"Failed to parse document: {doc}")
|
||||
logger.exception(f"Failed to parse document: {doc}")
|
||||
continue
|
||||
|
||||
score = 1.0 / float(dist) if dist != 0 else float("inf")
|
||||
# Cosine distance range [0,2] -> normalized to [0,1]
|
||||
score = 1.0 - (float(dist) / 2.0)
|
||||
logger.info(f"Computed score {score} from distance {dist} for chunk id {chunk.chunk_id}")
|
||||
if score < score_threshold:
|
||||
continue
|
||||
|
||||
chunks.append(chunk)
|
||||
scores.append(score)
|
||||
|
||||
logger.info(f"CHROMA VECTOR SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}")
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def delete(self):
|
||||
|
@ -140,7 +146,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
inference_api: Api.inference,
|
||||
files_api: Files | None,
|
||||
) -> None:
|
||||
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
|
||||
logger.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.client = None
|
||||
|
@ -154,7 +160,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self.vector_db_store = self.kvstore
|
||||
|
||||
if isinstance(self.config, RemoteChromaVectorIOConfig):
|
||||
log.info(f"Connecting to Chroma server at: {self.config.url}")
|
||||
logger.info(f"Connecting to Chroma server at: {self.config.url}")
|
||||
url = self.config.url.rstrip("/")
|
||||
parsed = urlparse(url)
|
||||
|
||||
|
@ -163,7 +169,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
|
||||
self.client = await chromadb.AsyncHttpClient(host=parsed.hostname, port=parsed.port)
|
||||
else:
|
||||
log.info(f"Connecting to Chroma local db at: {self.config.db_path}")
|
||||
logger.info(f"Connecting to Chroma local db at: {self.config.db_path}")
|
||||
self.client = chromadb.PersistentClient(path=self.config.db_path)
|
||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||
|
||||
|
@ -177,7 +183,10 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
collection = await maybe_await(
|
||||
self.client.get_or_create_collection(
|
||||
name=vector_db.identifier,
|
||||
metadata={"vector_db": vector_db.model_dump_json()},
|
||||
metadata={
|
||||
"vector_db": vector_db.model_dump_json(),
|
||||
"hnsw:space": "cosine", # Returns cosine distance
|
||||
},
|
||||
)
|
||||
)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||
|
@ -186,7 +195,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
if vector_db_id not in self.cache:
|
||||
log.warning(f"Vector DB {vector_db_id} not found")
|
||||
logger.warning(f"Vector DB {vector_db_id} not found")
|
||||
return
|
||||
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue