add query_keyword function

This commit is contained in:
kimbwook 2025-08-07 10:09:29 +09:00
parent 554c78ba66
commit 26fb208877
No known key found for this signature in database
GPG key ID: 13B032C99CBD373A

View file

@ -10,7 +10,6 @@ from typing import Any
from urllib.parse import urlparse from urllib.parse import urlparse
import chromadb import chromadb
from chromadb.api.models.AsyncCollection import AsyncCollection
from numpy.typing import NDArray from numpy.typing import NDArray
from llama_stack.apis.files import Files from llama_stack.apis.files import Files
@ -109,18 +108,46 @@ class ChromaIndex(EmbeddingIndex):
await maybe_await(self.client.delete_collection(self.collection.name)) await maybe_await(self.client.delete_collection(self.collection.name))
async def query_keyword( async def query_keyword(
self, self,
query_string: str, query_string: str,
k: int, k: int,
score_threshold: float, score_threshold: float,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Chroma") results = await maybe_await(
self.collection.query(
query_texts=[query_string],
where_document={"$contains": query_string},
n_results=k,
include=["documents", "distances"],
)
)
distances = results["distances"][0] if results["distances"] else []
documents = results["documents"][0] if results["documents"] else []
chunks = []
scores = []
for dist, doc in zip(distances, documents, strict=False):
try:
doc_data = json.loads(doc)
chunk = Chunk(**doc_data)
except Exception:
log.exception(f"Failed to parse document: {doc}")
continue
score = 1.0 / (1.0 + float(dist)) if dist is not None else 1.0
if score < score_threshold:
continue
chunks.append(chunk)
scores.append(score)
return QueryChunksResponse(chunks=chunks, scores=scores)
async def delete_chunk(self, chunk_id: str) -> None: async def delete_chunk(self, chunk_id: str) -> None:
if isinstance(self.collection, AsyncCollection): await maybe_await(self.collection.delete([chunk_id]))
await self.collection.delete([chunk_id])
else:
self.collection.delete([chunk_id])
async def query_hybrid( async def query_hybrid(
self, self,