diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index a59a38573..628302a25 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -55,7 +55,7 @@ class ChromaIndex(EmbeddingIndex): ) ) - async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: results = await maybe_await( self.collection.query( query_embeddings=[embedding.tolist()], @@ -76,8 +76,12 @@ class ChromaIndex(EmbeddingIndex): log.exception(f"Failed to parse document: {doc}") continue + score = 1.0 / float(dist) if dist != 0 else float("inf") + if score < score_threshold: + continue + chunks.append(chunk) - scores.append(1.0 / float(dist)) + scores.append(score) return QueryChunksResponse(chunks=chunks, scores=scores)