feat: implement keyword, vector and hybrid search inside vector stores for PGVector provider

This commit is contained in:
r3v5 2025-08-07 16:57:19 +01:00
parent 46ff302d87
commit a29af745e4
No known key found for this signature in database
GPG key ID: 7758B9F272DE67D9
11 changed files with 1332 additions and 83 deletions

View file

@ -12,6 +12,60 @@ That means you'll get fast and efficient vector retrieval.
- Easy to use - Easy to use
- Fully integrated with Llama Stack - Fully integrated with Llama Stack
Three implementations of search for PGVectoIndex:
1. Vector Search:
- How it works:
- Uses PostgreSQL's vector extension (pgvector) to perform similarity search
- Compares query embeddings against stored embeddings using Cosine distance or other distance metrics
- Eg. SQL query: SELECT document, embedding <=> %s::vector AS distance FROM table ORDER BY distance
-Characteristics:
- Semantic understanding - finds documents similar in meaning even if they don't share keywords
- Works with high-dimensional vector embeddings (typically 768, 1024, or higher dimensions)
- Best for: Finding conceptually related content, handling synonyms, cross-language search
2. Keyword Search
- How it works:
- Uses PostgreSQL's full-text search capabilities with tsvector and ts_rank
- Converts text to searchable tokens using to_tsvector('english', text)
- Eg. SQL query: SELECT document, ts_rank(content_tsvector, plainto_tsquery('english', %s)) AS score
- Characteristics:
- Lexical matching - finds exact keyword matches and variations
- Uses GIN (Generalized Inverted Index) for fast text search performance
- Scoring: Uses PostgreSQL's ts_rank function for relevance scoring
- Best for: Exact term matching, proper names, technical terms, Boolean-style queries
3. Hybrid Search
- How it works:
- Combines both vector and keyword search results
- Runs both searches independently, then merges results using configurable reranking
- Two reranking strategies available:
- Reciprocal Rank Fusion (RRF) - (default: 60.0)
- Weighted Average - (default: 0.5)
- Characteristics:
- Best of both worlds: semantic understanding + exact matching
- Documents appearing in both searches get boosted scores
- Configurable balance between semantic and lexical matching
- Best for: General-purpose search where you want both precision and recall
4. Database Schema
The PGVector implementation stores data optimized for all three search types:
CREATE TABLE vector_store_xxx (
id TEXT PRIMARY KEY,
document JSONB, -- Original document
embedding vector(dimension), -- For vector search
content_text TEXT, -- Raw text content
content_tsvector TSVECTOR -- For keyword search
);
-- Indexes for performance
CREATE INDEX content_gin_idx ON table USING GIN(content_tsvector); -- Keyword search
-- Vector index created automatically by pgvector
## Usage ## Usage
To use PGVector in your Llama Stack project, follow these steps: To use PGVector in your Llama Stack project, follow these steps:

View file

@ -35,6 +35,7 @@ from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex, EmbeddingIndex,
VectorDBWithIndex, VectorDBWithIndex,
) )
from llama_stack.providers.utils.vector_io.vector_utils import Reranker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -66,59 +67,6 @@ def _create_sqlite_connection(db_path):
return connection return connection
def _normalize_scores(scores: dict[str, float]) -> dict[str, float]:
"""Normalize scores to [0,1] range using min-max normalization."""
if not scores:
return {}
min_score = min(scores.values())
max_score = max(scores.values())
score_range = max_score - min_score
if score_range > 0:
return {doc_id: (score - min_score) / score_range for doc_id, score in scores.items()}
return dict.fromkeys(scores, 1.0)
def _weighted_rerank(
vector_scores: dict[str, float],
keyword_scores: dict[str, float],
alpha: float = 0.5,
) -> dict[str, float]:
"""ReRanker that uses weighted average of scores."""
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
normalized_vector_scores = _normalize_scores(vector_scores)
normalized_keyword_scores = _normalize_scores(keyword_scores)
return {
doc_id: (alpha * normalized_keyword_scores.get(doc_id, 0.0))
+ ((1 - alpha) * normalized_vector_scores.get(doc_id, 0.0))
for doc_id in all_ids
}
def _rrf_rerank(
vector_scores: dict[str, float],
keyword_scores: dict[str, float],
impact_factor: float = 60.0,
) -> dict[str, float]:
"""ReRanker that uses Reciprocal Rank Fusion."""
# Convert scores to ranks
vector_ranks = {
doc_id: i + 1 for i, (doc_id, _) in enumerate(sorted(vector_scores.items(), key=lambda x: x[1], reverse=True))
}
keyword_ranks = {
doc_id: i + 1 for i, (doc_id, _) in enumerate(sorted(keyword_scores.items(), key=lambda x: x[1], reverse=True))
}
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
rrf_scores = {}
for doc_id in all_ids:
vector_rank = vector_ranks.get(doc_id, float("inf"))
keyword_rank = keyword_ranks.get(doc_id, float("inf"))
# RRF formula: score = 1/(k + r) where k is impact_factor and r is the rank
rrf_scores[doc_id] = (1.0 / (impact_factor + vector_rank)) + (1.0 / (impact_factor + keyword_rank))
return rrf_scores
def _make_sql_identifier(name: str) -> str: def _make_sql_identifier(name: str) -> str:
return re.sub(r"[^a-zA-Z0-9_]", "_", name) return re.sub(r"[^a-zA-Z0-9_]", "_", name)
@ -401,11 +349,11 @@ class SQLiteVecIndex(EmbeddingIndex):
# Combine scores using the specified reranker # Combine scores using the specified reranker
if reranker_type == RERANKER_TYPE_WEIGHTED: if reranker_type == RERANKER_TYPE_WEIGHTED:
alpha = reranker_params.get("alpha", 0.5) alpha = reranker_params.get("alpha", 0.5)
combined_scores = _weighted_rerank(vector_scores, keyword_scores, alpha) combined_scores = Reranker.weighted_rerank(vector_scores, keyword_scores, alpha)
else: else:
# Default to RRF for None, RRF, or any unknown types # Default to RRF for None, RRF, or any unknown types
impact_factor = reranker_params.get("impact_factor", 60.0) impact_factor = reranker_params.get("impact_factor", 60.0)
combined_scores = _rrf_rerank(vector_scores, keyword_scores, impact_factor) combined_scores = Reranker.rrf_rerank(vector_scores, keyword_scores, impact_factor)
# Sort by combined score and get top k results # Sort by combined score and get top k results
sorted_items = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True) sorted_items = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)

View file

@ -404,6 +404,60 @@ That means you'll get fast and efficient vector retrieval.
- Easy to use - Easy to use
- Fully integrated with Llama Stack - Fully integrated with Llama Stack
Three implementations of search for PGVectoIndex:
1. Vector Search:
- How it works:
- Uses PostgreSQL's vector extension (pgvector) to perform similarity search
- Compares query embeddings against stored embeddings using Cosine distance or other distance metrics
- Eg. SQL query: SELECT document, embedding <=> %s::vector AS distance FROM table ORDER BY distance
-Characteristics:
- Semantic understanding - finds documents similar in meaning even if they don't share keywords
- Works with high-dimensional vector embeddings (typically 768, 1024, or higher dimensions)
- Best for: Finding conceptually related content, handling synonyms, cross-language search
2. Keyword Search
- How it works:
- Uses PostgreSQL's full-text search capabilities with tsvector and ts_rank
- Converts text to searchable tokens using to_tsvector('english', text)
- Eg. SQL query: SELECT document, ts_rank(content_tsvector, plainto_tsquery('english', %s)) AS score
- Characteristics:
- Lexical matching - finds exact keyword matches and variations
- Uses GIN (Generalized Inverted Index) for fast text search performance
- Scoring: Uses PostgreSQL's ts_rank function for relevance scoring
- Best for: Exact term matching, proper names, technical terms, Boolean-style queries
3. Hybrid Search
- How it works:
- Combines both vector and keyword search results
- Runs both searches independently, then merges results using configurable reranking
- Two reranking strategies available:
- Reciprocal Rank Fusion (RRF) - (default: 60.0)
- Weighted Average - (default: 0.5)
- Characteristics:
- Best of both worlds: semantic understanding + exact matching
- Documents appearing in both searches get boosted scores
- Configurable balance between semantic and lexical matching
- Best for: General-purpose search where you want both precision and recall
4. Database Schema
The PGVector implementation stores data optimized for all three search types:
CREATE TABLE vector_store_xxx (
id TEXT PRIMARY KEY,
document JSONB, -- Original document
embedding vector(dimension), -- For vector search
content_text TEXT, -- Raw text content
content_tsvector TSVECTOR -- For keyword search
);
-- Indexes for performance
CREATE INDEX content_gin_idx ON table USING GIN(content_tsvector); -- Keyword search
-- Vector index created automatically by pgvector
## Usage ## Usage
To use PGVector in your Llama Stack project, follow these steps: To use PGVector in your Llama Stack project, follow these steps:
@ -449,6 +503,7 @@ Weaviate supports:
- Metadata filtering - Metadata filtering
- Multi-modal retrieval - Multi-modal retrieval
## Usage ## Usage
To use Weaviate in your Llama Stack project, follow these steps: To use Weaviate in your Llama Stack project, follow these steps:

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import heapq
import logging import logging
from typing import Any from typing import Any
@ -23,6 +24,9 @@ from llama_stack.apis.vector_io import (
VectorIO, VectorIO,
) )
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
@ -31,6 +35,7 @@ from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex, EmbeddingIndex,
VectorDBWithIndex, VectorDBWithIndex,
) )
from llama_stack.providers.utils.vector_io.vector_utils import Reranker
from .config import PGVectorVectorIOConfig from .config import PGVectorVectorIOConfig
@ -72,25 +77,63 @@ def load_models(cur, cls):
class PGVectorIndex(EmbeddingIndex): class PGVectorIndex(EmbeddingIndex):
def __init__(self, vector_db: VectorDB, dimension: int, conn, kvstore: KVStore | None = None): # reference: https://github.com/pgvector/pgvector?tab=readme-ov-file#querying
self.conn = conn PGVECTOR_DISTANCE_METRIC_TO_SEARCH_OPERATOR: dict[str, str] = {
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: "L2": "<->", # Euclidean distance
# Sanitize the table name by replacing hyphens with underscores "L1": "<+>", # Manhattan distance
# SQL doesn't allow hyphens in table names, and vector_db.identifier may contain hyphens "COSINE": "<=>", # Cosine distance
# when created with patterns like "test-vector-db-{uuid4()}" "INNER_PRODUCT": "<#>", # Inner product distance
sanitized_identifier = vector_db.identifier.replace("-", "_") "HAMMING": "<~>", # Hamming distance
self.table_name = f"vector_store_{sanitized_identifier}" "JACCARD": "<%>", # Jaccard distance
self.kvstore = kvstore }
cur.execute( def __init__(
f""" self,
CREATE TABLE IF NOT EXISTS {self.table_name} ( vector_db: VectorDB,
id TEXT PRIMARY KEY, dimension: int,
document JSONB, conn: psycopg2.extensions.connection,
embedding vector({dimension}) kvstore: KVStore | None = None,
distance_metric: str = "COSINE",
):
self.conn = conn
self.check_distance_metric_availability(distance_metric)
self.distance_metric = distance_metric
try:
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
# Sanitize the table name by replacing hyphens with underscores
# SQL doesn't allow hyphens in table names, and vector_db.identifier may contain hyphens
# when created with patterns like "test-vector-db-{uuid4()}"
sanitized_identifier = vector_db.identifier.replace("-", "_")
self.table_name = f"vector_store_{sanitized_identifier}"
self.kvstore = kvstore
cur.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.table_name} (
id TEXT PRIMARY KEY,
document JSONB,
embedding vector({dimension}),
content_text TEXT,
content_tsvector TSVECTOR
)
"""
) )
"""
) # Create GIN index for full-text search performance
cur.execute(
f"""
CREATE INDEX IF NOT EXISTS {self.table_name}_content_gin_idx
ON {self.table_name} USING GIN(content_tsvector)
"""
)
except Exception as e:
log.exception(f"Error creating PGVectorIndex for vector_db: {vector_db.identifier}")
raise RuntimeError(f"Error creating PGVectorIndex for vector_db: {vector_db.identifier}") from e
async def initialize(self) -> None:
# PGVectorIndex does not require explicit initialization
pass
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), ( assert len(chunks) == len(embeddings), (
@ -99,29 +142,49 @@ class PGVectorIndex(EmbeddingIndex):
values = [] values = []
for i, chunk in enumerate(chunks): for i, chunk in enumerate(chunks):
content_text = interleaved_content_as_str(chunk.content)
values.append( values.append(
( (
f"{chunk.chunk_id}", f"{chunk.chunk_id}",
Json(chunk.model_dump()), Json(chunk.model_dump()),
embeddings[i].tolist(), embeddings[i].tolist(),
content_text,
content_text, # Pass content_text twice - once for content_text column, once for to_tsvector function
) )
) )
query = sql.SQL( query = sql.SQL(
f""" f"""
INSERT INTO {self.table_name} (id, document, embedding) INSERT INTO {self.table_name} (id, document, embedding, content_text, content_tsvector)
VALUES %s VALUES %s
ON CONFLICT (id) DO UPDATE SET embedding = EXCLUDED.embedding, document = EXCLUDED.document ON CONFLICT (id) DO UPDATE SET
embedding = EXCLUDED.embedding,
document = EXCLUDED.document,
content_text = EXCLUDED.content_text,
content_tsvector = EXCLUDED.content_tsvector
""" """
) )
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
execute_values(cur, query, values, template="(%s, %s, %s::vector)") execute_values(cur, query, values, template="(%s, %s, %s::vector, %s, to_tsvector('english', %s))")
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
"""
Performs vector similarity search using PostgreSQL's operators. Default distance metric is COSINE.
Args:
embedding: The query embedding vector
k: Number of results to return
score_threshold: Minimum similarity score threshold
Returns:
QueryChunksResponse with combined results
"""
pgvector_search_operator = self.get_pgvector_search_operator()
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute( cur.execute(
f""" f"""
SELECT document, embedding <-> %s::vector AS distance SELECT document, embedding {pgvector_search_operator} %s::vector AS distance
FROM {self.table_name} FROM {self.table_name}
ORDER BY distance ORDER BY distance
LIMIT %s LIMIT %s
@ -147,7 +210,40 @@ class PGVectorIndex(EmbeddingIndex):
k: int, k: int,
score_threshold: float, score_threshold: float,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in PGVector") """
Performs keyword-based search using PostgreSQL's full-text search with ts_rank scoring.
Args:
query_string: The text query for keyword search
k: Number of results to return
score_threshold: Minimum similarity score threshold
Returns:
QueryChunksResponse with combined results
"""
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
# Use plainto_tsquery to handle user input safely and ts_rank for relevance scoring
cur.execute(
f"""
SELECT document, ts_rank(content_tsvector, plainto_tsquery('english', %s)) AS score
FROM {self.table_name}
WHERE content_tsvector @@ plainto_tsquery('english', %s)
ORDER BY score DESC
LIMIT %s
""",
(query_string, query_string, k),
)
results = cur.fetchall()
chunks = []
scores = []
for doc, score in results:
if score < score_threshold:
continue
chunks.append(Chunk(**doc))
scores.append(float(score))
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_hybrid( async def query_hybrid(
self, self,
@ -158,7 +254,57 @@ class PGVectorIndex(EmbeddingIndex):
reranker_type: str, reranker_type: str,
reranker_params: dict[str, Any] | None = None, reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in PGVector") """
Hybrid search combining vector similarity and keyword search using configurable reranking.
Args:
embedding: The query embedding vector
query_string: The text query for keyword search
k: Number of results to return
score_threshold: Minimum similarity score threshold
reranker_type: Type of reranker to use ("rrf" or "weighted")
reranker_params: Parameters for the reranker
Returns:
QueryChunksResponse with combined results
"""
if reranker_params is None:
reranker_params = {}
# Get results from both search methods
vector_response = await self.query_vector(embedding, k, score_threshold)
keyword_response = await self.query_keyword(query_string, k, score_threshold)
# Convert responses to score dictionaries using chunk_id
vector_scores = {
chunk.chunk_id: score for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False)
}
keyword_scores = {
chunk.chunk_id: score
for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False)
}
# Combine scores using the reranking utility
combined_scores = Reranker.combine_search_results(vector_scores, keyword_scores, reranker_type, reranker_params)
# Efficient top-k selection because it only tracks the k best candidates it's seen so far
top_k_items = heapq.nlargest(k, combined_scores.items(), key=lambda x: x[1])
# Filter by score threshold
filtered_items = [(doc_id, score) for doc_id, score in top_k_items if score >= score_threshold]
# Create a map of chunk_id to chunk for both responses
chunk_map = {c.chunk_id: c for c in vector_response.chunks + keyword_response.chunks}
# Use the map to look up chunks by their IDs
chunks = []
scores = []
for doc_id, score in filtered_items:
if doc_id in chunk_map:
chunks.append(chunk_map[doc_id])
scores.append(score)
return QueryChunksResponse(chunks=chunks, scores=scores)
async def delete(self): async def delete(self):
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
@ -170,6 +316,25 @@ class PGVectorIndex(EmbeddingIndex):
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids,)) cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids,))
def get_pgvector_search_operator(self) -> str:
return self.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_OPERATOR[self.distance_metric]
def check_distance_metric_availability(self, distance_metric: str) -> None:
"""Check if the distance metric is supported by PGVector.
Args:
distance_metric: The distance metric to check
Raises:
ValueError: If the distance metric is not supported
"""
if distance_metric not in self.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_OPERATOR:
supported_metrics = list(self.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_OPERATOR.keys())
raise ValueError(
f"Distance metric '{distance_metric}' is not supported by PGVector. "
f"Supported metrics are: {', '.join(supported_metrics)}"
)
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
def __init__( def __init__(
@ -185,8 +350,8 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
self.files_api = files_api self.files_api = files_api
self.kvstore: KVStore | None = None self.kvstore: KVStore | None = None
self.vector_db_store = None self.vector_db_store = None
self.openai_vector_store: dict[str, dict[str, Any]] = {} self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.metadatadata_collection_name = "openai_vector_stores_metadata" self.metadata_collection_name = "openai_vector_stores_metadata"
async def initialize(self) -> None: async def initialize(self) -> None:
log.info(f"Initializing PGVector memory adapter with config: {self.config}") log.info(f"Initializing PGVector memory adapter with config: {self.config}")
@ -272,7 +437,13 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
if vector_db_id in self.cache: if vector_db_id in self.cache:
return self.cache[vector_db_id] return self.cache[vector_db_id]
if self.vector_db_store is None:
raise VectorStoreNotFoundError(vector_db_id)
vector_db = await self.vector_db_store.get_vector_db(vector_db_id) vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
if not vector_db:
raise VectorStoreNotFoundError(vector_db_id)
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn) index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn)
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api) self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api)
return self.cache[vector_db_id] return self.cache[vector_db_id]

View file

@ -37,3 +37,122 @@ def sanitize_collection_name(name: str, weaviate_format=False) -> str:
else: else:
s = proper_case(re.sub(r"[^a-zA-Z0-9]", "", name)) s = proper_case(re.sub(r"[^a-zA-Z0-9]", "", name))
return s return s
class Reranker:
@staticmethod
def _normalize_scores(scores: dict[str, float]) -> dict[str, float]:
"""
Normalize scores to 0-1 range using min-max normalization.
Args:
scores: dictionary of scores with document IDs as keys and scores as values
Returns:
Normalized scores with document IDs as keys and normalized scores as values
"""
if not scores:
return {}
min_score, max_score = min(scores.values()), max(scores.values())
score_range = max_score - min_score
if score_range > 0:
return {doc_id: (score - min_score) / score_range for doc_id, score in scores.items()}
return dict.fromkeys(scores, 1.0)
@staticmethod
def weighted_rerank(
vector_scores: dict[str, float],
keyword_scores: dict[str, float],
alpha: float = 0.5,
) -> dict[str, float]:
"""
Rerank via weighted average of scores.
Args:
vector_scores: scores from vector search
keyword_scores: scores from keyword search
alpha: weight factor between 0 and 1 (default: 0.5)
0 = keyword only, 1 = vector only, 0.5 = equal weight
Returns:
All unique document IDs with weighted combined scores
"""
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
normalized_vector_scores = Reranker._normalize_scores(vector_scores)
normalized_keyword_scores = Reranker._normalize_scores(keyword_scores)
# Weighted formula: score = (1-alpha) * keyword_score + alpha * vector_score
# alpha=0 means keyword only, alpha=1 means vector only
return {
doc_id: ((1 - alpha) * normalized_keyword_scores.get(doc_id, 0.0))
+ (alpha * normalized_vector_scores.get(doc_id, 0.0))
for doc_id in all_ids
}
@staticmethod
def rrf_rerank(
vector_scores: dict[str, float],
keyword_scores: dict[str, float],
impact_factor: float = 60.0,
) -> dict[str, float]:
"""
Rerank via Reciprocal Rank Fusion.
Args:
vector_scores: scores from vector search
keyword_scores: scores from keyword search
impact_factor: impact factor for RRF (default: 60.0)
Returns:
All unique document IDs with RRF combined scores
"""
# Convert scores to ranks
vector_ranks = {
doc_id: i + 1
for i, (doc_id, _) in enumerate(sorted(vector_scores.items(), key=lambda x: x[1], reverse=True))
}
keyword_ranks = {
doc_id: i + 1
for i, (doc_id, _) in enumerate(sorted(keyword_scores.items(), key=lambda x: x[1], reverse=True))
}
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
rrf_scores = {}
for doc_id in all_ids:
vector_rank = vector_ranks.get(doc_id, float("inf"))
keyword_rank = keyword_ranks.get(doc_id, float("inf"))
# RRF formula: score = 1/(k + r) where k is impact_factor (default: 60.0) and r is the rank
rrf_scores[doc_id] = (1.0 / (impact_factor + vector_rank)) + (1.0 / (impact_factor + keyword_rank))
return rrf_scores
@staticmethod
def combine_search_results(
vector_scores: dict[str, float],
keyword_scores: dict[str, float],
reranker_type: str = "rrf",
reranker_params: dict[str, float] | None = None,
) -> dict[str, float]:
"""
Combine vector and keyword search results using specified reranking strategy.
Args:
vector_scores: scores from vector search
keyword_scores: scores from keyword search
reranker_type: type of reranker to use (default: RERANKER_TYPE_RRF)
reranker_params: parameters for the reranker
Returns:
All unique document IDs with combined scores
"""
if reranker_params is None:
reranker_params = {}
if reranker_type == "weighted":
alpha = reranker_params.get("alpha", 0.5)
return Reranker.weighted_rerank(vector_scores, keyword_scores, alpha)
else:
# Default to RRF for None, RRF, or any unknown types
impact_factor = reranker_params.get("impact_factor", 60.0)
return Reranker.rrf_rerank(vector_scores, keyword_scores, impact_factor)

View file

@ -84,6 +84,7 @@ unit = [
"openai", "openai",
"aiosqlite", "aiosqlite",
"aiohttp", "aiohttp",
"psycopg2-binary>=2.9.0",
"pypdf", "pypdf",
"mcp", "mcp",
"chardet", "chardet",
@ -109,6 +110,7 @@ test = [
"torch>=2.6.0", "torch>=2.6.0",
"torchvision>=0.21.0", "torchvision>=0.21.0",
"chardet", "chardet",
"psycopg2-binary>=2.9.0",
"pypdf", "pypdf",
"mcp", "mcp",
"datasets", "datasets",

View file

@ -56,11 +56,13 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode
"keyword": [ "keyword": [
"inline::sqlite-vec", "inline::sqlite-vec",
"remote::milvus", "remote::milvus",
"remote::pgvector",
], ],
"hybrid": [ "hybrid": [
"inline::sqlite-vec", "inline::sqlite-vec",
"inline::milvus", "inline::milvus",
"remote::milvus", "remote::milvus",
"remote::pgvector",
], ],
} }
supported_providers = search_mode_support.get(search_mode, []) supported_providers = search_mode_support.get(search_mode, [])

View file

@ -0,0 +1,246 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.utils.memory.vector_store import RERANKER_TYPE_RRF, RERANKER_TYPE_WEIGHTED
from llama_stack.providers.utils.vector_io.vector_utils import Reranker
class TestNormalizeScores:
"""Test cases for score normalization."""
def test_normalize_scores_basic(self):
"""Test basic score normalization."""
scores = {"doc1": 10.0, "doc2": 5.0, "doc3": 0.0}
normalized = Reranker._normalize_scores(scores)
assert normalized["doc1"] == 1.0 # Max score
assert normalized["doc3"] == 0.0 # Min score
assert normalized["doc2"] == 0.5 # Middle score
assert all(0 <= score <= 1 for score in normalized.values())
def test_normalize_scores_identical(self):
"""Test normalization when all scores are identical."""
scores = {"doc1": 5.0, "doc2": 5.0, "doc3": 5.0}
normalized = Reranker._normalize_scores(scores)
# All scores should be 1.0 when identical
assert all(score == 1.0 for score in normalized.values())
def test_normalize_scores_empty(self):
"""Test normalization with empty scores."""
scores = {}
normalized = Reranker._normalize_scores(scores)
assert normalized == {}
def test_normalize_scores_single(self):
"""Test normalization with single score."""
scores = {"doc1": 7.5}
normalized = Reranker._normalize_scores(scores)
assert normalized["doc1"] == 1.0
class TestWeightedRerank:
"""Test cases for weighted reranking."""
def test_weighted_rerank_basic(self):
"""Test basic weighted reranking."""
vector_scores = {"doc1": 0.9, "doc2": 0.7, "doc3": 0.5}
keyword_scores = {"doc1": 0.6, "doc2": 0.8, "doc4": 0.9}
combined = Reranker.weighted_rerank(vector_scores, keyword_scores, alpha=0.5)
# Should include all documents
expected_docs = {"doc1", "doc2", "doc3", "doc4"}
assert set(combined.keys()) == expected_docs
# All scores should be between 0 and 1
assert all(0 <= score <= 1 for score in combined.values())
# doc1 appears in both searches, should have higher combined score
assert combined["doc1"] > 0
def test_weighted_rerank_alpha_zero(self):
"""Test weighted reranking with alpha=0 (keyword only)."""
vector_scores = {"doc1": 0.9, "doc2": 0.7, "doc3": 0.5} # All docs present in vector
keyword_scores = {"doc1": 0.1, "doc2": 0.3, "doc3": 0.9} # All docs present in keyword
combined = Reranker.weighted_rerank(vector_scores, keyword_scores, alpha=0.0)
# Alpha=0 means vector scores are ignored, keyword scores dominate
# doc3 should score highest since it has highest keyword score
assert combined["doc3"] > combined["doc2"] > combined["doc1"]
def test_weighted_rerank_alpha_one(self):
"""Test weighted reranking with alpha=1 (vector only)."""
vector_scores = {"doc1": 0.9, "doc2": 0.7, "doc3": 0.5} # All docs present in vector
keyword_scores = {"doc1": 0.1, "doc2": 0.3, "doc3": 0.9} # All docs present in keyword
combined = Reranker.weighted_rerank(vector_scores, keyword_scores, alpha=1.0)
# Alpha=1 means keyword scores are ignored, vector scores dominate
# doc1 should score highest since it has highest vector score
assert combined["doc1"] > combined["doc2"] > combined["doc3"]
def test_weighted_rerank_no_overlap(self):
"""Test weighted reranking with no overlapping documents."""
vector_scores = {"doc1": 0.9, "doc2": 0.7}
keyword_scores = {"doc3": 0.8, "doc4": 0.6}
combined = Reranker.weighted_rerank(vector_scores, keyword_scores, alpha=0.5)
assert len(combined) == 4
# With min-max normalization, lowest scoring docs in each group get 0.0
# but highest scoring docs should get positive scores
assert all(score >= 0 for score in combined.values())
assert combined["doc1"] > 0 # highest vector score
assert combined["doc3"] > 0 # highest keyword score
class TestRRFRerank:
"""Test cases for RRF (Reciprocal Rank Fusion) reranking."""
def test_rrf_rerank_basic(self):
"""Test basic RRF reranking."""
vector_scores = {"doc1": 0.9, "doc2": 0.7, "doc3": 0.5}
keyword_scores = {"doc1": 0.6, "doc2": 0.8, "doc4": 0.9}
combined = Reranker.rrf_rerank(vector_scores, keyword_scores, impact_factor=60.0)
# Should include all documents
expected_docs = {"doc1", "doc2", "doc3", "doc4"}
assert set(combined.keys()) == expected_docs
# All scores should be positive
assert all(score > 0 for score in combined.values())
# Documents appearing in both searches should have higher scores
# doc1 and doc2 appear in both, doc3 and doc4 appear in only one
assert combined["doc1"] > combined["doc3"]
assert combined["doc2"] > combined["doc4"]
def test_rrf_rerank_rank_calculation(self):
"""Test that RRF correctly calculates ranks."""
# Create clear ranking order
vector_scores = {"doc1": 1.0, "doc2": 0.8, "doc3": 0.6} # Ranks: 1, 2, 3
keyword_scores = {"doc1": 0.5, "doc2": 1.0, "doc3": 0.7} # Ranks: 3, 1, 2
combined = Reranker.rrf_rerank(vector_scores, keyword_scores, impact_factor=60.0)
# doc1: rank 1 in vector, rank 3 in keyword
# doc2: rank 2 in vector, rank 1 in keyword
# doc3: rank 3 in vector, rank 2 in keyword
# doc2 should have the highest combined score (ranks 2+1=3)
# followed by doc1 (ranks 1+3=4) and doc3 (ranks 3+2=5)
# Remember: lower rank sum = higher RRF score
assert combined["doc2"] > combined["doc1"] > combined["doc3"]
def test_rrf_rerank_impact_factor(self):
"""Test that impact factor affects RRF scores."""
vector_scores = {"doc1": 0.9, "doc2": 0.7}
keyword_scores = {"doc1": 0.8, "doc2": 0.6}
combined_low = Reranker.rrf_rerank(vector_scores, keyword_scores, impact_factor=10.0)
combined_high = Reranker.rrf_rerank(vector_scores, keyword_scores, impact_factor=100.0)
# Higher impact factor should generally result in lower scores
# (because 1/(k+r) decreases as k increases)
assert combined_low["doc1"] > combined_high["doc1"]
assert combined_low["doc2"] > combined_high["doc2"]
def test_rrf_rerank_missing_documents(self):
"""Test RRF handling of documents missing from one search."""
vector_scores = {"doc1": 0.9, "doc2": 0.7}
keyword_scores = {"doc1": 0.8, "doc3": 0.6}
combined = Reranker.rrf_rerank(vector_scores, keyword_scores, impact_factor=60.0)
# Should include all documents
assert len(combined) == 3
# doc1 appears in both searches, should have highest score
assert combined["doc1"] > combined["doc2"]
assert combined["doc1"] > combined["doc3"]
class TestCombineSearchResults:
"""Test cases for the main combine_search_results function."""
def test_combine_search_results_rrf_default(self):
"""Test combining with RRF as default."""
vector_scores = {"doc1": 0.9, "doc2": 0.7}
keyword_scores = {"doc1": 0.6, "doc3": 0.8}
combined = Reranker.combine_search_results(vector_scores, keyword_scores)
# Should default to RRF
assert len(combined) == 3
assert all(score > 0 for score in combined.values())
def test_combine_search_results_rrf_explicit(self):
"""Test combining with explicit RRF."""
vector_scores = {"doc1": 0.9, "doc2": 0.7}
keyword_scores = {"doc1": 0.6, "doc3": 0.8}
combined = Reranker.combine_search_results(
vector_scores, keyword_scores, reranker_type=RERANKER_TYPE_RRF, reranker_params={"impact_factor": 50.0}
)
assert len(combined) == 3
assert all(score > 0 for score in combined.values())
def test_combine_search_results_weighted(self):
"""Test combining with weighted reranking."""
vector_scores = {"doc1": 0.9, "doc2": 0.7}
keyword_scores = {"doc1": 0.6, "doc3": 0.8}
combined = Reranker.combine_search_results(
vector_scores, keyword_scores, reranker_type=RERANKER_TYPE_WEIGHTED, reranker_params={"alpha": 0.3}
)
assert len(combined) == 3
assert all(0 <= score <= 1 for score in combined.values())
def test_combine_search_results_unknown_type(self):
"""Test combining with unknown reranker type defaults to RRF."""
vector_scores = {"doc1": 0.9}
keyword_scores = {"doc2": 0.8}
combined = Reranker.combine_search_results(vector_scores, keyword_scores, reranker_type="unknown_type")
# Should fall back to RRF
assert len(combined) == 2
assert all(score > 0 for score in combined.values())
def test_combine_search_results_empty_params(self):
"""Test combining with empty parameters."""
vector_scores = {"doc1": 0.9}
keyword_scores = {"doc2": 0.8}
combined = Reranker.combine_search_results(vector_scores, keyword_scores, reranker_params={})
# Should use default parameters
assert len(combined) == 2
assert all(score > 0 for score in combined.values())
def test_combine_search_results_empty_scores(self):
"""Test combining with empty score dictionaries."""
# Test with empty vector scores
combined = Reranker.combine_search_results({}, {"doc1": 0.8})
assert len(combined) == 1
assert combined["doc1"] > 0
# Test with empty keyword scores
combined = Reranker.combine_search_results({"doc1": 0.9}, {})
assert len(combined) == 1
assert combined["doc1"] > 0
# Test with both empty
combined = Reranker.combine_search_results({}, {})
assert len(combined) == 0

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import random import random
from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np import numpy as np
import pytest import pytest
@ -12,7 +13,7 @@ from chromadb import PersistentClient
from pymilvus import MilvusClient, connections from pymilvus import MilvusClient, connections
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, ChunkMetadata from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
from llama_stack.providers.inline.vector_io.chroma.config import ChromaVectorIOConfig from llama_stack.providers.inline.vector_io.chroma.config import ChromaVectorIOConfig
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
@ -22,6 +23,8 @@ from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConf
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter, maybe_await from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter, maybe_await
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
EMBEDDING_DIMENSION = 384 EMBEDDING_DIMENSION = 384
@ -29,7 +32,7 @@ COLLECTION_PREFIX = "test_collection"
MILVUS_ALIAS = "test_milvus" MILVUS_ALIAS = "test_milvus"
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma"]) @pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector"])
def vector_provider(request): def vector_provider(request):
return request.param return request.param
@ -333,15 +336,127 @@ async def qdrant_vec_index(qdrant_vec_db_path, embedding_dimension):
await index.delete() await index.delete()
@pytest.fixture
def mock_psycopg2_connection():
connection = MagicMock()
cursor = MagicMock()
cursor.__enter__ = MagicMock(return_value=cursor)
cursor.__exit__ = MagicMock()
connection.cursor.return_value = cursor
return connection, cursor
@pytest.fixture
async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection):
connection, cursor = mock_psycopg2_connection
vector_db = VectorDB(
identifier="test-vector-db",
embedding_model="test-model",
embedding_dimension=embedding_dimension,
provider_id="pgvector",
provider_resource_id="pgvector:test-vector-db",
)
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.execute_values"):
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE")
index._test_chunks = []
original_add_chunks = index.add_chunks
async def mock_add_chunks(chunks, embeddings):
index._test_chunks = list(chunks)
await original_add_chunks(chunks, embeddings)
index.add_chunks = mock_add_chunks
async def mock_query_vector(embedding, k, score_threshold):
chunks = index._test_chunks[:k] if hasattr(index, "_test_chunks") else []
scores = [1.0] * len(chunks)
return QueryChunksResponse(chunks=chunks, scores=scores)
index.query_vector = mock_query_vector
yield index
@pytest.fixture
async def pgvector_vec_adapter(mock_inference_api, embedding_dimension):
config = PGVectorVectorIOConfig(
host="localhost",
port=5432,
db="test_db",
user="test_user",
password="test_password",
kvstore=SqliteKVStoreConfig(),
)
adapter = PGVectorVectorIOAdapter(config, mock_inference_api, None)
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2.connect") as mock_connect:
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_cursor.__enter__ = MagicMock(return_value=mock_cursor)
mock_cursor.__exit__ = MagicMock()
mock_conn.cursor.return_value = mock_cursor
mock_conn.autocommit = True
mock_connect.return_value = mock_conn
with patch(
"llama_stack.providers.remote.vector_io.pgvector.pgvector.check_extension_version"
) as mock_check_version:
mock_check_version.return_value = "0.5.1"
with patch("llama_stack.providers.utils.kvstore.kvstore_impl") as mock_kvstore_impl:
mock_kvstore = AsyncMock()
mock_kvstore_impl.return_value = mock_kvstore
with patch.object(adapter, "initialize_openai_vector_stores", new_callable=AsyncMock):
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.upsert_models"):
await adapter.initialize()
adapter.conn = mock_conn
async def mock_insert_chunks(vector_db_id, chunks, ttl_seconds=None):
index = await adapter._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
await index.insert_chunks(chunks)
adapter.insert_chunks = mock_insert_chunks
async def mock_query_chunks(vector_db_id, query, params=None):
index = await adapter._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
return await index.query_chunks(query, params)
adapter.query_chunks = mock_query_chunks
test_vector_db = VectorDB(
identifier=f"pgvector_test_collection_{random.randint(1, 1_000_000)}",
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=embedding_dimension,
)
await adapter.register_vector_db(test_vector_db)
adapter.test_collection_id = test_vector_db.identifier
yield adapter
await adapter.shutdown()
@pytest.fixture @pytest.fixture
def vector_io_adapter(vector_provider, request): def vector_io_adapter(vector_provider, request):
"""Returns the appropriate vector IO adapter based on the provider parameter."""
vector_provider_dict = { vector_provider_dict = {
"milvus": "milvus_vec_adapter", "milvus": "milvus_vec_adapter",
"faiss": "faiss_vec_adapter", "faiss": "faiss_vec_adapter",
"sqlite_vec": "sqlite_vec_adapter", "sqlite_vec": "sqlite_vec_adapter",
"chroma": "chroma_vec_adapter", "chroma": "chroma_vec_adapter",
"qdrant": "qdrant_vec_adapter", "qdrant": "qdrant_vec_adapter",
"pgvector": "pgvector_vec_adapter",
} }
return request.getfixturevalue(vector_provider_dict[vector_provider]) return request.getfixturevalue(vector_provider_dict[vector_provider])

View file

@ -0,0 +1,502 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion
PGVECTOR_PROVIDER = "pgvector"
@pytest.fixture(scope="session")
def loop():
return asyncio.new_event_loop()
@pytest.fixture
def embedding_dimension():
"""Default embedding dimension for tests."""
return 384
@pytest.fixture
def mock_psycopg2_connection():
"""Create a mock psycopg2 connection for testing."""
connection = MagicMock()
cursor = MagicMock()
# Mock the cursor context manager
cursor.__enter__ = MagicMock(return_value=cursor)
cursor.__exit__ = MagicMock()
# Mock connection cursor method
connection.cursor.return_value = cursor
return connection, cursor
@pytest.fixture
async def pgvector_index(embedding_dimension, mock_psycopg2_connection):
"""Create a PGVectorIndex instance with mocked database connection."""
connection, cursor = mock_psycopg2_connection
vector_db = VectorDB(
identifier="test-vector-db",
embedding_model="test-model",
embedding_dimension=embedding_dimension,
provider_id=PGVECTOR_PROVIDER,
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
)
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
# Use explicit COSINE distance metric for consistent testing
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE")
return index, cursor
def create_sample_chunks():
"""Create sample chunks for testing."""
return [
Chunk(
content="Machine learning is a subset of artificial intelligence",
metadata={"document_id": "doc-1", "topic": "AI"},
chunk_metadata=ChunkMetadata(document_id="doc-1", chunk_id="chunk-1"),
),
Chunk(
content="Deep learning uses neural networks with multiple layers",
metadata={"document_id": "doc-2", "topic": "AI"},
chunk_metadata=ChunkMetadata(document_id="doc-2", chunk_id="chunk-2"),
),
Chunk(
content="Natural language processing enables computers to understand text",
metadata={"document_id": "doc-3", "topic": "NLP"},
chunk_metadata=ChunkMetadata(document_id="doc-3", chunk_id="chunk-3"),
),
]
def create_sample_embeddings(num_chunks, dimension=384):
"""Create sample embeddings for testing."""
np.random.seed(42)
return np.array([np.random.rand(dimension).astype(np.float32) for _ in range(num_chunks)])
class TestPGVectorIndex:
def test_distance_metric_validation(self, embedding_dimension, mock_psycopg2_connection):
connection, cursor = mock_psycopg2_connection
vector_db = VectorDB(
identifier="test-vector-db",
embedding_model="test-model",
embedding_dimension=embedding_dimension,
provider_id=PGVECTOR_PROVIDER,
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
)
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="L2")
assert index.distance_metric == "L2"
with pytest.raises(ValueError, match="Distance metric 'INVALID' is not supported"):
PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="INVALID")
def test_get_pgvector_search_operator(self, pgvector_index):
index, cursor = pgvector_index
assert index.get_pgvector_search_operator() == "<=>"
index.distance_metric = "L2"
assert index.get_pgvector_search_operator() == "<->"
index.distance_metric = "L1"
assert index.get_pgvector_search_operator() == "<+>"
index.distance_metric = "INNER_PRODUCT"
assert index.get_pgvector_search_operator() == "<#>"
index.distance_metric = "HAMMING"
assert index.get_pgvector_search_operator() == "<~>"
index.distance_metric = "JACCARD"
assert index.get_pgvector_search_operator() == "<%>"
def test_check_distance_metric_availability(self, pgvector_index):
index, cursor = pgvector_index
index.check_distance_metric_availability("COSINE")
index.check_distance_metric_availability("L2")
index.check_distance_metric_availability("L1")
index.check_distance_metric_availability("INNER_PRODUCT")
index.check_distance_metric_availability("HAMMING")
index.check_distance_metric_availability("JACCARD")
with pytest.raises(ValueError, match="Distance metric 'INVALID' is not supported"):
index.check_distance_metric_availability("INVALID")
async def test_add_chunks(self, pgvector_index, embedding_dimension):
index, cursor = pgvector_index
chunks = create_sample_chunks()
embeddings = create_sample_embeddings(len(chunks), embedding_dimension)
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.execute_values") as mock_execute_values:
await index.add_chunks(chunks, embeddings)
assert mock_execute_values.called
call_args = mock_execute_values.call_args
query_arg = str(call_args[0][1])
assert "INSERT INTO" in query_arg
assert "content_tsvector" in query_arg
async def test_query_vector(self, pgvector_index, embedding_dimension):
index, cursor = pgvector_index
query_embedding = create_sample_embeddings(1, embedding_dimension)[0]
mock_results = [
({"content": "test content", "metadata": {}}, 0.1),
({"content": "test content 2", "metadata": {}}, 0.2),
]
cursor.fetchall.return_value = mock_results
response = await index.query_vector(query_embedding, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
assert len(response.scores) == 2
call_args = cursor.execute.call_args
assert "<=>" in str(call_args) or "ORDER BY" in str(call_args)
async def test_query_vector_different_metrics(self, pgvector_index, embedding_dimension):
index, cursor = pgvector_index
query_embedding = create_sample_embeddings(1, embedding_dimension)[0]
mock_results = [
({"content": "test content", "metadata": {}}, 0.1),
]
cursor.fetchall.return_value = mock_results
# Test L2 distance
index.distance_metric = "L2"
await index.query_vector(query_embedding, k=1, score_threshold=0.0)
call_args = cursor.execute.call_args
assert "<->" in str(call_args[0][0]) # L2 operator
# Test L1 distance
index.distance_metric = "L1"
await index.query_vector(query_embedding, k=1, score_threshold=0.0)
call_args = cursor.execute.call_args
assert "<+>" in str(call_args[0][0]) # L1 operator
# Test INNER_PRODUCT distance
index.distance_metric = "INNER_PRODUCT"
await index.query_vector(query_embedding, k=1, score_threshold=0.0)
call_args = cursor.execute.call_args
assert "<#>" in str(call_args[0][0]) # Inner product operator
# Test Hamming distance
index.distance_metric = "HAMMING"
await index.query_vector(query_embedding, k=1, score_threshold=0.0)
call_args = cursor.execute.call_args
assert "<~>" in str(call_args[0][0]) # Hamming operator
# Test Jaccard distance
index.distance_metric = "JACCARD"
await index.query_vector(query_embedding, k=1, score_threshold=0.0)
call_args = cursor.execute.call_args
assert "<%>" in str(call_args[0][0]) # Jaccard operator
async def test_query_keyword(self, pgvector_index):
index, cursor = pgvector_index
query_string = "machine learning"
mock_results = [
({"content": "Machine learning is great", "metadata": {}}, 0.8),
({"content": "Learning machines are useful", "metadata": {}}, 0.6),
]
cursor.fetchall.return_value = mock_results
response = await index.query_keyword(query_string, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
assert len(response.scores) == 2
call_args = cursor.execute.call_args
assert "ts_rank" in str(call_args) or "plainto_tsquery" in str(call_args)
async def test_query_keyword_with_score_threshold(self, pgvector_index):
index, cursor = pgvector_index
query_string = "machine learning"
score_threshold = 0.7
# Mock database response with mixed scores
mock_results = [
({"content": "Machine learning is great", "metadata": {}}, 0.8), # Above threshold
({"content": "Learning machines are useful", "metadata": {}}, 0.5), # Below threshold
]
cursor.fetchall.return_value = mock_results
response = await index.query_keyword(query_string, k=2, score_threshold=score_threshold)
assert len(response.chunks) == 1
assert response.scores[0] >= score_threshold
async def test_query_hybrid_rrf(self, pgvector_index, embedding_dimension):
index, cursor = pgvector_index
query_embedding = create_sample_embeddings(1, embedding_dimension)[0]
query_string = "machine learning"
# Mock responses for both vector and keyword searches
vector_results = [
({"content": "Vector result 1", "metadata": {}, "chunk_id": "chunk-1"}, 0.9),
({"content": "Vector result 2", "metadata": {}, "chunk_id": "chunk-2"}, 0.7),
]
keyword_results = [
({"content": "Keyword result 1", "metadata": {}, "chunk_id": "chunk-3"}, 0.8),
({"content": "Vector result 1", "metadata": {}, "chunk_id": "chunk-1"}, 0.6), # Overlap
]
cursor.fetchall.side_effect = [vector_results, keyword_results]
response = await index.query_hybrid(
query_embedding,
query_string,
k=3,
score_threshold=0.0,
reranker_type="rrf",
reranker_params={"impact_factor": 60.0},
)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) >= 1 # At least the overlapping chunk
assert cursor.execute.call_count >= 2
async def test_query_hybrid_weighted(self, pgvector_index, embedding_dimension):
index, cursor = pgvector_index
query_embedding = create_sample_embeddings(1, embedding_dimension)[0]
query_string = "machine learning"
vector_results = [
({"content": "Vector result", "metadata": {}, "chunk_id": "chunk-1"}, 0.9),
]
keyword_results = [
({"content": "Keyword result", "metadata": {}, "chunk_id": "chunk-2"}, 0.8),
]
cursor.fetchall.side_effect = [vector_results, keyword_results]
response = await index.query_hybrid(
query_embedding,
query_string,
k=2,
score_threshold=0.0,
reranker_type="weighted",
reranker_params={"alpha": 0.7},
)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
def test_constructor_invalid_distance_metric(self, embedding_dimension, mock_psycopg2_connection):
connection, cursor = mock_psycopg2_connection
vector_db = VectorDB(
identifier="test-vector-db",
embedding_model="test-model",
embedding_dimension=embedding_dimension,
provider_id=PGVECTOR_PROVIDER,
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
)
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
with pytest.raises(ValueError, match="Distance metric 'INVALID_METRIC' is not supported by PGVector"):
PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="INVALID_METRIC")
with pytest.raises(ValueError, match="Supported metrics are:"):
PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="UNKNOWN")
try:
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE")
assert index.distance_metric == "COSINE"
except ValueError:
pytest.fail("Valid distance metric 'COSINE' should not raise ValueError")
def test_constructor_all_supported_distance_metrics(self, embedding_dimension, mock_psycopg2_connection):
connection, cursor = mock_psycopg2_connection
vector_db = VectorDB(
identifier="test-vector-db",
embedding_model="test-model",
embedding_dimension=embedding_dimension,
provider_id=PGVECTOR_PROVIDER,
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
)
supported_metrics = ["L2", "L1", "COSINE", "INNER_PRODUCT", "HAMMING", "JACCARD"]
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
for metric in supported_metrics:
try:
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric=metric)
assert index.distance_metric == metric
expected_operators = {
"L2": "<->",
"L1": "<+>",
"COSINE": "<=>",
"INNER_PRODUCT": "<#>",
"HAMMING": "<~>",
"JACCARD": "<%>",
}
assert index.get_pgvector_search_operator() == expected_operators[metric]
except Exception as e:
pytest.fail(f"Valid distance metric '{metric}' should not raise exception: {e}")
def test_error_handling_in_constructor(self, embedding_dimension):
vector_db = VectorDB(
identifier="test-vector-db",
embedding_model="test-model",
embedding_dimension=embedding_dimension,
provider_id=PGVECTOR_PROVIDER,
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
)
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
mock_connection = MagicMock()
mock_cursor_context = MagicMock()
mock_cursor_context.__enter__.side_effect = Exception("Database connection failed")
mock_cursor_context.__exit__ = MagicMock()
mock_connection.cursor.return_value = mock_cursor_context
with pytest.raises(RuntimeError, match="Error creating PGVectorIndex"):
PGVectorIndex(vector_db, embedding_dimension, mock_connection, distance_metric="COSINE")
async def test_delete_chunks(self, pgvector_index):
index, cursor = pgvector_index
chunks_for_deletion = [
ChunkForDeletion(chunk_id="test-chunk-1", document_id="doc-1"),
ChunkForDeletion(chunk_id="test-chunk-2", document_id="doc-2"),
]
await index.delete_chunks(chunks_for_deletion)
cursor.execute.assert_called()
call_args = cursor.execute.call_args
assert "DELETE FROM" in str(call_args)
assert "test-chunk-1" in str(call_args) or "test-chunk-2" in str(call_args)
async def test_delete_index(self, pgvector_index):
"""Test deleting the entire index."""
index, cursor = pgvector_index
await index.delete()
cursor.execute.assert_called_with(f"DROP TABLE IF EXISTS {index.table_name}")
class TestPGVectorVectorIOAdapter:
@pytest.fixture
async def pgvector_adapter(self):
config = PGVectorVectorIOConfig(
host="localhost",
port=5432,
db="test_db",
user="test_user",
password="test_password",
kvstore={"type": "sqlite", "config": {"db_path": ":memory:"}},
)
inference_api = AsyncMock()
files_api = AsyncMock()
adapter = PGVectorVectorIOAdapter(config, inference_api, files_api)
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2.connect") as mock_connect:
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_cursor.__enter__ = MagicMock(return_value=mock_cursor)
mock_cursor.__exit__ = MagicMock()
mock_conn.cursor.return_value = mock_cursor
mock_connect.return_value = mock_conn
with patch.object(adapter, "kvstore") as mock_kvstore:
mock_kvstore.set = AsyncMock()
mock_kvstore.get = AsyncMock()
yield adapter, mock_conn, mock_cursor
async def test_initialization(self, pgvector_adapter):
adapter, mock_conn, mock_cursor = pgvector_adapter
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.kvstore_impl") as mock_kvstore_impl:
mock_kvstore = AsyncMock()
mock_kvstore_impl.return_value = mock_kvstore
mock_cursor.fetchone.return_value = ["0.5.0"]
await adapter.initialize()
assert adapter.conn is not None
mock_cursor.execute.assert_called()
assert adapter.metadata_collection_name == "openai_vector_stores_metadata"
async def test_register_vector_db(self, pgvector_adapter):
adapter, mock_conn, mock_cursor = pgvector_adapter
vector_db = VectorDB(
identifier="test-db",
embedding_model="test-model",
embedding_dimension=384,
provider_id=PGVECTOR_PROVIDER,
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
)
adapter.kvstore = AsyncMock()
adapter.conn = mock_conn
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.upsert_models"):
await adapter.register_vector_db(vector_db)
assert "test-db" in adapter.cache
async def test_insert_chunks(self, pgvector_adapter):
adapter, mock_conn, mock_cursor = pgvector_adapter
adapter.conn = mock_conn
chunks = create_sample_chunks()
mock_index = AsyncMock()
adapter.cache["test-db"] = mock_index
await adapter.insert_chunks("test-db", chunks)
mock_index.insert_chunks.assert_called_once_with(chunks)
async def test_delete_chunks(self, pgvector_adapter):
adapter, mock_conn, mock_cursor = pgvector_adapter
adapter.conn = mock_conn
mock_index = AsyncMock()
mock_index.index = AsyncMock()
adapter.cache["test-db"] = mock_index
chunks_for_deletion = [
ChunkForDeletion(chunk_id="chunk-1", document_id="doc-1"),
ChunkForDeletion(chunk_id="chunk-2", document_id="doc-2"),
]
await adapter.delete_chunks("test-db", chunks_for_deletion)
mock_index.index.delete_chunks.assert_called_once_with(chunks_for_deletion)

35
uv.lock generated
View file

@ -1810,6 +1810,7 @@ test = [
{ name = "datasets" }, { name = "datasets" },
{ name = "mcp" }, { name = "mcp" },
{ name = "openai" }, { name = "openai" },
{ name = "psycopg2-binary" },
{ name = "pymilvus" }, { name = "pymilvus" },
{ name = "pypdf" }, { name = "pypdf" },
{ name = "requests" }, { name = "requests" },
@ -1833,6 +1834,7 @@ unit = [
{ name = "mcp" }, { name = "mcp" },
{ name = "ollama" }, { name = "ollama" },
{ name = "openai" }, { name = "openai" },
{ name = "psycopg2-binary" },
{ name = "pymilvus" }, { name = "pymilvus" },
{ name = "pypdf" }, { name = "pypdf" },
{ name = "qdrant-client" }, { name = "qdrant-client" },
@ -1926,6 +1928,7 @@ test = [
{ name = "datasets" }, { name = "datasets" },
{ name = "mcp" }, { name = "mcp" },
{ name = "openai" }, { name = "openai" },
{ name = "psycopg2-binary", specifier = ">=2.9.0" },
{ name = "pymilvus", specifier = ">=2.5.12" }, { name = "pymilvus", specifier = ">=2.5.12" },
{ name = "pypdf" }, { name = "pypdf" },
{ name = "requests" }, { name = "requests" },
@ -1948,6 +1951,7 @@ unit = [
{ name = "mcp" }, { name = "mcp" },
{ name = "ollama" }, { name = "ollama" },
{ name = "openai" }, { name = "openai" },
{ name = "psycopg2-binary", specifier = ">=2.9.0" },
{ name = "pymilvus", specifier = ">=2.5.12" }, { name = "pymilvus", specifier = ">=2.5.12" },
{ name = "pypdf" }, { name = "pypdf" },
{ name = "qdrant-client" }, { name = "qdrant-client" },
@ -3058,6 +3062,37 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/50/1b/6921afe68c74868b4c9fa424dad3be35b095e16687989ebbb50ce4fceb7c/psutil-7.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:4cf3d4eb1aa9b348dec30105c55cd9b7d4629285735a102beb4441e38db90553", size = 244885, upload-time = "2025-02-13T21:54:37.486Z" }, { url = "https://files.pythonhosted.org/packages/50/1b/6921afe68c74868b4c9fa424dad3be35b095e16687989ebbb50ce4fceb7c/psutil-7.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:4cf3d4eb1aa9b348dec30105c55cd9b7d4629285735a102beb4441e38db90553", size = 244885, upload-time = "2025-02-13T21:54:37.486Z" },
] ]
[[package]]
name = "psycopg2-binary"
version = "2.9.10"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/cb/0e/bdc8274dc0585090b4e3432267d7be4dfbfd8971c0fa59167c711105a6bf/psycopg2-binary-2.9.10.tar.gz", hash = "sha256:4b3df0e6990aa98acda57d983942eff13d824135fe2250e6522edaa782a06de2", size = 385764, upload-time = "2024-10-16T11:24:58.126Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/49/7d/465cc9795cf76f6d329efdafca74693714556ea3891813701ac1fee87545/psycopg2_binary-2.9.10-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:880845dfe1f85d9d5f7c412efea7a08946a46894537e4e5d091732eb1d34d9a0", size = 3044771, upload-time = "2024-10-16T11:20:35.234Z" },
{ url = "https://files.pythonhosted.org/packages/8b/31/6d225b7b641a1a2148e3ed65e1aa74fc86ba3fee850545e27be9e1de893d/psycopg2_binary-2.9.10-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:9440fa522a79356aaa482aa4ba500b65f28e5d0e63b801abf6aa152a29bd842a", size = 3275336, upload-time = "2024-10-16T11:20:38.742Z" },
{ url = "https://files.pythonhosted.org/packages/30/b7/a68c2b4bff1cbb1728e3ec864b2d92327c77ad52edcd27922535a8366f68/psycopg2_binary-2.9.10-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e3923c1d9870c49a2d44f795df0c889a22380d36ef92440ff618ec315757e539", size = 2851637, upload-time = "2024-10-16T11:20:42.145Z" },
{ url = "https://files.pythonhosted.org/packages/0b/b1/cfedc0e0e6f9ad61f8657fd173b2f831ce261c02a08c0b09c652b127d813/psycopg2_binary-2.9.10-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7b2c956c028ea5de47ff3a8d6b3cc3330ab45cf0b7c3da35a2d6ff8420896526", size = 3082097, upload-time = "2024-10-16T11:20:46.185Z" },
{ url = "https://files.pythonhosted.org/packages/18/ed/0a8e4153c9b769f59c02fb5e7914f20f0b2483a19dae7bf2db54b743d0d0/psycopg2_binary-2.9.10-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f758ed67cab30b9a8d2833609513ce4d3bd027641673d4ebc9c067e4d208eec1", size = 3264776, upload-time = "2024-10-16T11:20:50.879Z" },
{ url = "https://files.pythonhosted.org/packages/10/db/d09da68c6a0cdab41566b74e0a6068a425f077169bed0946559b7348ebe9/psycopg2_binary-2.9.10-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8cd9b4f2cfab88ed4a9106192de509464b75a906462fb846b936eabe45c2063e", size = 3020968, upload-time = "2024-10-16T11:20:56.819Z" },
{ url = "https://files.pythonhosted.org/packages/94/28/4d6f8c255f0dfffb410db2b3f9ac5218d959a66c715c34cac31081e19b95/psycopg2_binary-2.9.10-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6dc08420625b5a20b53551c50deae6e231e6371194fa0651dbe0fb206452ae1f", size = 2872334, upload-time = "2024-10-16T11:21:02.411Z" },
{ url = "https://files.pythonhosted.org/packages/05/f7/20d7bf796593c4fea95e12119d6cc384ff1f6141a24fbb7df5a668d29d29/psycopg2_binary-2.9.10-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:d7cd730dfa7c36dbe8724426bf5612798734bff2d3c3857f36f2733f5bfc7c00", size = 2822722, upload-time = "2024-10-16T11:21:09.01Z" },
{ url = "https://files.pythonhosted.org/packages/4d/e4/0c407ae919ef626dbdb32835a03b6737013c3cc7240169843965cada2bdf/psycopg2_binary-2.9.10-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:155e69561d54d02b3c3209545fb08938e27889ff5a10c19de8d23eb5a41be8a5", size = 2920132, upload-time = "2024-10-16T11:21:16.339Z" },
{ url = "https://files.pythonhosted.org/packages/2d/70/aa69c9f69cf09a01da224909ff6ce8b68faeef476f00f7ec377e8f03be70/psycopg2_binary-2.9.10-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c3cc28a6fd5a4a26224007712e79b81dbaee2ffb90ff406256158ec4d7b52b47", size = 2959312, upload-time = "2024-10-16T11:21:25.584Z" },
{ url = "https://files.pythonhosted.org/packages/d3/bd/213e59854fafe87ba47814bf413ace0dcee33a89c8c8c814faca6bc7cf3c/psycopg2_binary-2.9.10-cp312-cp312-win32.whl", hash = "sha256:ec8a77f521a17506a24a5f626cb2aee7850f9b69a0afe704586f63a464f3cd64", size = 1025191, upload-time = "2024-10-16T11:21:29.912Z" },
{ url = "https://files.pythonhosted.org/packages/92/29/06261ea000e2dc1e22907dbbc483a1093665509ea586b29b8986a0e56733/psycopg2_binary-2.9.10-cp312-cp312-win_amd64.whl", hash = "sha256:18c5ee682b9c6dd3696dad6e54cc7ff3a1a9020df6a5c0f861ef8bfd338c3ca0", size = 1164031, upload-time = "2024-10-16T11:21:34.211Z" },
{ url = "https://files.pythonhosted.org/packages/3e/30/d41d3ba765609c0763505d565c4d12d8f3c79793f0d0f044ff5a28bf395b/psycopg2_binary-2.9.10-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:26540d4a9a4e2b096f1ff9cce51253d0504dca5a85872c7f7be23be5a53eb18d", size = 3044699, upload-time = "2024-10-16T11:21:42.841Z" },
{ url = "https://files.pythonhosted.org/packages/35/44/257ddadec7ef04536ba71af6bc6a75ec05c5343004a7ec93006bee66c0bc/psycopg2_binary-2.9.10-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:e217ce4d37667df0bc1c397fdcd8de5e81018ef305aed9415c3b093faaeb10fb", size = 3275245, upload-time = "2024-10-16T11:21:51.989Z" },
{ url = "https://files.pythonhosted.org/packages/1b/11/48ea1cd11de67f9efd7262085588790a95d9dfcd9b8a687d46caf7305c1a/psycopg2_binary-2.9.10-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:245159e7ab20a71d989da00f280ca57da7641fa2cdcf71749c193cea540a74f7", size = 2851631, upload-time = "2024-10-16T11:21:57.584Z" },
{ url = "https://files.pythonhosted.org/packages/62/e0/62ce5ee650e6c86719d621a761fe4bc846ab9eff8c1f12b1ed5741bf1c9b/psycopg2_binary-2.9.10-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3c4ded1a24b20021ebe677b7b08ad10bf09aac197d6943bfe6fec70ac4e4690d", size = 3082140, upload-time = "2024-10-16T11:22:02.005Z" },
{ url = "https://files.pythonhosted.org/packages/27/ce/63f946c098611f7be234c0dd7cb1ad68b0b5744d34f68062bb3c5aa510c8/psycopg2_binary-2.9.10-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3abb691ff9e57d4a93355f60d4f4c1dd2d68326c968e7db17ea96df3c023ef73", size = 3264762, upload-time = "2024-10-16T11:22:06.412Z" },
{ url = "https://files.pythonhosted.org/packages/43/25/c603cd81402e69edf7daa59b1602bd41eb9859e2824b8c0855d748366ac9/psycopg2_binary-2.9.10-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8608c078134f0b3cbd9f89b34bd60a943b23fd33cc5f065e8d5f840061bd0673", size = 3020967, upload-time = "2024-10-16T11:22:11.583Z" },
{ url = "https://files.pythonhosted.org/packages/5f/d6/8708d8c6fca531057fa170cdde8df870e8b6a9b136e82b361c65e42b841e/psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:230eeae2d71594103cd5b93fd29d1ace6420d0b86f4778739cb1a5a32f607d1f", size = 2872326, upload-time = "2024-10-16T11:22:16.406Z" },
{ url = "https://files.pythonhosted.org/packages/ce/ac/5b1ea50fc08a9df82de7e1771537557f07c2632231bbab652c7e22597908/psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:bb89f0a835bcfc1d42ccd5f41f04870c1b936d8507c6df12b7737febc40f0909", size = 2822712, upload-time = "2024-10-16T11:22:21.366Z" },
{ url = "https://files.pythonhosted.org/packages/c4/fc/504d4503b2abc4570fac3ca56eb8fed5e437bf9c9ef13f36b6621db8ef00/psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f0c2d907a1e102526dd2986df638343388b94c33860ff3bbe1384130828714b1", size = 2920155, upload-time = "2024-10-16T11:22:25.684Z" },
{ url = "https://files.pythonhosted.org/packages/b2/d1/323581e9273ad2c0dbd1902f3fb50c441da86e894b6e25a73c3fda32c57e/psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f8157bed2f51db683f31306aa497311b560f2265998122abe1dce6428bd86567", size = 2959356, upload-time = "2024-10-16T11:22:30.562Z" },
{ url = "https://files.pythonhosted.org/packages/08/50/d13ea0a054189ae1bc21af1d85b6f8bb9bbc5572991055d70ad9006fe2d6/psycopg2_binary-2.9.10-cp313-cp313-win_amd64.whl", hash = "sha256:27422aa5f11fbcd9b18da48373eb67081243662f9b46e6fd07c3eb46e4535142", size = 2569224, upload-time = "2025-01-04T20:09:19.234Z" },
]
[[package]] [[package]]
name = "ptyprocess" name = "ptyprocess"
version = "0.7.0" version = "0.7.0"