From 26fb2088771e9f88a61347ee2af9ee26a2b42b64 Mon Sep 17 00:00:00 2001 From: kimbwook Date: Thu, 7 Aug 2025 10:09:29 +0900 Subject: [PATCH] add query_keyword function --- .../remote/vector_io/chroma/chroma.py | 47 +++++++++++++++---- 1 file changed, 37 insertions(+), 10 deletions(-) diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index 442e64f5d..954817837 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -10,7 +10,6 @@ from typing import Any from urllib.parse import urlparse import chromadb -from chromadb.api.models.AsyncCollection import AsyncCollection from numpy.typing import NDArray from llama_stack.apis.files import Files @@ -109,18 +108,46 @@ class ChromaIndex(EmbeddingIndex): await maybe_await(self.client.delete_collection(self.collection.name)) async def query_keyword( - self, - query_string: str, - k: int, - score_threshold: float, + self, + query_string: str, + k: int, + score_threshold: float, ) -> QueryChunksResponse: - raise NotImplementedError("Keyword search is not supported in Chroma") + results = await maybe_await( + self.collection.query( + query_texts=[query_string], + where_document={"$contains": query_string}, + n_results=k, + include=["documents", "distances"], + ) + ) + + distances = results["distances"][0] if results["distances"] else [] + documents = results["documents"][0] if results["documents"] else [] + + chunks = [] + scores = [] + + for dist, doc in zip(distances, documents, strict=False): + try: + doc_data = json.loads(doc) + chunk = Chunk(**doc_data) + except Exception: + log.exception(f"Failed to parse document: {doc}") + continue + + score = 1.0 / (1.0 + float(dist)) if dist is not None else 1.0 + + if score < score_threshold: + continue + + chunks.append(chunk) + scores.append(score) + + return QueryChunksResponse(chunks=chunks, scores=scores) async def delete_chunk(self, chunk_id: str) -> None: - if isinstance(self.collection, AsyncCollection): - await self.collection.delete([chunk_id]) - else: - self.collection.delete([chunk_id]) + await maybe_await(self.collection.delete([chunk_id])) async def query_hybrid( self,