This commit is contained in:
Bwook (Byoungwook) Kim 2025-09-24 09:30:04 +02:00 committed by GitHub
commit c4f3d41c57
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 109 additions and 24 deletions

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import heapq
import json
from typing import Any
from urllib.parse import urlparse
@ -30,6 +31,7 @@ from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex,
VectorDBWithIndex,
)
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
@ -114,7 +116,38 @@ class ChromaIndex(EmbeddingIndex):
k: int,
score_threshold: float,
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Chroma")
results = await maybe_await(
self.collection.query(
query_texts=[query_string],
where_document={"$contains": query_string},
n_results=k,
include=["documents", "distances"],
)
)
distances = results["distances"][0] if results["distances"] else []
documents = results["documents"][0] if results["documents"] else []
chunks = []
scores = []
for dist, doc in zip(distances, documents, strict=False):
try:
doc_data = json.loads(doc)
chunk = Chunk(**doc_data)
except Exception:
log.exception(f"Failed to load chunk: {doc}")
continue
score = 1.0 / (1.0 + float(dist)) if dist is not None else 1.0
if score < score_threshold:
continue
chunks.append(chunk)
scores.append(score)
return QueryChunksResponse(chunks=chunks, scores=scores)
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete a single chunk from the Chroma collection by its ID."""
@ -130,7 +163,57 @@ class ChromaIndex(EmbeddingIndex):
reranker_type: str,
reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in Chroma")
"""
Hybrid search combining vector similarity and keyword search using configurable reranking.
Args:
embedding: The query embedding vector
query_string: The text query for keyword search
k: Number of results to return
score_threshold: Minimum similarity score threshold
reranker_type: Type of reranker to use ("rrf" or "weighted")
reranker_params: Parameters for the reranker
Returns:
QueryChunksResponse with combined results
"""
if reranker_params is None:
reranker_params = {}
# Get results from both search methods
vector_response = await self.query_vector(embedding, k, score_threshold)
keyword_response = await self.query_keyword(query_string, k, score_threshold)
# Convert responses to score dictionaries using chunk_id
vector_scores = {
chunk.chunk_id: score for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False)
}
keyword_scores = {
chunk.chunk_id: score
for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False)
}
# Combine scores using the reranking utility
combined_scores = WeightedInMemoryAggregator.combine_search_results(
vector_scores, keyword_scores, reranker_type, reranker_params
)
# Efficient top-k selection because it only tracks the k best candidates it's seen so far
top_k_items = heapq.nlargest(k, combined_scores.items(), key=lambda x: x[1])
# Filter by score threshold
filtered_items = [(doc_id, score) for doc_id, score in top_k_items if score >= score_threshold]
# Create a map of chunk_id to chunk for both responses
chunk_map = {c.chunk_id: c for c in vector_response.chunks + keyword_response.chunks}
# Use the map to look up chunks by their IDs
chunks = []
scores = []
for doc_id, score in filtered_items:
if doc_id in chunk_map:
chunks.append(chunk_map[doc_id])
scores.append(score)
return QueryChunksResponse(chunks=chunks, scores=scores)
class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):

2
uv.lock generated
View file

@ -1767,6 +1767,7 @@ dependencies = [
{ name = "opentelemetry-exporter-otlp-proto-http" },
{ name = "opentelemetry-sdk" },
{ name = "pillow" },
{ name = "pre-commit" },
{ name = "prompt-toolkit" },
{ name = "pydantic" },
{ name = "python-dotenv" },
@ -1892,6 +1893,7 @@ requires-dist = [
{ name = "opentelemetry-sdk", specifier = ">=1.30.0" },
{ name = "pandas", marker = "extra == 'ui'" },
{ name = "pillow" },
{ name = "pre-commit", specifier = ">=4.2.0" },
{ name = "prompt-toolkit" },
{ name = "pydantic", specifier = ">=2.11.9" },
{ name = "python-dotenv" },