mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
change Reranker to WeightedInMemoryAggregator
This commit is contained in:
parent
60318b659d
commit
897be1376e
6 changed files with 22 additions and 142 deletions
|
@ -3,15 +3,15 @@
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
The Batches API enables efficient processing of multiple requests in a single operation,
|
The Batches API enables efficient processing of multiple requests in a single operation,
|
||||||
particularly useful for processing large datasets, batch evaluation workflows, and
|
particularly useful for processing large datasets, batch evaluation workflows, and
|
||||||
cost-effective inference at scale.
|
cost-effective inference at scale.
|
||||||
|
|
||||||
The API is designed to allow use of openai client libraries for seamless integration.
|
The API is designed to allow use of openai client libraries for seamless integration.
|
||||||
|
|
||||||
This API provides the following extensions:
|
This API provides the following extensions:
|
||||||
- idempotent batch creation
|
- idempotent batch creation
|
||||||
|
|
||||||
Note: This API is currently under active development and may undergo changes.
|
Note: This API is currently under active development and may undergo changes.
|
||||||
|
|
||||||
This section contains documentation for all available providers for the **batches** API.
|
This section contains documentation for all available providers for the **batches** API.
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
VectorDBWithIndex,
|
VectorDBWithIndex,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.vector_io.vector_utils import Reranker
|
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator
|
||||||
|
|
||||||
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
|
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
|
||||||
|
|
||||||
|
@ -192,7 +192,9 @@ class ChromaIndex(EmbeddingIndex):
|
||||||
}
|
}
|
||||||
|
|
||||||
# Combine scores using the reranking utility
|
# Combine scores using the reranking utility
|
||||||
combined_scores = Reranker.combine_search_results(vector_scores, keyword_scores, reranker_type, reranker_params)
|
combined_scores = WeightedInMemoryAggregator.combine_search_results(
|
||||||
|
vector_scores, keyword_scores, reranker_type, reranker_params
|
||||||
|
)
|
||||||
|
|
||||||
# Efficient top-k selection because it only tracks the k best candidates it's seen so far
|
# Efficient top-k selection because it only tracks the k best candidates it's seen so far
|
||||||
top_k_items = heapq.nlargest(k, combined_scores.items(), key=lambda x: x[1])
|
top_k_items = heapq.nlargest(k, combined_scores.items(), key=lambda x: x[1])
|
||||||
|
|
|
@ -39,7 +39,6 @@ def sanitize_collection_name(name: str, weaviate_format=False) -> str:
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class WeightedInMemoryAggregator:
|
class WeightedInMemoryAggregator:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _normalize_scores(scores: dict[str, float]) -> dict[str, float]:
|
def _normalize_scores(scores: dict[str, float]) -> dict[str, float]:
|
||||||
|
|
|
@ -25,14 +25,14 @@ classifiers = [
|
||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
"fastapi>=0.115.0,<1.0", # server
|
"fastapi>=0.115.0,<1.0", # server
|
||||||
"fire", # for MCP in LLS client
|
"fire", # for MCP in LLS client
|
||||||
"httpx",
|
"httpx",
|
||||||
"huggingface-hub>=0.34.0,<1.0",
|
"huggingface-hub>=0.34.0,<1.0",
|
||||||
"jinja2>=3.1.6",
|
"jinja2>=3.1.6",
|
||||||
"jsonschema",
|
"jsonschema",
|
||||||
"llama-stack-client>=0.2.21",
|
"llama-stack-client>=0.2.21",
|
||||||
"openai>=1.100.0", # for expires_after support
|
"openai>=1.100.0", # for expires_after support
|
||||||
"prompt-toolkit",
|
"prompt-toolkit",
|
||||||
"python-dotenv",
|
"python-dotenv",
|
||||||
"python-jose[cryptography]",
|
"python-jose[cryptography]",
|
||||||
|
@ -43,12 +43,13 @@ dependencies = [
|
||||||
"tiktoken",
|
"tiktoken",
|
||||||
"pillow",
|
"pillow",
|
||||||
"h11>=0.16.0",
|
"h11>=0.16.0",
|
||||||
"python-multipart>=0.0.20", # For fastapi Form
|
"python-multipart>=0.0.20", # For fastapi Form
|
||||||
"uvicorn>=0.34.0", # server
|
"uvicorn>=0.34.0", # server
|
||||||
"opentelemetry-sdk>=1.30.0", # server
|
"opentelemetry-sdk>=1.30.0", # server
|
||||||
"opentelemetry-exporter-otlp-proto-http>=1.30.0", # server
|
"opentelemetry-exporter-otlp-proto-http>=1.30.0", # server
|
||||||
"aiosqlite>=0.21.0", # server - for metadata store
|
"aiosqlite>=0.21.0", # server - for metadata store
|
||||||
"asyncpg", # for metadata store
|
"asyncpg", # for metadata store
|
||||||
|
"pre-commit>=4.2.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|
|
@ -1,124 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import json
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from llama_stack.apis.vector_io import QueryChunksResponse
|
|
||||||
|
|
||||||
# Mock the entire chromadb module
|
|
||||||
chromadb_mock = MagicMock()
|
|
||||||
chromadb_mock.AsyncHttpClient = MagicMock
|
|
||||||
chromadb_mock.PersistentClient = MagicMock
|
|
||||||
|
|
||||||
# Apply the mock before importing ChromaIndex
|
|
||||||
with patch.dict("sys.modules", {"chromadb": chromadb_mock}):
|
|
||||||
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex
|
|
||||||
|
|
||||||
# This test is a unit test for the ChromaVectorIOAdapter class. This should only contain
|
|
||||||
# tests which are specific to this class. More general (API-level) tests should be placed in
|
|
||||||
# tests/integration/vector_io/
|
|
||||||
#
|
|
||||||
# How to run this test:
|
|
||||||
#
|
|
||||||
# pytest tests/unit/providers/vector_io/test_chroma.py \
|
|
||||||
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
|
||||||
|
|
||||||
CHROMA_PROVIDER = "chromadb"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def mock_chroma_collection() -> MagicMock:
|
|
||||||
"""Create a mock Chroma collection with common method behaviors."""
|
|
||||||
collection = MagicMock()
|
|
||||||
collection.name = "test_collection"
|
|
||||||
|
|
||||||
# Mock add operation
|
|
||||||
collection.add.return_value = None
|
|
||||||
|
|
||||||
# Mock query operation for vector search
|
|
||||||
collection.query.return_value = {
|
|
||||||
"distances": [[0.1, 0.2]],
|
|
||||||
"documents": [
|
|
||||||
[
|
|
||||||
json.dumps({"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}),
|
|
||||||
json.dumps({"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}),
|
|
||||||
]
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Mock delete operation
|
|
||||||
collection.delete.return_value = None
|
|
||||||
|
|
||||||
return collection
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def mock_chroma_client(mock_chroma_collection):
|
|
||||||
"""Create a mock Chroma client with common method behaviors."""
|
|
||||||
client = MagicMock()
|
|
||||||
|
|
||||||
# Mock collection operations
|
|
||||||
client.get_or_create_collection.return_value = mock_chroma_collection
|
|
||||||
client.get_collection.return_value = mock_chroma_collection
|
|
||||||
client.delete_collection.return_value = None
|
|
||||||
|
|
||||||
return client
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def chroma_index(mock_chroma_client, mock_chroma_collection):
|
|
||||||
"""Create a ChromaIndex with mocked client and collection."""
|
|
||||||
index = ChromaIndex(client=mock_chroma_client, collection=mock_chroma_collection)
|
|
||||||
yield index
|
|
||||||
# No real cleanup needed since we're using mocks
|
|
||||||
|
|
||||||
|
|
||||||
async def test_add_chunks(chroma_index, sample_chunks, sample_embeddings, mock_chroma_collection):
|
|
||||||
await chroma_index.add_chunks(sample_chunks, sample_embeddings)
|
|
||||||
|
|
||||||
# Verify data was inserted
|
|
||||||
mock_chroma_collection.add.assert_called_once()
|
|
||||||
|
|
||||||
# Verify the add call had the right number of chunks
|
|
||||||
add_call = mock_chroma_collection.add.call_args
|
|
||||||
assert len(add_call[1]["documents"]) == len(sample_chunks)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_query_chunks_vector(
|
|
||||||
chroma_index, sample_chunks, sample_embeddings, embedding_dimension, mock_chroma_collection
|
|
||||||
):
|
|
||||||
# Setup: Add chunks first
|
|
||||||
await chroma_index.add_chunks(sample_chunks, sample_embeddings)
|
|
||||||
|
|
||||||
# Test vector search
|
|
||||||
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
|
||||||
response = await chroma_index.query_vector(query_embedding, k=2, score_threshold=0.0)
|
|
||||||
|
|
||||||
assert isinstance(response, QueryChunksResponse)
|
|
||||||
assert len(response.chunks) == 2
|
|
||||||
mock_chroma_collection.query.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
async def test_query_chunks_keyword_search(chroma_index, sample_chunks, sample_embeddings, mock_chroma_collection):
|
|
||||||
await chroma_index.add_chunks(sample_chunks, sample_embeddings)
|
|
||||||
|
|
||||||
# Test keyword search
|
|
||||||
query_string = "Sentence 5"
|
|
||||||
response = await chroma_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0)
|
|
||||||
|
|
||||||
assert isinstance(response, QueryChunksResponse)
|
|
||||||
assert len(response.chunks) == 2
|
|
||||||
|
|
||||||
|
|
||||||
async def test_delete_collection(chroma_index, mock_chroma_client):
|
|
||||||
# Test collection deletion
|
|
||||||
await chroma_index.delete()
|
|
||||||
|
|
||||||
mock_chroma_client.delete_collection.assert_called_once_with(chroma_index.collection.name)
|
|
4
uv.lock
generated
4
uv.lock
generated
|
@ -1,5 +1,5 @@
|
||||||
version = 1
|
version = 1
|
||||||
revision = 2
|
revision = 3
|
||||||
requires-python = ">=3.12"
|
requires-python = ">=3.12"
|
||||||
resolution-markers = [
|
resolution-markers = [
|
||||||
"(python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
"(python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||||
|
@ -1767,6 +1767,7 @@ dependencies = [
|
||||||
{ name = "opentelemetry-exporter-otlp-proto-http" },
|
{ name = "opentelemetry-exporter-otlp-proto-http" },
|
||||||
{ name = "opentelemetry-sdk" },
|
{ name = "opentelemetry-sdk" },
|
||||||
{ name = "pillow" },
|
{ name = "pillow" },
|
||||||
|
{ name = "pre-commit" },
|
||||||
{ name = "prompt-toolkit" },
|
{ name = "prompt-toolkit" },
|
||||||
{ name = "pydantic" },
|
{ name = "pydantic" },
|
||||||
{ name = "python-dotenv" },
|
{ name = "python-dotenv" },
|
||||||
|
@ -1892,6 +1893,7 @@ requires-dist = [
|
||||||
{ name = "opentelemetry-sdk", specifier = ">=1.30.0" },
|
{ name = "opentelemetry-sdk", specifier = ">=1.30.0" },
|
||||||
{ name = "pandas", marker = "extra == 'ui'" },
|
{ name = "pandas", marker = "extra == 'ui'" },
|
||||||
{ name = "pillow" },
|
{ name = "pillow" },
|
||||||
|
{ name = "pre-commit", specifier = ">=4.2.0" },
|
||||||
{ name = "prompt-toolkit" },
|
{ name = "prompt-toolkit" },
|
||||||
{ name = "pydantic", specifier = ">=2" },
|
{ name = "pydantic", specifier = ">=2" },
|
||||||
{ name = "python-dotenv" },
|
{ name = "python-dotenv" },
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue