change Reranker to WeightedInMemoryAggregator

This commit is contained in:
kimbwook 2025-09-11 21:40:21 +09:00
parent 60318b659d
commit 897be1376e
No known key found for this signature in database
GPG key ID: 13B032C99CBD373A
6 changed files with 22 additions and 142 deletions

View file

@ -3,15 +3,15 @@
## Overview ## Overview
The Batches API enables efficient processing of multiple requests in a single operation, The Batches API enables efficient processing of multiple requests in a single operation,
particularly useful for processing large datasets, batch evaluation workflows, and particularly useful for processing large datasets, batch evaluation workflows, and
cost-effective inference at scale. cost-effective inference at scale.
The API is designed to allow use of openai client libraries for seamless integration. The API is designed to allow use of openai client libraries for seamless integration.
This API provides the following extensions: This API provides the following extensions:
- idempotent batch creation - idempotent batch creation
Note: This API is currently under active development and may undergo changes. Note: This API is currently under active development and may undergo changes.
This section contains documentation for all available providers for the **batches** API. This section contains documentation for all available providers for the **batches** API.

View file

@ -31,7 +31,7 @@ from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex, EmbeddingIndex,
VectorDBWithIndex, VectorDBWithIndex,
) )
from llama_stack.providers.utils.vector_io.vector_utils import Reranker from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
@ -192,7 +192,9 @@ class ChromaIndex(EmbeddingIndex):
} }
# Combine scores using the reranking utility # Combine scores using the reranking utility
combined_scores = Reranker.combine_search_results(vector_scores, keyword_scores, reranker_type, reranker_params) combined_scores = WeightedInMemoryAggregator.combine_search_results(
vector_scores, keyword_scores, reranker_type, reranker_params
)
# Efficient top-k selection because it only tracks the k best candidates it's seen so far # Efficient top-k selection because it only tracks the k best candidates it's seen so far
top_k_items = heapq.nlargest(k, combined_scores.items(), key=lambda x: x[1]) top_k_items = heapq.nlargest(k, combined_scores.items(), key=lambda x: x[1])

View file

@ -39,7 +39,6 @@ def sanitize_collection_name(name: str, weaviate_format=False) -> str:
return s return s
class WeightedInMemoryAggregator: class WeightedInMemoryAggregator:
@staticmethod @staticmethod
def _normalize_scores(scores: dict[str, float]) -> dict[str, float]: def _normalize_scores(scores: dict[str, float]) -> dict[str, float]:

View file

@ -49,6 +49,7 @@ dependencies = [
"opentelemetry-exporter-otlp-proto-http>=1.30.0", # server "opentelemetry-exporter-otlp-proto-http>=1.30.0", # server
"aiosqlite>=0.21.0", # server - for metadata store "aiosqlite>=0.21.0", # server - for metadata store
"asyncpg", # for metadata store "asyncpg", # for metadata store
"pre-commit>=4.2.0",
] ]
[project.optional-dependencies] [project.optional-dependencies]

View file

@ -1,124 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from llama_stack.apis.vector_io import QueryChunksResponse
# Mock the entire chromadb module
chromadb_mock = MagicMock()
chromadb_mock.AsyncHttpClient = MagicMock
chromadb_mock.PersistentClient = MagicMock
# Apply the mock before importing ChromaIndex
with patch.dict("sys.modules", {"chromadb": chromadb_mock}):
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex
# This test is a unit test for the ChromaVectorIOAdapter class. This should only contain
# tests which are specific to this class. More general (API-level) tests should be placed in
# tests/integration/vector_io/
#
# How to run this test:
#
# pytest tests/unit/providers/vector_io/test_chroma.py \
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
CHROMA_PROVIDER = "chromadb"
@pytest.fixture
async def mock_chroma_collection() -> MagicMock:
"""Create a mock Chroma collection with common method behaviors."""
collection = MagicMock()
collection.name = "test_collection"
# Mock add operation
collection.add.return_value = None
# Mock query operation for vector search
collection.query.return_value = {
"distances": [[0.1, 0.2]],
"documents": [
[
json.dumps({"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}),
json.dumps({"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}),
]
],
}
# Mock delete operation
collection.delete.return_value = None
return collection
@pytest.fixture
async def mock_chroma_client(mock_chroma_collection):
"""Create a mock Chroma client with common method behaviors."""
client = MagicMock()
# Mock collection operations
client.get_or_create_collection.return_value = mock_chroma_collection
client.get_collection.return_value = mock_chroma_collection
client.delete_collection.return_value = None
return client
@pytest.fixture
async def chroma_index(mock_chroma_client, mock_chroma_collection):
"""Create a ChromaIndex with mocked client and collection."""
index = ChromaIndex(client=mock_chroma_client, collection=mock_chroma_collection)
yield index
# No real cleanup needed since we're using mocks
async def test_add_chunks(chroma_index, sample_chunks, sample_embeddings, mock_chroma_collection):
await chroma_index.add_chunks(sample_chunks, sample_embeddings)
# Verify data was inserted
mock_chroma_collection.add.assert_called_once()
# Verify the add call had the right number of chunks
add_call = mock_chroma_collection.add.call_args
assert len(add_call[1]["documents"]) == len(sample_chunks)
async def test_query_chunks_vector(
chroma_index, sample_chunks, sample_embeddings, embedding_dimension, mock_chroma_collection
):
# Setup: Add chunks first
await chroma_index.add_chunks(sample_chunks, sample_embeddings)
# Test vector search
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
response = await chroma_index.query_vector(query_embedding, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
mock_chroma_collection.query.assert_called_once()
async def test_query_chunks_keyword_search(chroma_index, sample_chunks, sample_embeddings, mock_chroma_collection):
await chroma_index.add_chunks(sample_chunks, sample_embeddings)
# Test keyword search
query_string = "Sentence 5"
response = await chroma_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
async def test_delete_collection(chroma_index, mock_chroma_client):
# Test collection deletion
await chroma_index.delete()
mock_chroma_client.delete_collection.assert_called_once_with(chroma_index.collection.name)

4
uv.lock generated
View file

@ -1,5 +1,5 @@
version = 1 version = 1
revision = 2 revision = 3
requires-python = ">=3.12" requires-python = ">=3.12"
resolution-markers = [ resolution-markers = [
"(python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'linux')", "(python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'linux')",
@ -1767,6 +1767,7 @@ dependencies = [
{ name = "opentelemetry-exporter-otlp-proto-http" }, { name = "opentelemetry-exporter-otlp-proto-http" },
{ name = "opentelemetry-sdk" }, { name = "opentelemetry-sdk" },
{ name = "pillow" }, { name = "pillow" },
{ name = "pre-commit" },
{ name = "prompt-toolkit" }, { name = "prompt-toolkit" },
{ name = "pydantic" }, { name = "pydantic" },
{ name = "python-dotenv" }, { name = "python-dotenv" },
@ -1892,6 +1893,7 @@ requires-dist = [
{ name = "opentelemetry-sdk", specifier = ">=1.30.0" }, { name = "opentelemetry-sdk", specifier = ">=1.30.0" },
{ name = "pandas", marker = "extra == 'ui'" }, { name = "pandas", marker = "extra == 'ui'" },
{ name = "pillow" }, { name = "pillow" },
{ name = "pre-commit", specifier = ">=4.2.0" },
{ name = "prompt-toolkit" }, { name = "prompt-toolkit" },
{ name = "pydantic", specifier = ">=2" }, { name = "pydantic", specifier = ">=2" },
{ name = "python-dotenv" }, { name = "python-dotenv" },