mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
feat: implement keyword, vector and hybrid search inside vector stores for PGVector provider (#3064)
# What does this PR do? <!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. --> The purpose of this task is to implement `openai/v1/vector_stores/{vector_store_id}/search` for PGVector provider. It involves implementing vector similarity search, keyword search and hybrid search for `PGVectorIndex`. <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> Closes #3006 ## Test Plan <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* --> Run unit tests: ` ./scripts/unit-tests.sh ` Run integration tests for openai vector stores: 1. Export env vars: ``` export ENABLE_PGVECTOR=true export PGVECTOR_HOST=localhost export PGVECTOR_PORT=5432 export PGVECTOR_DB=llamastack export PGVECTOR_USER=llamastack export PGVECTOR_PASSWORD=llamastack ``` 2. Create DB: ``` psql -h localhost -U postgres -c "CREATE ROLE llamastack LOGIN PASSWORD 'llamastack';" psql -h localhost -U postgres -c "CREATE DATABASE llamastack OWNER llamastack;" psql -h localhost -U llamastack -d llamastack -c "CREATE EXTENSION IF NOT EXISTS vector;" ``` 3. Install sentence-transformers: ` uv pip install sentence-transformers ` 4. Run: ``` uv run --group test pytest -s -v --stack-config="inference=inline::sentence-transformers,vector_io=remote::pgvector" --embedding-model sentence-transformers/all-MiniLM-L6-v2 tests/integration/vector_io/test_openai_vector_stores.py ``` Inspect PGVector vector stores (optional): ``` psql llamastack psql (14.18 (Homebrew)) Type "help" for help. llamastack=# \z Access privileges Schema | Name | Type | Access privileges | Column privileges | Policies --------+------------------------------------------------------+-------+-------------------+-------------------+---------- public | llamastack_kvstore | table | | | public | metadata_store | table | | | public | vector_store_pgvector_main | table | | | public | vector_store_vs_1dfbc061_1f4d_4497_9165_ecba2622ba3a | table | | | public | vector_store_vs_2085a9fb_1822_4e42_a277_c6a685843fa7 | table | | | public | vector_store_vs_2b3dae46_38be_462a_afd6_37ee5fe661b1 | table | | | public | vector_store_vs_2f438de6_f606_4561_9d50_ef9160eb9060 | table | | | public | vector_store_vs_3eeca564_2580_4c68_bfea_83dc57e31214 | table | | | public | vector_store_vs_53942163_05f3_40e0_83c0_0997c64613da | table | | | public | vector_store_vs_545bac75_8950_4ff1_b084_e221192d4709 | table | | | public | vector_store_vs_688a37d8_35b2_4298_a035_bfedf5b21f86 | table | | | public | vector_store_vs_70624d9a_f6ac_4c42_b8ab_0649473c6600 | table | | | public | vector_store_vs_73fc1dd2_e942_4972_afb1_1e177b591ac2 | table | | | public | vector_store_vs_9d464949_d51f_49db_9f87_e033b8b84ac9 | table | | | public | vector_store_vs_a1e4d724_5162_4d6d_a6c0_bdafaf6b76ec | table | | | public | vector_store_vs_a328fb1b_1a21_480f_9624_ffaa60fb6672 | table | | | public | vector_store_vs_a8981bf0_2e66_4445_a267_a8fff442db53 | table | | | public | vector_store_vs_ccd4b6a4_1efd_4984_ad03_e7ff8eadb296 | table | | | public | vector_store_vs_cd6420a4_a1fc_4cec_948c_1413a26281c9 | table | | | public | vector_store_vs_cd709284_e5cf_4a88_aba5_dc76a35364bd | table | | | public | vector_store_vs_d7a4548e_fbc1_44d7_b2ec_b664417f2a46 | table | | | public | vector_store_vs_e7f73231_414c_4523_886c_d1174eee836e | table | | | public | vector_store_vs_ffd53588_819f_47e8_bb9d_954af6f7833d | table | | | (23 rows) llamastack=# ``` Co-authored-by: Francisco Arceo <arceofrancisco@gmail.com>
This commit is contained in:
parent
e96e3c4da4
commit
3130ca0a78
11 changed files with 1014 additions and 29 deletions
|
@ -404,6 +404,60 @@ That means you'll get fast and efficient vector retrieval.
|
|||
- Easy to use
|
||||
- Fully integrated with Llama Stack
|
||||
|
||||
There are three implementations of search for PGVectoIndex available:
|
||||
|
||||
1. Vector Search:
|
||||
- How it works:
|
||||
- Uses PostgreSQL's vector extension (pgvector) to perform similarity search
|
||||
- Compares query embeddings against stored embeddings using Cosine distance or other distance metrics
|
||||
- Eg. SQL query: SELECT document, embedding <=> %s::vector AS distance FROM table ORDER BY distance
|
||||
|
||||
-Characteristics:
|
||||
- Semantic understanding - finds documents similar in meaning even if they don't share keywords
|
||||
- Works with high-dimensional vector embeddings (typically 768, 1024, or higher dimensions)
|
||||
- Best for: Finding conceptually related content, handling synonyms, cross-language search
|
||||
|
||||
2. Keyword Search
|
||||
- How it works:
|
||||
- Uses PostgreSQL's full-text search capabilities with tsvector and ts_rank
|
||||
- Converts text to searchable tokens using to_tsvector('english', text). Default language is English.
|
||||
- Eg. SQL query: SELECT document, ts_rank(tokenized_content, plainto_tsquery('english', %s)) AS score
|
||||
|
||||
- Characteristics:
|
||||
- Lexical matching - finds exact keyword matches and variations
|
||||
- Uses GIN (Generalized Inverted Index) for fast text search performance
|
||||
- Scoring: Uses PostgreSQL's ts_rank function for relevance scoring
|
||||
- Best for: Exact term matching, proper names, technical terms, Boolean-style queries
|
||||
|
||||
3. Hybrid Search
|
||||
- How it works:
|
||||
- Combines both vector and keyword search results
|
||||
- Runs both searches independently, then merges results using configurable reranking
|
||||
|
||||
- Two reranking strategies available:
|
||||
- Reciprocal Rank Fusion (RRF) - (default: 60.0)
|
||||
- Weighted Average - (default: 0.5)
|
||||
|
||||
- Characteristics:
|
||||
- Best of both worlds: semantic understanding + exact matching
|
||||
- Documents appearing in both searches get boosted scores
|
||||
- Configurable balance between semantic and lexical matching
|
||||
- Best for: General-purpose search where you want both precision and recall
|
||||
|
||||
4. Database Schema
|
||||
The PGVector implementation stores data optimized for all three search types:
|
||||
CREATE TABLE vector_store_xxx (
|
||||
id TEXT PRIMARY KEY,
|
||||
document JSONB, -- Original document
|
||||
embedding vector(dimension), -- For vector search
|
||||
content_text TEXT, -- Raw text content
|
||||
tokenized_content TSVECTOR -- For keyword search
|
||||
);
|
||||
|
||||
-- Indexes for performance
|
||||
CREATE INDEX content_gin_idx ON table USING GIN(tokenized_content); -- Keyword search
|
||||
-- Vector index created automatically by pgvector
|
||||
|
||||
## Usage
|
||||
|
||||
To use PGVector in your Llama Stack project, follow these steps:
|
||||
|
@ -412,6 +466,25 @@ To use PGVector in your Llama Stack project, follow these steps:
|
|||
2. Configure your Llama Stack project to use pgvector. (e.g. remote::pgvector).
|
||||
3. Start storing and querying vectors.
|
||||
|
||||
## This is an example how you can set up your environment for using PGVector
|
||||
|
||||
1. Export env vars:
|
||||
```bash
|
||||
export ENABLE_PGVECTOR=true
|
||||
export PGVECTOR_HOST=localhost
|
||||
export PGVECTOR_PORT=5432
|
||||
export PGVECTOR_DB=llamastack
|
||||
export PGVECTOR_USER=llamastack
|
||||
export PGVECTOR_PASSWORD=llamastack
|
||||
```
|
||||
|
||||
2. Create DB:
|
||||
```bash
|
||||
psql -h localhost -U postgres -c "CREATE ROLE llamastack LOGIN PASSWORD 'llamastack';"
|
||||
psql -h localhost -U postgres -c "CREATE DATABASE llamastack OWNER llamastack;"
|
||||
psql -h localhost -U llamastack -d llamastack -c "CREATE EXTENSION IF NOT EXISTS vector;"
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
You can install PGVector using docker:
|
||||
|
@ -449,6 +522,7 @@ Weaviate supports:
|
|||
- Metadata filtering
|
||||
- Multi-modal retrieval
|
||||
|
||||
|
||||
## Usage
|
||||
|
||||
To use Weaviate in your Llama Stack project, follow these steps:
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import heapq
|
||||
from typing import Any
|
||||
|
||||
import psycopg2
|
||||
|
@ -23,6 +24,9 @@ from llama_stack.apis.vector_io import (
|
|||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
|
@ -31,6 +35,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator, sanitize_collection_name
|
||||
|
||||
from .config import PGVectorVectorIOConfig
|
||||
|
||||
|
@ -72,25 +77,63 @@ def load_models(cur, cls):
|
|||
|
||||
|
||||
class PGVectorIndex(EmbeddingIndex):
|
||||
def __init__(self, vector_db: VectorDB, dimension: int, conn, kvstore: KVStore | None = None):
|
||||
self.conn = conn
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
# Sanitize the table name by replacing hyphens with underscores
|
||||
# SQL doesn't allow hyphens in table names, and vector_db.identifier may contain hyphens
|
||||
# when created with patterns like "test-vector-db-{uuid4()}"
|
||||
sanitized_identifier = vector_db.identifier.replace("-", "_")
|
||||
self.table_name = f"vector_store_{sanitized_identifier}"
|
||||
self.kvstore = kvstore
|
||||
# reference: https://github.com/pgvector/pgvector?tab=readme-ov-file#querying
|
||||
PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION: dict[str, str] = {
|
||||
"L2": "<->",
|
||||
"L1": "<+>",
|
||||
"COSINE": "<=>",
|
||||
"INNER_PRODUCT": "<#>",
|
||||
"HAMMING": "<~>",
|
||||
"JACCARD": "<%>",
|
||||
}
|
||||
|
||||
cur.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
||||
id TEXT PRIMARY KEY,
|
||||
document JSONB,
|
||||
embedding vector({dimension})
|
||||
def __init__(
|
||||
self,
|
||||
vector_db: VectorDB,
|
||||
dimension: int,
|
||||
conn: psycopg2.extensions.connection,
|
||||
kvstore: KVStore | None = None,
|
||||
distance_metric: str = "COSINE",
|
||||
):
|
||||
self.vector_db = vector_db
|
||||
self.dimension = dimension
|
||||
self.conn = conn
|
||||
self.kvstore = kvstore
|
||||
self.check_distance_metric_availability(distance_metric)
|
||||
self.distance_metric = distance_metric
|
||||
self.table_name = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
# Sanitize the table name by replacing hyphens with underscores
|
||||
# SQL doesn't allow hyphens in table names, and vector_db.identifier may contain hyphens
|
||||
# when created with patterns like "test-vector-db-{uuid4()}"
|
||||
sanitized_identifier = sanitize_collection_name(self.vector_db.identifier)
|
||||
self.table_name = f"vs_{sanitized_identifier}"
|
||||
|
||||
cur.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
||||
id TEXT PRIMARY KEY,
|
||||
document JSONB,
|
||||
embedding vector({self.dimension}),
|
||||
content_text TEXT,
|
||||
tokenized_content TSVECTOR
|
||||
)
|
||||
"""
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Create GIN index for full-text search performance
|
||||
cur.execute(
|
||||
f"""
|
||||
CREATE INDEX IF NOT EXISTS {self.table_name}_content_gin_idx
|
||||
ON {self.table_name} USING GIN(tokenized_content)
|
||||
"""
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(f"Error creating PGVectorIndex for vector_db: {self.vector_db.identifier}")
|
||||
raise RuntimeError(f"Error creating PGVectorIndex for vector_db: {self.vector_db.identifier}") from e
|
||||
|
||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
||||
assert len(chunks) == len(embeddings), (
|
||||
|
@ -99,29 +142,49 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
|
||||
values = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
content_text = interleaved_content_as_str(chunk.content)
|
||||
values.append(
|
||||
(
|
||||
f"{chunk.chunk_id}",
|
||||
Json(chunk.model_dump()),
|
||||
embeddings[i].tolist(),
|
||||
content_text,
|
||||
content_text, # Pass content_text twice - once for content_text column, once for to_tsvector function. Eg. to_tsvector(content_text) = tokenized_content
|
||||
)
|
||||
)
|
||||
|
||||
query = sql.SQL(
|
||||
f"""
|
||||
INSERT INTO {self.table_name} (id, document, embedding)
|
||||
INSERT INTO {self.table_name} (id, document, embedding, content_text, tokenized_content)
|
||||
VALUES %s
|
||||
ON CONFLICT (id) DO UPDATE SET embedding = EXCLUDED.embedding, document = EXCLUDED.document
|
||||
ON CONFLICT (id) DO UPDATE SET
|
||||
embedding = EXCLUDED.embedding,
|
||||
document = EXCLUDED.document,
|
||||
content_text = EXCLUDED.content_text,
|
||||
tokenized_content = EXCLUDED.tokenized_content
|
||||
"""
|
||||
)
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
execute_values(cur, query, values, template="(%s, %s, %s::vector)")
|
||||
execute_values(cur, query, values, template="(%s, %s, %s::vector, %s, to_tsvector('english', %s))")
|
||||
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
Performs vector similarity search using PostgreSQL's search function. Default distance metric is COSINE.
|
||||
|
||||
Args:
|
||||
embedding: The query embedding vector
|
||||
k: Number of results to return
|
||||
score_threshold: Minimum similarity score threshold
|
||||
|
||||
Returns:
|
||||
QueryChunksResponse with combined results
|
||||
"""
|
||||
pgvector_search_function = self.get_pgvector_search_function()
|
||||
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
cur.execute(
|
||||
f"""
|
||||
SELECT document, embedding <-> %s::vector AS distance
|
||||
SELECT document, embedding {pgvector_search_function} %s::vector AS distance
|
||||
FROM {self.table_name}
|
||||
ORDER BY distance
|
||||
LIMIT %s
|
||||
|
@ -147,7 +210,40 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in PGVector")
|
||||
"""
|
||||
Performs keyword-based search using PostgreSQL's full-text search with ts_rank scoring.
|
||||
|
||||
Args:
|
||||
query_string: The text query for keyword search
|
||||
k: Number of results to return
|
||||
score_threshold: Minimum similarity score threshold
|
||||
|
||||
Returns:
|
||||
QueryChunksResponse with combined results
|
||||
"""
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
# Use plainto_tsquery to handle user input safely and ts_rank for relevance scoring
|
||||
cur.execute(
|
||||
f"""
|
||||
SELECT document, ts_rank(tokenized_content, plainto_tsquery('english', %s)) AS score
|
||||
FROM {self.table_name}
|
||||
WHERE tokenized_content @@ plainto_tsquery('english', %s)
|
||||
ORDER BY score DESC
|
||||
LIMIT %s
|
||||
""",
|
||||
(query_string, query_string, k),
|
||||
)
|
||||
results = cur.fetchall()
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
for doc, score in results:
|
||||
if score < score_threshold:
|
||||
continue
|
||||
chunks.append(Chunk(**doc))
|
||||
scores.append(float(score))
|
||||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_hybrid(
|
||||
self,
|
||||
|
@ -158,7 +254,59 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Hybrid search is not supported in PGVector")
|
||||
"""
|
||||
Hybrid search combining vector similarity and keyword search using configurable reranking.
|
||||
|
||||
Args:
|
||||
embedding: The query embedding vector
|
||||
query_string: The text query for keyword search
|
||||
k: Number of results to return
|
||||
score_threshold: Minimum similarity score threshold
|
||||
reranker_type: Type of reranker to use ("rrf" or "weighted")
|
||||
reranker_params: Parameters for the reranker
|
||||
|
||||
Returns:
|
||||
QueryChunksResponse with combined results
|
||||
"""
|
||||
if reranker_params is None:
|
||||
reranker_params = {}
|
||||
|
||||
# Get results from both search methods
|
||||
vector_response = await self.query_vector(embedding, k, score_threshold)
|
||||
keyword_response = await self.query_keyword(query_string, k, score_threshold)
|
||||
|
||||
# Convert responses to score dictionaries using chunk_id
|
||||
vector_scores = {
|
||||
chunk.chunk_id: score for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False)
|
||||
}
|
||||
keyword_scores = {
|
||||
chunk.chunk_id: score
|
||||
for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False)
|
||||
}
|
||||
|
||||
# Combine scores using the reranking utility
|
||||
combined_scores = WeightedInMemoryAggregator.combine_search_results(
|
||||
vector_scores, keyword_scores, reranker_type, reranker_params
|
||||
)
|
||||
|
||||
# Efficient top-k selection because it only tracks the k best candidates it's seen so far
|
||||
top_k_items = heapq.nlargest(k, combined_scores.items(), key=lambda x: x[1])
|
||||
|
||||
# Filter by score threshold
|
||||
filtered_items = [(doc_id, score) for doc_id, score in top_k_items if score >= score_threshold]
|
||||
|
||||
# Create a map of chunk_id to chunk for both responses
|
||||
chunk_map = {c.chunk_id: c for c in vector_response.chunks + keyword_response.chunks}
|
||||
|
||||
# Use the map to look up chunks by their IDs
|
||||
chunks = []
|
||||
scores = []
|
||||
for doc_id, score in filtered_items:
|
||||
if doc_id in chunk_map:
|
||||
chunks.append(chunk_map[doc_id])
|
||||
scores.append(score)
|
||||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def delete(self):
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
|
@ -170,6 +318,25 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids,))
|
||||
|
||||
def get_pgvector_search_function(self) -> str:
|
||||
return self.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION[self.distance_metric]
|
||||
|
||||
def check_distance_metric_availability(self, distance_metric: str) -> None:
|
||||
"""Check if the distance metric is supported by PGVector.
|
||||
|
||||
Args:
|
||||
distance_metric: The distance metric to check
|
||||
|
||||
Raises:
|
||||
ValueError: If the distance metric is not supported
|
||||
"""
|
||||
if distance_metric not in self.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION:
|
||||
supported_metrics = list(self.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION.keys())
|
||||
raise ValueError(
|
||||
f"Distance metric '{distance_metric}' is not supported by PGVector. "
|
||||
f"Supported metrics are: {', '.join(supported_metrics)}"
|
||||
)
|
||||
|
||||
|
||||
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(
|
||||
|
@ -185,8 +352,8 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
|||
self.files_api = files_api
|
||||
self.kvstore: KVStore | None = None
|
||||
self.vector_db_store = None
|
||||
self.openai_vector_store: dict[str, dict[str, Any]] = {}
|
||||
self.metadatadata_collection_name = "openai_vector_stores_metadata"
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
log.info(f"Initializing PGVector memory adapter with config: {self.config}")
|
||||
|
@ -233,9 +400,13 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
|||
upsert_models(self.conn, [(vector_db.identifier, vector_db)])
|
||||
|
||||
# Create and cache the PGVector index table for the vector DB
|
||||
pgvector_index = PGVectorIndex(
|
||||
vector_db=vector_db, dimension=vector_db.embedding_dimension, conn=self.conn, kvstore=self.kvstore
|
||||
)
|
||||
await pgvector_index.initialize()
|
||||
index = VectorDBWithIndex(
|
||||
vector_db,
|
||||
index=PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn, kvstore=self.kvstore),
|
||||
index=pgvector_index,
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_db.identifier] = index
|
||||
|
@ -272,8 +443,15 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
|||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
|
||||
if self.vector_db_store is None:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
||||
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
|
||||
if not vector_db:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
||||
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn)
|
||||
await index.initialize()
|
||||
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
return self.cache[vector_db_id]
|
||||
|
||||
|
|
|
@ -37,3 +37,122 @@ def sanitize_collection_name(name: str, weaviate_format=False) -> str:
|
|||
else:
|
||||
s = proper_case(re.sub(r"[^a-zA-Z0-9]", "", name))
|
||||
return s
|
||||
|
||||
|
||||
class WeightedInMemoryAggregator:
|
||||
@staticmethod
|
||||
def _normalize_scores(scores: dict[str, float]) -> dict[str, float]:
|
||||
"""
|
||||
Normalize scores to 0-1 range using min-max normalization.
|
||||
|
||||
Args:
|
||||
scores: dictionary of scores with document IDs as keys and scores as values
|
||||
|
||||
Returns:
|
||||
Normalized scores with document IDs as keys and normalized scores as values
|
||||
"""
|
||||
if not scores:
|
||||
return {}
|
||||
min_score, max_score = min(scores.values()), max(scores.values())
|
||||
score_range = max_score - min_score
|
||||
if score_range > 0:
|
||||
return {doc_id: (score - min_score) / score_range for doc_id, score in scores.items()}
|
||||
return dict.fromkeys(scores, 1.0)
|
||||
|
||||
@staticmethod
|
||||
def weighted_rerank(
|
||||
vector_scores: dict[str, float],
|
||||
keyword_scores: dict[str, float],
|
||||
alpha: float = 0.5,
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
Rerank via weighted average of scores.
|
||||
|
||||
Args:
|
||||
vector_scores: scores from vector search
|
||||
keyword_scores: scores from keyword search
|
||||
alpha: weight factor between 0 and 1 (default: 0.5)
|
||||
0 = keyword only, 1 = vector only, 0.5 = equal weight
|
||||
|
||||
Returns:
|
||||
All unique document IDs with weighted combined scores
|
||||
"""
|
||||
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
|
||||
normalized_vector_scores = WeightedInMemoryAggregator._normalize_scores(vector_scores)
|
||||
normalized_keyword_scores = WeightedInMemoryAggregator._normalize_scores(keyword_scores)
|
||||
|
||||
# Weighted formula: score = (1-alpha) * keyword_score + alpha * vector_score
|
||||
# alpha=0 means keyword only, alpha=1 means vector only
|
||||
return {
|
||||
doc_id: ((1 - alpha) * normalized_keyword_scores.get(doc_id, 0.0))
|
||||
+ (alpha * normalized_vector_scores.get(doc_id, 0.0))
|
||||
for doc_id in all_ids
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def rrf_rerank(
|
||||
vector_scores: dict[str, float],
|
||||
keyword_scores: dict[str, float],
|
||||
impact_factor: float = 60.0,
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
Rerank via Reciprocal Rank Fusion.
|
||||
|
||||
Args:
|
||||
vector_scores: scores from vector search
|
||||
keyword_scores: scores from keyword search
|
||||
impact_factor: impact factor for RRF (default: 60.0)
|
||||
|
||||
Returns:
|
||||
All unique document IDs with RRF combined scores
|
||||
"""
|
||||
|
||||
# Convert scores to ranks
|
||||
vector_ranks = {
|
||||
doc_id: i + 1
|
||||
for i, (doc_id, _) in enumerate(sorted(vector_scores.items(), key=lambda x: x[1], reverse=True))
|
||||
}
|
||||
keyword_ranks = {
|
||||
doc_id: i + 1
|
||||
for i, (doc_id, _) in enumerate(sorted(keyword_scores.items(), key=lambda x: x[1], reverse=True))
|
||||
}
|
||||
|
||||
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
|
||||
rrf_scores = {}
|
||||
for doc_id in all_ids:
|
||||
vector_rank = vector_ranks.get(doc_id, float("inf"))
|
||||
keyword_rank = keyword_ranks.get(doc_id, float("inf"))
|
||||
|
||||
# RRF formula: score = 1/(k + r) where k is impact_factor (default: 60.0) and r is the rank
|
||||
rrf_scores[doc_id] = (1.0 / (impact_factor + vector_rank)) + (1.0 / (impact_factor + keyword_rank))
|
||||
return rrf_scores
|
||||
|
||||
@staticmethod
|
||||
def combine_search_results(
|
||||
vector_scores: dict[str, float],
|
||||
keyword_scores: dict[str, float],
|
||||
reranker_type: str = "rrf",
|
||||
reranker_params: dict[str, float] | None = None,
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
Combine vector and keyword search results using specified reranking strategy.
|
||||
|
||||
Args:
|
||||
vector_scores: scores from vector search
|
||||
keyword_scores: scores from keyword search
|
||||
reranker_type: type of reranker to use (default: RERANKER_TYPE_RRF)
|
||||
reranker_params: parameters for the reranker
|
||||
|
||||
Returns:
|
||||
All unique document IDs with combined scores
|
||||
"""
|
||||
if reranker_params is None:
|
||||
reranker_params = {}
|
||||
|
||||
if reranker_type == "weighted":
|
||||
alpha = reranker_params.get("alpha", 0.5)
|
||||
return WeightedInMemoryAggregator.weighted_rerank(vector_scores, keyword_scores, alpha)
|
||||
else:
|
||||
# Default to RRF for None, RRF, or any unknown types
|
||||
impact_factor = reranker_params.get("impact_factor", 60.0)
|
||||
return WeightedInMemoryAggregator.rrf_rerank(vector_scores, keyword_scores, impact_factor)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue