This commit is contained in:
Bwook (Byoungwook) Kim 2025-12-03 01:04:10 +00:00 committed by GitHub
commit cae28be572
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 152 additions and 42 deletions

View file

View file

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import heapq
import json
from typing import Any
from urllib.parse import urlparse
@ -16,6 +17,7 @@ from llama_stack.log import get_logger
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator
from llama_stack_api import (
Chunk,
Files,
@ -99,8 +101,55 @@ class ChromaIndex(EmbeddingIndex):
async def delete(self):
await maybe_await(self.client.delete_collection(self.collection.name))
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Chroma")
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
"""
Perform keyword search using Chroma's built-in where_document feature.
Args:
query_string: The text query for keyword search
k: Number of results to return
score_threshold: Minimum similarity score threshold
Returns:
QueryChunksResponse with combined results
"""
try:
results = await maybe_await(
self.collection.query(
query_texts=[query_string],
where_document={"$contains": query_string},
n_results=k,
include=["documents", "distances"],
)
)
except Exception as e:
log.error(f"Chroma client keyword search failed: {e}")
raise
distances = results["distances"][0] if results["distances"] else []
documents = results["documents"][0] if results["documents"] else []
chunks = []
scores = []
for dist, doc in zip(distances, documents, strict=False):
doc_data = json.loads(doc)
chunk = Chunk(**doc_data)
score = 1.0 / (1.0 + float(dist)) if dist is not None else 1.0
if score < score_threshold:
continue
chunks.append(chunk)
scores.append(score)
return QueryChunksResponse(chunks=chunks, scores=scores)
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
"""Delete a single chunk from the Chroma collection by its ID."""
@ -116,7 +165,57 @@ class ChromaIndex(EmbeddingIndex):
reranker_type: str,
reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in Chroma")
"""
Hybrid search combining vector similarity and keyword search using configurable reranking.
Args:
embedding: The query embedding vector
query_string: The text query for keyword search
k: Number of results to return
score_threshold: Minimum similarity score threshold
reranker_type: Type of reranker to use ("rrf" or "weighted")
reranker_params: Parameters for the reranker
Returns:
QueryChunksResponse with combined results
"""
if reranker_params is None:
reranker_params = {}
# Get results from both search methods
vector_response = await self.query_vector(embedding, k, score_threshold)
keyword_response = await self.query_keyword(query_string, k, score_threshold)
# Convert responses to score dictionaries using chunk_id
vector_scores = {
chunk.chunk_id: score for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False)
}
keyword_scores = {
chunk.chunk_id: score
for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False)
}
# Combine scores using the reranking utility
combined_scores = WeightedInMemoryAggregator.combine_search_results(
vector_scores, keyword_scores, reranker_type, reranker_params
)
# Efficient top-k selection because it only tracks the k best candidates it's seen so far
top_k_items = heapq.nlargest(k, combined_scores.items(), key=lambda x: x[1])
# Filter by score threshold
filtered_items = [(doc_id, score) for doc_id, score in top_k_items if score >= score_threshold]
# Create a map of chunk_id to chunk for both responses
chunk_map = {c.chunk_id: c for c in vector_response.chunks + keyword_response.chunks}
# Use the map to look up chunks by their IDs
chunks = []
scores = []
for doc_id, score in filtered_items:
if doc_id in chunk_map:
chunks.append(chunk_map[doc_id])
scores.append(score)
return QueryChunksResponse(chunks=chunks, scores=scores)
class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):

View file

@ -92,6 +92,13 @@ class OpenAIVectorStoreMixin(ABC):
self.kvstore = kvstore
self._last_file_batch_cleanup_time = 0
self._file_batch_tasks: dict[str, asyncio.Task[None]] = {}
self._vector_store_locks: dict[str, asyncio.Lock] = {}
def _get_vector_store_lock(self, vector_store_id: str) -> asyncio.Lock:
"""Get or create a lock for a specific vector store."""
if vector_store_id not in self._vector_store_locks:
self._vector_store_locks[vector_store_id] = asyncio.Lock()
return self._vector_store_locks[vector_store_id]
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to persistent storage."""
@ -831,7 +838,12 @@ class OpenAIVectorStoreMixin(ABC):
await self._save_openai_vector_store_file(vector_store_id, file_id, file_info, dict_chunks)
# Update file_ids and file_counts in vector store metadata
# Use lock to prevent race condition when multiple files are attached concurrently
async with self._get_vector_store_lock(vector_store_id):
store_info = self.openai_vector_stores[vector_store_id].copy()
# Deep copy file_counts to avoid mutating shared dict
store_info["file_counts"] = store_info["file_counts"].copy()
store_info["file_ids"] = store_info["file_ids"].copy()
store_info["file_ids"].append(file_id)
store_info["file_counts"]["total"] += 1
store_info["file_counts"][vector_store_file_object.status] += 1
@ -839,9 +851,6 @@ class OpenAIVectorStoreMixin(ABC):
# Save updated vector store to persistent storage
await self._save_openai_vector_store(vector_store_id, store_info)
# Update vector store in-memory cache
self.openai_vector_stores[vector_store_id] = store_info
return vector_store_file_object
async def openai_list_files_in_vector_store(

View file

@ -61,6 +61,7 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode
"remote::milvus",
"remote::pgvector",
"remote::weaviate",
"remote::chromadb",
],
"hybrid": [
"inline::milvus",
@ -68,6 +69,7 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode
"remote::milvus",
"remote::pgvector",
"remote::weaviate",
"remote::chromadb",
],
}
supported_providers = search_mode_support.get(search_mode, [])