diff --git a/docs/source/providers/vector_io/remote_pgvector.md b/docs/source/providers/vector_io/remote_pgvector.md index 74f588a13..689ac8ccc 100644 --- a/docs/source/providers/vector_io/remote_pgvector.md +++ b/docs/source/providers/vector_io/remote_pgvector.md @@ -12,6 +12,60 @@ That means you'll get fast and efficient vector retrieval. - Easy to use - Fully integrated with Llama Stack +Three implementations of search for PGVectoIndex: + +1. Vector Search: +- How it works: + - Uses PostgreSQL's vector extension (pgvector) to perform similarity search + - Compares query embeddings against stored embeddings using Cosine distance or other distance metrics + - Eg. SQL query: SELECT document, embedding <=> %s::vector AS distance FROM table ORDER BY distance + +-Characteristics: + - Semantic understanding - finds documents similar in meaning even if they don't share keywords + - Works with high-dimensional vector embeddings (typically 768, 1024, or higher dimensions) + - Best for: Finding conceptually related content, handling synonyms, cross-language search + +2. Keyword Search +- How it works: + - Uses PostgreSQL's full-text search capabilities with tsvector and ts_rank + - Converts text to searchable tokens using to_tsvector('english', text) + - Eg. SQL query: SELECT document, ts_rank(content_tsvector, plainto_tsquery('english', %s)) AS score + +- Characteristics: + - Lexical matching - finds exact keyword matches and variations + - Uses GIN (Generalized Inverted Index) for fast text search performance + - Scoring: Uses PostgreSQL's ts_rank function for relevance scoring + - Best for: Exact term matching, proper names, technical terms, Boolean-style queries + +3. Hybrid Search +- How it works: + - Combines both vector and keyword search results + - Runs both searches independently, then merges results using configurable reranking + +- Two reranking strategies available: + - Reciprocal Rank Fusion (RRF) - (default: 60.0) + - Weighted Average - (default: 0.5) + +- Characteristics: + - Best of both worlds: semantic understanding + exact matching + - Documents appearing in both searches get boosted scores + - Configurable balance between semantic and lexical matching + - Best for: General-purpose search where you want both precision and recall + +4. Database Schema +The PGVector implementation stores data optimized for all three search types: +CREATE TABLE vector_store_xxx ( + id TEXT PRIMARY KEY, + document JSONB, -- Original document + embedding vector(dimension), -- For vector search + content_text TEXT, -- Raw text content + content_tsvector TSVECTOR -- For keyword search +); + +-- Indexes for performance +CREATE INDEX content_gin_idx ON table USING GIN(content_tsvector); -- Keyword search +-- Vector index created automatically by pgvector + ## Usage To use PGVector in your Llama Stack project, follow these steps: diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index cc1982f3b..572e58517 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -35,6 +35,7 @@ from llama_stack.providers.utils.memory.vector_store import ( EmbeddingIndex, VectorDBWithIndex, ) +from llama_stack.providers.utils.vector_io.vector_utils import Reranker logger = logging.getLogger(__name__) @@ -66,59 +67,6 @@ def _create_sqlite_connection(db_path): return connection -def _normalize_scores(scores: dict[str, float]) -> dict[str, float]: - """Normalize scores to [0,1] range using min-max normalization.""" - if not scores: - return {} - min_score = min(scores.values()) - max_score = max(scores.values()) - score_range = max_score - min_score - if score_range > 0: - return {doc_id: (score - min_score) / score_range for doc_id, score in scores.items()} - return dict.fromkeys(scores, 1.0) - - -def _weighted_rerank( - vector_scores: dict[str, float], - keyword_scores: dict[str, float], - alpha: float = 0.5, -) -> dict[str, float]: - """ReRanker that uses weighted average of scores.""" - all_ids = set(vector_scores.keys()) | set(keyword_scores.keys()) - normalized_vector_scores = _normalize_scores(vector_scores) - normalized_keyword_scores = _normalize_scores(keyword_scores) - - return { - doc_id: (alpha * normalized_keyword_scores.get(doc_id, 0.0)) - + ((1 - alpha) * normalized_vector_scores.get(doc_id, 0.0)) - for doc_id in all_ids - } - - -def _rrf_rerank( - vector_scores: dict[str, float], - keyword_scores: dict[str, float], - impact_factor: float = 60.0, -) -> dict[str, float]: - """ReRanker that uses Reciprocal Rank Fusion.""" - # Convert scores to ranks - vector_ranks = { - doc_id: i + 1 for i, (doc_id, _) in enumerate(sorted(vector_scores.items(), key=lambda x: x[1], reverse=True)) - } - keyword_ranks = { - doc_id: i + 1 for i, (doc_id, _) in enumerate(sorted(keyword_scores.items(), key=lambda x: x[1], reverse=True)) - } - - all_ids = set(vector_scores.keys()) | set(keyword_scores.keys()) - rrf_scores = {} - for doc_id in all_ids: - vector_rank = vector_ranks.get(doc_id, float("inf")) - keyword_rank = keyword_ranks.get(doc_id, float("inf")) - # RRF formula: score = 1/(k + r) where k is impact_factor and r is the rank - rrf_scores[doc_id] = (1.0 / (impact_factor + vector_rank)) + (1.0 / (impact_factor + keyword_rank)) - return rrf_scores - - def _make_sql_identifier(name: str) -> str: return re.sub(r"[^a-zA-Z0-9_]", "_", name) @@ -401,11 +349,11 @@ class SQLiteVecIndex(EmbeddingIndex): # Combine scores using the specified reranker if reranker_type == RERANKER_TYPE_WEIGHTED: alpha = reranker_params.get("alpha", 0.5) - combined_scores = _weighted_rerank(vector_scores, keyword_scores, alpha) + combined_scores = Reranker.weighted_rerank(vector_scores, keyword_scores, alpha) else: # Default to RRF for None, RRF, or any unknown types impact_factor = reranker_params.get("impact_factor", 60.0) - combined_scores = _rrf_rerank(vector_scores, keyword_scores, impact_factor) + combined_scores = Reranker.rrf_rerank(vector_scores, keyword_scores, impact_factor) # Sort by combined score and get top k results sorted_items = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True) diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py index 70148eb15..dbddca533 100644 --- a/llama_stack/providers/registry/vector_io.py +++ b/llama_stack/providers/registry/vector_io.py @@ -404,6 +404,60 @@ That means you'll get fast and efficient vector retrieval. - Easy to use - Fully integrated with Llama Stack +Three implementations of search for PGVectoIndex: + +1. Vector Search: +- How it works: + - Uses PostgreSQL's vector extension (pgvector) to perform similarity search + - Compares query embeddings against stored embeddings using Cosine distance or other distance metrics + - Eg. SQL query: SELECT document, embedding <=> %s::vector AS distance FROM table ORDER BY distance + +-Characteristics: + - Semantic understanding - finds documents similar in meaning even if they don't share keywords + - Works with high-dimensional vector embeddings (typically 768, 1024, or higher dimensions) + - Best for: Finding conceptually related content, handling synonyms, cross-language search + +2. Keyword Search +- How it works: + - Uses PostgreSQL's full-text search capabilities with tsvector and ts_rank + - Converts text to searchable tokens using to_tsvector('english', text) + - Eg. SQL query: SELECT document, ts_rank(content_tsvector, plainto_tsquery('english', %s)) AS score + +- Characteristics: + - Lexical matching - finds exact keyword matches and variations + - Uses GIN (Generalized Inverted Index) for fast text search performance + - Scoring: Uses PostgreSQL's ts_rank function for relevance scoring + - Best for: Exact term matching, proper names, technical terms, Boolean-style queries + +3. Hybrid Search +- How it works: + - Combines both vector and keyword search results + - Runs both searches independently, then merges results using configurable reranking + +- Two reranking strategies available: + - Reciprocal Rank Fusion (RRF) - (default: 60.0) + - Weighted Average - (default: 0.5) + +- Characteristics: + - Best of both worlds: semantic understanding + exact matching + - Documents appearing in both searches get boosted scores + - Configurable balance between semantic and lexical matching + - Best for: General-purpose search where you want both precision and recall + +4. Database Schema +The PGVector implementation stores data optimized for all three search types: +CREATE TABLE vector_store_xxx ( + id TEXT PRIMARY KEY, + document JSONB, -- Original document + embedding vector(dimension), -- For vector search + content_text TEXT, -- Raw text content + content_tsvector TSVECTOR -- For keyword search +); + +-- Indexes for performance +CREATE INDEX content_gin_idx ON table USING GIN(content_tsvector); -- Keyword search +-- Vector index created automatically by pgvector + ## Usage To use PGVector in your Llama Stack project, follow these steps: @@ -449,6 +503,7 @@ Weaviate supports: - Metadata filtering - Multi-modal retrieval + ## Usage To use Weaviate in your Llama Stack project, follow these steps: diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index d2a5d910b..a92727c7c 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import heapq import logging from typing import Any @@ -23,6 +24,9 @@ from llama_stack.apis.vector_io import ( VectorIO, ) from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate +from llama_stack.providers.utils.inference.prompt_adapter import ( + interleaved_content_as_str, +) from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin @@ -31,6 +35,7 @@ from llama_stack.providers.utils.memory.vector_store import ( EmbeddingIndex, VectorDBWithIndex, ) +from llama_stack.providers.utils.vector_io.vector_utils import Reranker from .config import PGVectorVectorIOConfig @@ -72,25 +77,63 @@ def load_models(cur, cls): class PGVectorIndex(EmbeddingIndex): - def __init__(self, vector_db: VectorDB, dimension: int, conn, kvstore: KVStore | None = None): - self.conn = conn - with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: - # Sanitize the table name by replacing hyphens with underscores - # SQL doesn't allow hyphens in table names, and vector_db.identifier may contain hyphens - # when created with patterns like "test-vector-db-{uuid4()}" - sanitized_identifier = vector_db.identifier.replace("-", "_") - self.table_name = f"vector_store_{sanitized_identifier}" - self.kvstore = kvstore + # reference: https://github.com/pgvector/pgvector?tab=readme-ov-file#querying + PGVECTOR_DISTANCE_METRIC_TO_SEARCH_OPERATOR: dict[str, str] = { + "L2": "<->", # Euclidean distance + "L1": "<+>", # Manhattan distance + "COSINE": "<=>", # Cosine distance + "INNER_PRODUCT": "<#>", # Inner product distance + "HAMMING": "<~>", # Hamming distance + "JACCARD": "<%>", # Jaccard distance + } - cur.execute( - f""" - CREATE TABLE IF NOT EXISTS {self.table_name} ( - id TEXT PRIMARY KEY, - document JSONB, - embedding vector({dimension}) + def __init__( + self, + vector_db: VectorDB, + dimension: int, + conn: psycopg2.extensions.connection, + kvstore: KVStore | None = None, + distance_metric: str = "COSINE", + ): + self.conn = conn + self.check_distance_metric_availability(distance_metric) + self.distance_metric = distance_metric + + try: + with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + # Sanitize the table name by replacing hyphens with underscores + # SQL doesn't allow hyphens in table names, and vector_db.identifier may contain hyphens + # when created with patterns like "test-vector-db-{uuid4()}" + sanitized_identifier = vector_db.identifier.replace("-", "_") + self.table_name = f"vector_store_{sanitized_identifier}" + self.kvstore = kvstore + + cur.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.table_name} ( + id TEXT PRIMARY KEY, + document JSONB, + embedding vector({dimension}), + content_text TEXT, + content_tsvector TSVECTOR + ) + """ ) - """ - ) + + # Create GIN index for full-text search performance + cur.execute( + f""" + CREATE INDEX IF NOT EXISTS {self.table_name}_content_gin_idx + ON {self.table_name} USING GIN(content_tsvector) + """ + ) + except Exception as e: + log.exception(f"Error creating PGVectorIndex for vector_db: {vector_db.identifier}") + raise RuntimeError(f"Error creating PGVectorIndex for vector_db: {vector_db.identifier}") from e + + async def initialize(self) -> None: + # PGVectorIndex does not require explicit initialization + pass async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): assert len(chunks) == len(embeddings), ( @@ -99,29 +142,49 @@ class PGVectorIndex(EmbeddingIndex): values = [] for i, chunk in enumerate(chunks): + content_text = interleaved_content_as_str(chunk.content) values.append( ( f"{chunk.chunk_id}", Json(chunk.model_dump()), embeddings[i].tolist(), + content_text, + content_text, # Pass content_text twice - once for content_text column, once for to_tsvector function ) ) query = sql.SQL( f""" - INSERT INTO {self.table_name} (id, document, embedding) + INSERT INTO {self.table_name} (id, document, embedding, content_text, content_tsvector) VALUES %s - ON CONFLICT (id) DO UPDATE SET embedding = EXCLUDED.embedding, document = EXCLUDED.document + ON CONFLICT (id) DO UPDATE SET + embedding = EXCLUDED.embedding, + document = EXCLUDED.document, + content_text = EXCLUDED.content_text, + content_tsvector = EXCLUDED.content_tsvector """ ) with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: - execute_values(cur, query, values, template="(%s, %s, %s::vector)") + execute_values(cur, query, values, template="(%s, %s, %s::vector, %s, to_tsvector('english', %s))") async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + """ + Performs vector similarity search using PostgreSQL's operators. Default distance metric is COSINE. + + Args: + embedding: The query embedding vector + k: Number of results to return + score_threshold: Minimum similarity score threshold + + Returns: + QueryChunksResponse with combined results + """ + pgvector_search_operator = self.get_pgvector_search_operator() + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: cur.execute( f""" - SELECT document, embedding <-> %s::vector AS distance + SELECT document, embedding {pgvector_search_operator} %s::vector AS distance FROM {self.table_name} ORDER BY distance LIMIT %s @@ -147,7 +210,40 @@ class PGVectorIndex(EmbeddingIndex): k: int, score_threshold: float, ) -> QueryChunksResponse: - raise NotImplementedError("Keyword search is not supported in PGVector") + """ + Performs keyword-based search using PostgreSQL's full-text search with ts_rank scoring. + + Args: + query_string: The text query for keyword search + k: Number of results to return + score_threshold: Minimum similarity score threshold + + Returns: + QueryChunksResponse with combined results + """ + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + # Use plainto_tsquery to handle user input safely and ts_rank for relevance scoring + cur.execute( + f""" + SELECT document, ts_rank(content_tsvector, plainto_tsquery('english', %s)) AS score + FROM {self.table_name} + WHERE content_tsvector @@ plainto_tsquery('english', %s) + ORDER BY score DESC + LIMIT %s + """, + (query_string, query_string, k), + ) + results = cur.fetchall() + + chunks = [] + scores = [] + for doc, score in results: + if score < score_threshold: + continue + chunks.append(Chunk(**doc)) + scores.append(float(score)) + + return QueryChunksResponse(chunks=chunks, scores=scores) async def query_hybrid( self, @@ -158,7 +254,57 @@ class PGVectorIndex(EmbeddingIndex): reranker_type: str, reranker_params: dict[str, Any] | None = None, ) -> QueryChunksResponse: - raise NotImplementedError("Hybrid search is not supported in PGVector") + """ + Hybrid search combining vector similarity and keyword search using configurable reranking. + + Args: + embedding: The query embedding vector + query_string: The text query for keyword search + k: Number of results to return + score_threshold: Minimum similarity score threshold + reranker_type: Type of reranker to use ("rrf" or "weighted") + reranker_params: Parameters for the reranker + + Returns: + QueryChunksResponse with combined results + """ + if reranker_params is None: + reranker_params = {} + + # Get results from both search methods + vector_response = await self.query_vector(embedding, k, score_threshold) + keyword_response = await self.query_keyword(query_string, k, score_threshold) + + # Convert responses to score dictionaries using chunk_id + vector_scores = { + chunk.chunk_id: score for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False) + } + keyword_scores = { + chunk.chunk_id: score + for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False) + } + + # Combine scores using the reranking utility + combined_scores = Reranker.combine_search_results(vector_scores, keyword_scores, reranker_type, reranker_params) + + # Efficient top-k selection because it only tracks the k best candidates it's seen so far + top_k_items = heapq.nlargest(k, combined_scores.items(), key=lambda x: x[1]) + + # Filter by score threshold + filtered_items = [(doc_id, score) for doc_id, score in top_k_items if score >= score_threshold] + + # Create a map of chunk_id to chunk for both responses + chunk_map = {c.chunk_id: c for c in vector_response.chunks + keyword_response.chunks} + + # Use the map to look up chunks by their IDs + chunks = [] + scores = [] + for doc_id, score in filtered_items: + if doc_id in chunk_map: + chunks.append(chunk_map[doc_id]) + scores.append(score) + + return QueryChunksResponse(chunks=chunks, scores=scores) async def delete(self): with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: @@ -170,6 +316,25 @@ class PGVectorIndex(EmbeddingIndex): with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids,)) + def get_pgvector_search_operator(self) -> str: + return self.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_OPERATOR[self.distance_metric] + + def check_distance_metric_availability(self, distance_metric: str) -> None: + """Check if the distance metric is supported by PGVector. + + Args: + distance_metric: The distance metric to check + + Raises: + ValueError: If the distance metric is not supported + """ + if distance_metric not in self.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_OPERATOR: + supported_metrics = list(self.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_OPERATOR.keys()) + raise ValueError( + f"Distance metric '{distance_metric}' is not supported by PGVector. " + f"Supported metrics are: {', '.join(supported_metrics)}" + ) + class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): def __init__( @@ -185,8 +350,8 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco self.files_api = files_api self.kvstore: KVStore | None = None self.vector_db_store = None - self.openai_vector_store: dict[str, dict[str, Any]] = {} - self.metadatadata_collection_name = "openai_vector_stores_metadata" + self.openai_vector_stores: dict[str, dict[str, Any]] = {} + self.metadata_collection_name = "openai_vector_stores_metadata" async def initialize(self) -> None: log.info(f"Initializing PGVector memory adapter with config: {self.config}") @@ -272,7 +437,13 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco if vector_db_id in self.cache: return self.cache[vector_db_id] + if self.vector_db_store is None: + raise VectorStoreNotFoundError(vector_db_id) + vector_db = await self.vector_db_store.get_vector_db(vector_db_id) + if not vector_db: + raise VectorStoreNotFoundError(vector_db_id) + index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn) self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api) return self.cache[vector_db_id] diff --git a/llama_stack/providers/utils/vector_io/vector_utils.py b/llama_stack/providers/utils/vector_io/vector_utils.py index f2888043e..4aeaab9c2 100644 --- a/llama_stack/providers/utils/vector_io/vector_utils.py +++ b/llama_stack/providers/utils/vector_io/vector_utils.py @@ -37,3 +37,122 @@ def sanitize_collection_name(name: str, weaviate_format=False) -> str: else: s = proper_case(re.sub(r"[^a-zA-Z0-9]", "", name)) return s + + +class Reranker: + @staticmethod + def _normalize_scores(scores: dict[str, float]) -> dict[str, float]: + """ + Normalize scores to 0-1 range using min-max normalization. + + Args: + scores: dictionary of scores with document IDs as keys and scores as values + + Returns: + Normalized scores with document IDs as keys and normalized scores as values + """ + if not scores: + return {} + min_score, max_score = min(scores.values()), max(scores.values()) + score_range = max_score - min_score + if score_range > 0: + return {doc_id: (score - min_score) / score_range for doc_id, score in scores.items()} + return dict.fromkeys(scores, 1.0) + + @staticmethod + def weighted_rerank( + vector_scores: dict[str, float], + keyword_scores: dict[str, float], + alpha: float = 0.5, + ) -> dict[str, float]: + """ + Rerank via weighted average of scores. + + Args: + vector_scores: scores from vector search + keyword_scores: scores from keyword search + alpha: weight factor between 0 and 1 (default: 0.5) + 0 = keyword only, 1 = vector only, 0.5 = equal weight + + Returns: + All unique document IDs with weighted combined scores + """ + all_ids = set(vector_scores.keys()) | set(keyword_scores.keys()) + normalized_vector_scores = Reranker._normalize_scores(vector_scores) + normalized_keyword_scores = Reranker._normalize_scores(keyword_scores) + + # Weighted formula: score = (1-alpha) * keyword_score + alpha * vector_score + # alpha=0 means keyword only, alpha=1 means vector only + return { + doc_id: ((1 - alpha) * normalized_keyword_scores.get(doc_id, 0.0)) + + (alpha * normalized_vector_scores.get(doc_id, 0.0)) + for doc_id in all_ids + } + + @staticmethod + def rrf_rerank( + vector_scores: dict[str, float], + keyword_scores: dict[str, float], + impact_factor: float = 60.0, + ) -> dict[str, float]: + """ + Rerank via Reciprocal Rank Fusion. + + Args: + vector_scores: scores from vector search + keyword_scores: scores from keyword search + impact_factor: impact factor for RRF (default: 60.0) + + Returns: + All unique document IDs with RRF combined scores + """ + + # Convert scores to ranks + vector_ranks = { + doc_id: i + 1 + for i, (doc_id, _) in enumerate(sorted(vector_scores.items(), key=lambda x: x[1], reverse=True)) + } + keyword_ranks = { + doc_id: i + 1 + for i, (doc_id, _) in enumerate(sorted(keyword_scores.items(), key=lambda x: x[1], reverse=True)) + } + + all_ids = set(vector_scores.keys()) | set(keyword_scores.keys()) + rrf_scores = {} + for doc_id in all_ids: + vector_rank = vector_ranks.get(doc_id, float("inf")) + keyword_rank = keyword_ranks.get(doc_id, float("inf")) + + # RRF formula: score = 1/(k + r) where k is impact_factor (default: 60.0) and r is the rank + rrf_scores[doc_id] = (1.0 / (impact_factor + vector_rank)) + (1.0 / (impact_factor + keyword_rank)) + return rrf_scores + + @staticmethod + def combine_search_results( + vector_scores: dict[str, float], + keyword_scores: dict[str, float], + reranker_type: str = "rrf", + reranker_params: dict[str, float] | None = None, + ) -> dict[str, float]: + """ + Combine vector and keyword search results using specified reranking strategy. + + Args: + vector_scores: scores from vector search + keyword_scores: scores from keyword search + reranker_type: type of reranker to use (default: RERANKER_TYPE_RRF) + reranker_params: parameters for the reranker + + Returns: + All unique document IDs with combined scores + """ + if reranker_params is None: + reranker_params = {} + + if reranker_type == "weighted": + alpha = reranker_params.get("alpha", 0.5) + return Reranker.weighted_rerank(vector_scores, keyword_scores, alpha) + else: + # Default to RRF for None, RRF, or any unknown types + impact_factor = reranker_params.get("impact_factor", 60.0) + return Reranker.rrf_rerank(vector_scores, keyword_scores, impact_factor) diff --git a/pyproject.toml b/pyproject.toml index db0ad1f00..e6c0595b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,6 +84,7 @@ unit = [ "openai", "aiosqlite", "aiohttp", + "psycopg2-binary>=2.9.0", "pypdf", "mcp", "chardet", @@ -109,6 +110,7 @@ test = [ "torch>=2.6.0", "torchvision>=0.21.0", "chardet", + "psycopg2-binary>=2.9.0", "pypdf", "mcp", "datasets", diff --git a/tests/integration/vector_io/test_openai_vector_stores.py b/tests/integration/vector_io/test_openai_vector_stores.py index 7ccca9077..155d4ff42 100644 --- a/tests/integration/vector_io/test_openai_vector_stores.py +++ b/tests/integration/vector_io/test_openai_vector_stores.py @@ -56,11 +56,13 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode "keyword": [ "inline::sqlite-vec", "remote::milvus", + "remote::pgvector", ], "hybrid": [ "inline::sqlite-vec", "inline::milvus", "remote::milvus", + "remote::pgvector", ], } supported_providers = search_mode_support.get(search_mode, []) diff --git a/tests/unit/providers/utils/memory/test_reranking.py b/tests/unit/providers/utils/memory/test_reranking.py new file mode 100644 index 000000000..cb2a6383b --- /dev/null +++ b/tests/unit/providers/utils/memory/test_reranking.py @@ -0,0 +1,246 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from llama_stack.providers.utils.memory.vector_store import RERANKER_TYPE_RRF, RERANKER_TYPE_WEIGHTED +from llama_stack.providers.utils.vector_io.vector_utils import Reranker + + +class TestNormalizeScores: + """Test cases for score normalization.""" + + def test_normalize_scores_basic(self): + """Test basic score normalization.""" + scores = {"doc1": 10.0, "doc2": 5.0, "doc3": 0.0} + normalized = Reranker._normalize_scores(scores) + + assert normalized["doc1"] == 1.0 # Max score + assert normalized["doc3"] == 0.0 # Min score + assert normalized["doc2"] == 0.5 # Middle score + assert all(0 <= score <= 1 for score in normalized.values()) + + def test_normalize_scores_identical(self): + """Test normalization when all scores are identical.""" + scores = {"doc1": 5.0, "doc2": 5.0, "doc3": 5.0} + normalized = Reranker._normalize_scores(scores) + + # All scores should be 1.0 when identical + assert all(score == 1.0 for score in normalized.values()) + + def test_normalize_scores_empty(self): + """Test normalization with empty scores.""" + scores = {} + normalized = Reranker._normalize_scores(scores) + + assert normalized == {} + + def test_normalize_scores_single(self): + """Test normalization with single score.""" + scores = {"doc1": 7.5} + normalized = Reranker._normalize_scores(scores) + + assert normalized["doc1"] == 1.0 + + +class TestWeightedRerank: + """Test cases for weighted reranking.""" + + def test_weighted_rerank_basic(self): + """Test basic weighted reranking.""" + vector_scores = {"doc1": 0.9, "doc2": 0.7, "doc3": 0.5} + keyword_scores = {"doc1": 0.6, "doc2": 0.8, "doc4": 0.9} + + combined = Reranker.weighted_rerank(vector_scores, keyword_scores, alpha=0.5) + + # Should include all documents + expected_docs = {"doc1", "doc2", "doc3", "doc4"} + assert set(combined.keys()) == expected_docs + + # All scores should be between 0 and 1 + assert all(0 <= score <= 1 for score in combined.values()) + + # doc1 appears in both searches, should have higher combined score + assert combined["doc1"] > 0 + + def test_weighted_rerank_alpha_zero(self): + """Test weighted reranking with alpha=0 (keyword only).""" + vector_scores = {"doc1": 0.9, "doc2": 0.7, "doc3": 0.5} # All docs present in vector + keyword_scores = {"doc1": 0.1, "doc2": 0.3, "doc3": 0.9} # All docs present in keyword + + combined = Reranker.weighted_rerank(vector_scores, keyword_scores, alpha=0.0) + + # Alpha=0 means vector scores are ignored, keyword scores dominate + # doc3 should score highest since it has highest keyword score + assert combined["doc3"] > combined["doc2"] > combined["doc1"] + + def test_weighted_rerank_alpha_one(self): + """Test weighted reranking with alpha=1 (vector only).""" + vector_scores = {"doc1": 0.9, "doc2": 0.7, "doc3": 0.5} # All docs present in vector + keyword_scores = {"doc1": 0.1, "doc2": 0.3, "doc3": 0.9} # All docs present in keyword + + combined = Reranker.weighted_rerank(vector_scores, keyword_scores, alpha=1.0) + + # Alpha=1 means keyword scores are ignored, vector scores dominate + # doc1 should score highest since it has highest vector score + assert combined["doc1"] > combined["doc2"] > combined["doc3"] + + def test_weighted_rerank_no_overlap(self): + """Test weighted reranking with no overlapping documents.""" + vector_scores = {"doc1": 0.9, "doc2": 0.7} + keyword_scores = {"doc3": 0.8, "doc4": 0.6} + + combined = Reranker.weighted_rerank(vector_scores, keyword_scores, alpha=0.5) + + assert len(combined) == 4 + # With min-max normalization, lowest scoring docs in each group get 0.0 + # but highest scoring docs should get positive scores + assert all(score >= 0 for score in combined.values()) + assert combined["doc1"] > 0 # highest vector score + assert combined["doc3"] > 0 # highest keyword score + + +class TestRRFRerank: + """Test cases for RRF (Reciprocal Rank Fusion) reranking.""" + + def test_rrf_rerank_basic(self): + """Test basic RRF reranking.""" + vector_scores = {"doc1": 0.9, "doc2": 0.7, "doc3": 0.5} + keyword_scores = {"doc1": 0.6, "doc2": 0.8, "doc4": 0.9} + + combined = Reranker.rrf_rerank(vector_scores, keyword_scores, impact_factor=60.0) + + # Should include all documents + expected_docs = {"doc1", "doc2", "doc3", "doc4"} + assert set(combined.keys()) == expected_docs + + # All scores should be positive + assert all(score > 0 for score in combined.values()) + + # Documents appearing in both searches should have higher scores + # doc1 and doc2 appear in both, doc3 and doc4 appear in only one + assert combined["doc1"] > combined["doc3"] + assert combined["doc2"] > combined["doc4"] + + def test_rrf_rerank_rank_calculation(self): + """Test that RRF correctly calculates ranks.""" + # Create clear ranking order + vector_scores = {"doc1": 1.0, "doc2": 0.8, "doc3": 0.6} # Ranks: 1, 2, 3 + keyword_scores = {"doc1": 0.5, "doc2": 1.0, "doc3": 0.7} # Ranks: 3, 1, 2 + + combined = Reranker.rrf_rerank(vector_scores, keyword_scores, impact_factor=60.0) + + # doc1: rank 1 in vector, rank 3 in keyword + # doc2: rank 2 in vector, rank 1 in keyword + # doc3: rank 3 in vector, rank 2 in keyword + + # doc2 should have the highest combined score (ranks 2+1=3) + # followed by doc1 (ranks 1+3=4) and doc3 (ranks 3+2=5) + # Remember: lower rank sum = higher RRF score + assert combined["doc2"] > combined["doc1"] > combined["doc3"] + + def test_rrf_rerank_impact_factor(self): + """Test that impact factor affects RRF scores.""" + vector_scores = {"doc1": 0.9, "doc2": 0.7} + keyword_scores = {"doc1": 0.8, "doc2": 0.6} + + combined_low = Reranker.rrf_rerank(vector_scores, keyword_scores, impact_factor=10.0) + combined_high = Reranker.rrf_rerank(vector_scores, keyword_scores, impact_factor=100.0) + + # Higher impact factor should generally result in lower scores + # (because 1/(k+r) decreases as k increases) + assert combined_low["doc1"] > combined_high["doc1"] + assert combined_low["doc2"] > combined_high["doc2"] + + def test_rrf_rerank_missing_documents(self): + """Test RRF handling of documents missing from one search.""" + vector_scores = {"doc1": 0.9, "doc2": 0.7} + keyword_scores = {"doc1": 0.8, "doc3": 0.6} + + combined = Reranker.rrf_rerank(vector_scores, keyword_scores, impact_factor=60.0) + + # Should include all documents + assert len(combined) == 3 + + # doc1 appears in both searches, should have highest score + assert combined["doc1"] > combined["doc2"] + assert combined["doc1"] > combined["doc3"] + + +class TestCombineSearchResults: + """Test cases for the main combine_search_results function.""" + + def test_combine_search_results_rrf_default(self): + """Test combining with RRF as default.""" + vector_scores = {"doc1": 0.9, "doc2": 0.7} + keyword_scores = {"doc1": 0.6, "doc3": 0.8} + + combined = Reranker.combine_search_results(vector_scores, keyword_scores) + + # Should default to RRF + assert len(combined) == 3 + assert all(score > 0 for score in combined.values()) + + def test_combine_search_results_rrf_explicit(self): + """Test combining with explicit RRF.""" + vector_scores = {"doc1": 0.9, "doc2": 0.7} + keyword_scores = {"doc1": 0.6, "doc3": 0.8} + + combined = Reranker.combine_search_results( + vector_scores, keyword_scores, reranker_type=RERANKER_TYPE_RRF, reranker_params={"impact_factor": 50.0} + ) + + assert len(combined) == 3 + assert all(score > 0 for score in combined.values()) + + def test_combine_search_results_weighted(self): + """Test combining with weighted reranking.""" + vector_scores = {"doc1": 0.9, "doc2": 0.7} + keyword_scores = {"doc1": 0.6, "doc3": 0.8} + + combined = Reranker.combine_search_results( + vector_scores, keyword_scores, reranker_type=RERANKER_TYPE_WEIGHTED, reranker_params={"alpha": 0.3} + ) + + assert len(combined) == 3 + assert all(0 <= score <= 1 for score in combined.values()) + + def test_combine_search_results_unknown_type(self): + """Test combining with unknown reranker type defaults to RRF.""" + vector_scores = {"doc1": 0.9} + keyword_scores = {"doc2": 0.8} + + combined = Reranker.combine_search_results(vector_scores, keyword_scores, reranker_type="unknown_type") + + # Should fall back to RRF + assert len(combined) == 2 + assert all(score > 0 for score in combined.values()) + + def test_combine_search_results_empty_params(self): + """Test combining with empty parameters.""" + vector_scores = {"doc1": 0.9} + keyword_scores = {"doc2": 0.8} + + combined = Reranker.combine_search_results(vector_scores, keyword_scores, reranker_params={}) + + # Should use default parameters + assert len(combined) == 2 + assert all(score > 0 for score in combined.values()) + + def test_combine_search_results_empty_scores(self): + """Test combining with empty score dictionaries.""" + # Test with empty vector scores + combined = Reranker.combine_search_results({}, {"doc1": 0.8}) + assert len(combined) == 1 + assert combined["doc1"] > 0 + + # Test with empty keyword scores + combined = Reranker.combine_search_results({"doc1": 0.9}, {}) + assert len(combined) == 1 + assert combined["doc1"] > 0 + + # Test with both empty + combined = Reranker.combine_search_results({}, {}) + assert len(combined) == 0 diff --git a/tests/unit/providers/vector_io/conftest.py b/tests/unit/providers/vector_io/conftest.py index f71073651..91bddd037 100644 --- a/tests/unit/providers/vector_io/conftest.py +++ b/tests/unit/providers/vector_io/conftest.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import random +from unittest.mock import AsyncMock, MagicMock, patch import numpy as np import pytest @@ -12,7 +13,7 @@ from chromadb import PersistentClient from pymilvus import MilvusClient, connections from llama_stack.apis.vector_dbs import VectorDB -from llama_stack.apis.vector_io import Chunk, ChunkMetadata +from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse from llama_stack.providers.inline.vector_io.chroma.config import ChromaVectorIOConfig from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter @@ -22,6 +23,8 @@ from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConf from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter, maybe_await from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter +from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig +from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter EMBEDDING_DIMENSION = 384 @@ -29,7 +32,7 @@ COLLECTION_PREFIX = "test_collection" MILVUS_ALIAS = "test_milvus" -@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma"]) +@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector"]) def vector_provider(request): return request.param @@ -333,15 +336,127 @@ async def qdrant_vec_index(qdrant_vec_db_path, embedding_dimension): await index.delete() +@pytest.fixture +def mock_psycopg2_connection(): + connection = MagicMock() + cursor = MagicMock() + + cursor.__enter__ = MagicMock(return_value=cursor) + cursor.__exit__ = MagicMock() + + connection.cursor.return_value = cursor + + return connection, cursor + + +@pytest.fixture +async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection): + connection, cursor = mock_psycopg2_connection + + vector_db = VectorDB( + identifier="test-vector-db", + embedding_model="test-model", + embedding_dimension=embedding_dimension, + provider_id="pgvector", + provider_resource_id="pgvector:test-vector-db", + ) + + with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"): + with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.execute_values"): + index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE") + index._test_chunks = [] + original_add_chunks = index.add_chunks + + async def mock_add_chunks(chunks, embeddings): + index._test_chunks = list(chunks) + await original_add_chunks(chunks, embeddings) + + index.add_chunks = mock_add_chunks + + async def mock_query_vector(embedding, k, score_threshold): + chunks = index._test_chunks[:k] if hasattr(index, "_test_chunks") else [] + scores = [1.0] * len(chunks) + return QueryChunksResponse(chunks=chunks, scores=scores) + + index.query_vector = mock_query_vector + + yield index + + +@pytest.fixture +async def pgvector_vec_adapter(mock_inference_api, embedding_dimension): + config = PGVectorVectorIOConfig( + host="localhost", + port=5432, + db="test_db", + user="test_user", + password="test_password", + kvstore=SqliteKVStoreConfig(), + ) + + adapter = PGVectorVectorIOAdapter(config, mock_inference_api, None) + + with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2.connect") as mock_connect: + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.__enter__ = MagicMock(return_value=mock_cursor) + mock_cursor.__exit__ = MagicMock() + mock_conn.cursor.return_value = mock_cursor + mock_conn.autocommit = True + mock_connect.return_value = mock_conn + + with patch( + "llama_stack.providers.remote.vector_io.pgvector.pgvector.check_extension_version" + ) as mock_check_version: + mock_check_version.return_value = "0.5.1" + + with patch("llama_stack.providers.utils.kvstore.kvstore_impl") as mock_kvstore_impl: + mock_kvstore = AsyncMock() + mock_kvstore_impl.return_value = mock_kvstore + + with patch.object(adapter, "initialize_openai_vector_stores", new_callable=AsyncMock): + with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.upsert_models"): + await adapter.initialize() + adapter.conn = mock_conn + + async def mock_insert_chunks(vector_db_id, chunks, ttl_seconds=None): + index = await adapter._get_and_cache_vector_db_index(vector_db_id) + if not index: + raise ValueError(f"Vector DB {vector_db_id} not found") + await index.insert_chunks(chunks) + + adapter.insert_chunks = mock_insert_chunks + + async def mock_query_chunks(vector_db_id, query, params=None): + index = await adapter._get_and_cache_vector_db_index(vector_db_id) + if not index: + raise ValueError(f"Vector DB {vector_db_id} not found") + return await index.query_chunks(query, params) + + adapter.query_chunks = mock_query_chunks + + test_vector_db = VectorDB( + identifier=f"pgvector_test_collection_{random.randint(1, 1_000_000)}", + provider_id="test_provider", + embedding_model="test_model", + embedding_dimension=embedding_dimension, + ) + await adapter.register_vector_db(test_vector_db) + adapter.test_collection_id = test_vector_db.identifier + + yield adapter + await adapter.shutdown() + + @pytest.fixture def vector_io_adapter(vector_provider, request): - """Returns the appropriate vector IO adapter based on the provider parameter.""" vector_provider_dict = { "milvus": "milvus_vec_adapter", "faiss": "faiss_vec_adapter", "sqlite_vec": "sqlite_vec_adapter", "chroma": "chroma_vec_adapter", "qdrant": "qdrant_vec_adapter", + "pgvector": "pgvector_vec_adapter", } return request.getfixturevalue(vector_provider_dict[vector_provider]) diff --git a/tests/unit/providers/vector_io/remote/test_pgvector.py b/tests/unit/providers/vector_io/remote/test_pgvector.py new file mode 100644 index 000000000..006074f98 --- /dev/null +++ b/tests/unit/providers/vector_io/remote/test_pgvector.py @@ -0,0 +1,502 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import numpy as np +import pytest + +from llama_stack.apis.vector_dbs import VectorDB +from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse +from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig +from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter +from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion + +PGVECTOR_PROVIDER = "pgvector" + + +@pytest.fixture(scope="session") +def loop(): + return asyncio.new_event_loop() + + +@pytest.fixture +def embedding_dimension(): + """Default embedding dimension for tests.""" + return 384 + + +@pytest.fixture +def mock_psycopg2_connection(): + """Create a mock psycopg2 connection for testing.""" + connection = MagicMock() + cursor = MagicMock() + + # Mock the cursor context manager + cursor.__enter__ = MagicMock(return_value=cursor) + cursor.__exit__ = MagicMock() + + # Mock connection cursor method + connection.cursor.return_value = cursor + + return connection, cursor + + +@pytest.fixture +async def pgvector_index(embedding_dimension, mock_psycopg2_connection): + """Create a PGVectorIndex instance with mocked database connection.""" + connection, cursor = mock_psycopg2_connection + + vector_db = VectorDB( + identifier="test-vector-db", + embedding_model="test-model", + embedding_dimension=embedding_dimension, + provider_id=PGVECTOR_PROVIDER, + provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db", + ) + + with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"): + # Use explicit COSINE distance metric for consistent testing + index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE") + + return index, cursor + + +def create_sample_chunks(): + """Create sample chunks for testing.""" + return [ + Chunk( + content="Machine learning is a subset of artificial intelligence", + metadata={"document_id": "doc-1", "topic": "AI"}, + chunk_metadata=ChunkMetadata(document_id="doc-1", chunk_id="chunk-1"), + ), + Chunk( + content="Deep learning uses neural networks with multiple layers", + metadata={"document_id": "doc-2", "topic": "AI"}, + chunk_metadata=ChunkMetadata(document_id="doc-2", chunk_id="chunk-2"), + ), + Chunk( + content="Natural language processing enables computers to understand text", + metadata={"document_id": "doc-3", "topic": "NLP"}, + chunk_metadata=ChunkMetadata(document_id="doc-3", chunk_id="chunk-3"), + ), + ] + + +def create_sample_embeddings(num_chunks, dimension=384): + """Create sample embeddings for testing.""" + np.random.seed(42) + return np.array([np.random.rand(dimension).astype(np.float32) for _ in range(num_chunks)]) + + +class TestPGVectorIndex: + def test_distance_metric_validation(self, embedding_dimension, mock_psycopg2_connection): + connection, cursor = mock_psycopg2_connection + + vector_db = VectorDB( + identifier="test-vector-db", + embedding_model="test-model", + embedding_dimension=embedding_dimension, + provider_id=PGVECTOR_PROVIDER, + provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db", + ) + + with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"): + index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="L2") + assert index.distance_metric == "L2" + with pytest.raises(ValueError, match="Distance metric 'INVALID' is not supported"): + PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="INVALID") + + def test_get_pgvector_search_operator(self, pgvector_index): + index, cursor = pgvector_index + + assert index.get_pgvector_search_operator() == "<=>" + + index.distance_metric = "L2" + assert index.get_pgvector_search_operator() == "<->" + + index.distance_metric = "L1" + assert index.get_pgvector_search_operator() == "<+>" + + index.distance_metric = "INNER_PRODUCT" + assert index.get_pgvector_search_operator() == "<#>" + + index.distance_metric = "HAMMING" + assert index.get_pgvector_search_operator() == "<~>" + + index.distance_metric = "JACCARD" + assert index.get_pgvector_search_operator() == "<%>" + + def test_check_distance_metric_availability(self, pgvector_index): + index, cursor = pgvector_index + index.check_distance_metric_availability("COSINE") + index.check_distance_metric_availability("L2") + index.check_distance_metric_availability("L1") + index.check_distance_metric_availability("INNER_PRODUCT") + index.check_distance_metric_availability("HAMMING") + index.check_distance_metric_availability("JACCARD") + + with pytest.raises(ValueError, match="Distance metric 'INVALID' is not supported"): + index.check_distance_metric_availability("INVALID") + + async def test_add_chunks(self, pgvector_index, embedding_dimension): + index, cursor = pgvector_index + chunks = create_sample_chunks() + embeddings = create_sample_embeddings(len(chunks), embedding_dimension) + + with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.execute_values") as mock_execute_values: + await index.add_chunks(chunks, embeddings) + + assert mock_execute_values.called + call_args = mock_execute_values.call_args + + query_arg = str(call_args[0][1]) + assert "INSERT INTO" in query_arg + assert "content_tsvector" in query_arg + + async def test_query_vector(self, pgvector_index, embedding_dimension): + index, cursor = pgvector_index + query_embedding = create_sample_embeddings(1, embedding_dimension)[0] + + mock_results = [ + ({"content": "test content", "metadata": {}}, 0.1), + ({"content": "test content 2", "metadata": {}}, 0.2), + ] + cursor.fetchall.return_value = mock_results + + response = await index.query_vector(query_embedding, k=2, score_threshold=0.0) + + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) == 2 + assert len(response.scores) == 2 + + call_args = cursor.execute.call_args + assert "<=>" in str(call_args) or "ORDER BY" in str(call_args) + + async def test_query_vector_different_metrics(self, pgvector_index, embedding_dimension): + index, cursor = pgvector_index + query_embedding = create_sample_embeddings(1, embedding_dimension)[0] + + mock_results = [ + ({"content": "test content", "metadata": {}}, 0.1), + ] + cursor.fetchall.return_value = mock_results + + # Test L2 distance + index.distance_metric = "L2" + await index.query_vector(query_embedding, k=1, score_threshold=0.0) + call_args = cursor.execute.call_args + assert "<->" in str(call_args[0][0]) # L2 operator + + # Test L1 distance + index.distance_metric = "L1" + await index.query_vector(query_embedding, k=1, score_threshold=0.0) + call_args = cursor.execute.call_args + assert "<+>" in str(call_args[0][0]) # L1 operator + + # Test INNER_PRODUCT distance + index.distance_metric = "INNER_PRODUCT" + await index.query_vector(query_embedding, k=1, score_threshold=0.0) + call_args = cursor.execute.call_args + assert "<#>" in str(call_args[0][0]) # Inner product operator + + # Test Hamming distance + index.distance_metric = "HAMMING" + await index.query_vector(query_embedding, k=1, score_threshold=0.0) + call_args = cursor.execute.call_args + assert "<~>" in str(call_args[0][0]) # Hamming operator + + # Test Jaccard distance + index.distance_metric = "JACCARD" + await index.query_vector(query_embedding, k=1, score_threshold=0.0) + call_args = cursor.execute.call_args + assert "<%>" in str(call_args[0][0]) # Jaccard operator + + async def test_query_keyword(self, pgvector_index): + index, cursor = pgvector_index + query_string = "machine learning" + + mock_results = [ + ({"content": "Machine learning is great", "metadata": {}}, 0.8), + ({"content": "Learning machines are useful", "metadata": {}}, 0.6), + ] + cursor.fetchall.return_value = mock_results + + response = await index.query_keyword(query_string, k=2, score_threshold=0.0) + + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) == 2 + assert len(response.scores) == 2 + + call_args = cursor.execute.call_args + assert "ts_rank" in str(call_args) or "plainto_tsquery" in str(call_args) + + async def test_query_keyword_with_score_threshold(self, pgvector_index): + index, cursor = pgvector_index + query_string = "machine learning" + score_threshold = 0.7 + + # Mock database response with mixed scores + mock_results = [ + ({"content": "Machine learning is great", "metadata": {}}, 0.8), # Above threshold + ({"content": "Learning machines are useful", "metadata": {}}, 0.5), # Below threshold + ] + cursor.fetchall.return_value = mock_results + + response = await index.query_keyword(query_string, k=2, score_threshold=score_threshold) + + assert len(response.chunks) == 1 + assert response.scores[0] >= score_threshold + + async def test_query_hybrid_rrf(self, pgvector_index, embedding_dimension): + index, cursor = pgvector_index + query_embedding = create_sample_embeddings(1, embedding_dimension)[0] + query_string = "machine learning" + + # Mock responses for both vector and keyword searches + vector_results = [ + ({"content": "Vector result 1", "metadata": {}, "chunk_id": "chunk-1"}, 0.9), + ({"content": "Vector result 2", "metadata": {}, "chunk_id": "chunk-2"}, 0.7), + ] + keyword_results = [ + ({"content": "Keyword result 1", "metadata": {}, "chunk_id": "chunk-3"}, 0.8), + ({"content": "Vector result 1", "metadata": {}, "chunk_id": "chunk-1"}, 0.6), # Overlap + ] + + cursor.fetchall.side_effect = [vector_results, keyword_results] + + response = await index.query_hybrid( + query_embedding, + query_string, + k=3, + score_threshold=0.0, + reranker_type="rrf", + reranker_params={"impact_factor": 60.0}, + ) + + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) >= 1 # At least the overlapping chunk + assert cursor.execute.call_count >= 2 + + async def test_query_hybrid_weighted(self, pgvector_index, embedding_dimension): + index, cursor = pgvector_index + query_embedding = create_sample_embeddings(1, embedding_dimension)[0] + query_string = "machine learning" + + vector_results = [ + ({"content": "Vector result", "metadata": {}, "chunk_id": "chunk-1"}, 0.9), + ] + keyword_results = [ + ({"content": "Keyword result", "metadata": {}, "chunk_id": "chunk-2"}, 0.8), + ] + + cursor.fetchall.side_effect = [vector_results, keyword_results] + + response = await index.query_hybrid( + query_embedding, + query_string, + k=2, + score_threshold=0.0, + reranker_type="weighted", + reranker_params={"alpha": 0.7}, + ) + + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) == 2 + + def test_constructor_invalid_distance_metric(self, embedding_dimension, mock_psycopg2_connection): + connection, cursor = mock_psycopg2_connection + + vector_db = VectorDB( + identifier="test-vector-db", + embedding_model="test-model", + embedding_dimension=embedding_dimension, + provider_id=PGVECTOR_PROVIDER, + provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db", + ) + + with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"): + with pytest.raises(ValueError, match="Distance metric 'INVALID_METRIC' is not supported by PGVector"): + PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="INVALID_METRIC") + + with pytest.raises(ValueError, match="Supported metrics are:"): + PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="UNKNOWN") + + try: + index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE") + assert index.distance_metric == "COSINE" + except ValueError: + pytest.fail("Valid distance metric 'COSINE' should not raise ValueError") + + def test_constructor_all_supported_distance_metrics(self, embedding_dimension, mock_psycopg2_connection): + connection, cursor = mock_psycopg2_connection + + vector_db = VectorDB( + identifier="test-vector-db", + embedding_model="test-model", + embedding_dimension=embedding_dimension, + provider_id=PGVECTOR_PROVIDER, + provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db", + ) + + supported_metrics = ["L2", "L1", "COSINE", "INNER_PRODUCT", "HAMMING", "JACCARD"] + + with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"): + for metric in supported_metrics: + try: + index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric=metric) + assert index.distance_metric == metric + + expected_operators = { + "L2": "<->", + "L1": "<+>", + "COSINE": "<=>", + "INNER_PRODUCT": "<#>", + "HAMMING": "<~>", + "JACCARD": "<%>", + } + assert index.get_pgvector_search_operator() == expected_operators[metric] + except Exception as e: + pytest.fail(f"Valid distance metric '{metric}' should not raise exception: {e}") + + def test_error_handling_in_constructor(self, embedding_dimension): + vector_db = VectorDB( + identifier="test-vector-db", + embedding_model="test-model", + embedding_dimension=embedding_dimension, + provider_id=PGVECTOR_PROVIDER, + provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db", + ) + + with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"): + mock_connection = MagicMock() + + mock_cursor_context = MagicMock() + mock_cursor_context.__enter__.side_effect = Exception("Database connection failed") + mock_cursor_context.__exit__ = MagicMock() + + mock_connection.cursor.return_value = mock_cursor_context + + with pytest.raises(RuntimeError, match="Error creating PGVectorIndex"): + PGVectorIndex(vector_db, embedding_dimension, mock_connection, distance_metric="COSINE") + + async def test_delete_chunks(self, pgvector_index): + index, cursor = pgvector_index + + chunks_for_deletion = [ + ChunkForDeletion(chunk_id="test-chunk-1", document_id="doc-1"), + ChunkForDeletion(chunk_id="test-chunk-2", document_id="doc-2"), + ] + + await index.delete_chunks(chunks_for_deletion) + + cursor.execute.assert_called() + call_args = cursor.execute.call_args + assert "DELETE FROM" in str(call_args) + assert "test-chunk-1" in str(call_args) or "test-chunk-2" in str(call_args) + + async def test_delete_index(self, pgvector_index): + """Test deleting the entire index.""" + index, cursor = pgvector_index + + await index.delete() + + cursor.execute.assert_called_with(f"DROP TABLE IF EXISTS {index.table_name}") + + +class TestPGVectorVectorIOAdapter: + @pytest.fixture + async def pgvector_adapter(self): + config = PGVectorVectorIOConfig( + host="localhost", + port=5432, + db="test_db", + user="test_user", + password="test_password", + kvstore={"type": "sqlite", "config": {"db_path": ":memory:"}}, + ) + + inference_api = AsyncMock() + files_api = AsyncMock() + + adapter = PGVectorVectorIOAdapter(config, inference_api, files_api) + + with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2.connect") as mock_connect: + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.__enter__ = MagicMock(return_value=mock_cursor) + mock_cursor.__exit__ = MagicMock() + mock_conn.cursor.return_value = mock_cursor + mock_connect.return_value = mock_conn + + with patch.object(adapter, "kvstore") as mock_kvstore: + mock_kvstore.set = AsyncMock() + mock_kvstore.get = AsyncMock() + + yield adapter, mock_conn, mock_cursor + + async def test_initialization(self, pgvector_adapter): + adapter, mock_conn, mock_cursor = pgvector_adapter + + with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.kvstore_impl") as mock_kvstore_impl: + mock_kvstore = AsyncMock() + mock_kvstore_impl.return_value = mock_kvstore + + mock_cursor.fetchone.return_value = ["0.5.0"] + + await adapter.initialize() + + assert adapter.conn is not None + mock_cursor.execute.assert_called() + + assert adapter.metadata_collection_name == "openai_vector_stores_metadata" + + async def test_register_vector_db(self, pgvector_adapter): + adapter, mock_conn, mock_cursor = pgvector_adapter + + vector_db = VectorDB( + identifier="test-db", + embedding_model="test-model", + embedding_dimension=384, + provider_id=PGVECTOR_PROVIDER, + provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db", + ) + + adapter.kvstore = AsyncMock() + adapter.conn = mock_conn + + with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.upsert_models"): + await adapter.register_vector_db(vector_db) + + assert "test-db" in adapter.cache + + async def test_insert_chunks(self, pgvector_adapter): + adapter, mock_conn, mock_cursor = pgvector_adapter + + adapter.conn = mock_conn + chunks = create_sample_chunks() + + mock_index = AsyncMock() + adapter.cache["test-db"] = mock_index + + await adapter.insert_chunks("test-db", chunks) + mock_index.insert_chunks.assert_called_once_with(chunks) + + async def test_delete_chunks(self, pgvector_adapter): + adapter, mock_conn, mock_cursor = pgvector_adapter + adapter.conn = mock_conn + mock_index = AsyncMock() + mock_index.index = AsyncMock() + adapter.cache["test-db"] = mock_index + + chunks_for_deletion = [ + ChunkForDeletion(chunk_id="chunk-1", document_id="doc-1"), + ChunkForDeletion(chunk_id="chunk-2", document_id="doc-2"), + ] + await adapter.delete_chunks("test-db", chunks_for_deletion) + + mock_index.index.delete_chunks.assert_called_once_with(chunks_for_deletion) diff --git a/uv.lock b/uv.lock index 4c56816ef..db056f423 100644 --- a/uv.lock +++ b/uv.lock @@ -1810,6 +1810,7 @@ test = [ { name = "datasets" }, { name = "mcp" }, { name = "openai" }, + { name = "psycopg2-binary" }, { name = "pymilvus" }, { name = "pypdf" }, { name = "requests" }, @@ -1833,6 +1834,7 @@ unit = [ { name = "mcp" }, { name = "ollama" }, { name = "openai" }, + { name = "psycopg2-binary" }, { name = "pymilvus" }, { name = "pypdf" }, { name = "qdrant-client" }, @@ -1926,6 +1928,7 @@ test = [ { name = "datasets" }, { name = "mcp" }, { name = "openai" }, + { name = "psycopg2-binary", specifier = ">=2.9.0" }, { name = "pymilvus", specifier = ">=2.5.12" }, { name = "pypdf" }, { name = "requests" }, @@ -1948,6 +1951,7 @@ unit = [ { name = "mcp" }, { name = "ollama" }, { name = "openai" }, + { name = "psycopg2-binary", specifier = ">=2.9.0" }, { name = "pymilvus", specifier = ">=2.5.12" }, { name = "pypdf" }, { name = "qdrant-client" }, @@ -3058,6 +3062,37 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/50/1b/6921afe68c74868b4c9fa424dad3be35b095e16687989ebbb50ce4fceb7c/psutil-7.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:4cf3d4eb1aa9b348dec30105c55cd9b7d4629285735a102beb4441e38db90553", size = 244885, upload-time = "2025-02-13T21:54:37.486Z" }, ] +[[package]] +name = "psycopg2-binary" +version = "2.9.10" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cb/0e/bdc8274dc0585090b4e3432267d7be4dfbfd8971c0fa59167c711105a6bf/psycopg2-binary-2.9.10.tar.gz", hash = "sha256:4b3df0e6990aa98acda57d983942eff13d824135fe2250e6522edaa782a06de2", size = 385764, upload-time = "2024-10-16T11:24:58.126Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/49/7d/465cc9795cf76f6d329efdafca74693714556ea3891813701ac1fee87545/psycopg2_binary-2.9.10-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:880845dfe1f85d9d5f7c412efea7a08946a46894537e4e5d091732eb1d34d9a0", size = 3044771, upload-time = "2024-10-16T11:20:35.234Z" }, + { url = "https://files.pythonhosted.org/packages/8b/31/6d225b7b641a1a2148e3ed65e1aa74fc86ba3fee850545e27be9e1de893d/psycopg2_binary-2.9.10-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:9440fa522a79356aaa482aa4ba500b65f28e5d0e63b801abf6aa152a29bd842a", size = 3275336, upload-time = "2024-10-16T11:20:38.742Z" }, + { url = "https://files.pythonhosted.org/packages/30/b7/a68c2b4bff1cbb1728e3ec864b2d92327c77ad52edcd27922535a8366f68/psycopg2_binary-2.9.10-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e3923c1d9870c49a2d44f795df0c889a22380d36ef92440ff618ec315757e539", size = 2851637, upload-time = "2024-10-16T11:20:42.145Z" }, + { url = "https://files.pythonhosted.org/packages/0b/b1/cfedc0e0e6f9ad61f8657fd173b2f831ce261c02a08c0b09c652b127d813/psycopg2_binary-2.9.10-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7b2c956c028ea5de47ff3a8d6b3cc3330ab45cf0b7c3da35a2d6ff8420896526", size = 3082097, upload-time = "2024-10-16T11:20:46.185Z" }, + { url = "https://files.pythonhosted.org/packages/18/ed/0a8e4153c9b769f59c02fb5e7914f20f0b2483a19dae7bf2db54b743d0d0/psycopg2_binary-2.9.10-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f758ed67cab30b9a8d2833609513ce4d3bd027641673d4ebc9c067e4d208eec1", size = 3264776, upload-time = "2024-10-16T11:20:50.879Z" }, + { url = "https://files.pythonhosted.org/packages/10/db/d09da68c6a0cdab41566b74e0a6068a425f077169bed0946559b7348ebe9/psycopg2_binary-2.9.10-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8cd9b4f2cfab88ed4a9106192de509464b75a906462fb846b936eabe45c2063e", size = 3020968, upload-time = "2024-10-16T11:20:56.819Z" }, + { url = "https://files.pythonhosted.org/packages/94/28/4d6f8c255f0dfffb410db2b3f9ac5218d959a66c715c34cac31081e19b95/psycopg2_binary-2.9.10-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6dc08420625b5a20b53551c50deae6e231e6371194fa0651dbe0fb206452ae1f", size = 2872334, upload-time = "2024-10-16T11:21:02.411Z" }, + { url = "https://files.pythonhosted.org/packages/05/f7/20d7bf796593c4fea95e12119d6cc384ff1f6141a24fbb7df5a668d29d29/psycopg2_binary-2.9.10-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:d7cd730dfa7c36dbe8724426bf5612798734bff2d3c3857f36f2733f5bfc7c00", size = 2822722, upload-time = "2024-10-16T11:21:09.01Z" }, + { url = "https://files.pythonhosted.org/packages/4d/e4/0c407ae919ef626dbdb32835a03b6737013c3cc7240169843965cada2bdf/psycopg2_binary-2.9.10-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:155e69561d54d02b3c3209545fb08938e27889ff5a10c19de8d23eb5a41be8a5", size = 2920132, upload-time = "2024-10-16T11:21:16.339Z" }, + { url = "https://files.pythonhosted.org/packages/2d/70/aa69c9f69cf09a01da224909ff6ce8b68faeef476f00f7ec377e8f03be70/psycopg2_binary-2.9.10-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c3cc28a6fd5a4a26224007712e79b81dbaee2ffb90ff406256158ec4d7b52b47", size = 2959312, upload-time = "2024-10-16T11:21:25.584Z" }, + { url = "https://files.pythonhosted.org/packages/d3/bd/213e59854fafe87ba47814bf413ace0dcee33a89c8c8c814faca6bc7cf3c/psycopg2_binary-2.9.10-cp312-cp312-win32.whl", hash = "sha256:ec8a77f521a17506a24a5f626cb2aee7850f9b69a0afe704586f63a464f3cd64", size = 1025191, upload-time = "2024-10-16T11:21:29.912Z" }, + { url = "https://files.pythonhosted.org/packages/92/29/06261ea000e2dc1e22907dbbc483a1093665509ea586b29b8986a0e56733/psycopg2_binary-2.9.10-cp312-cp312-win_amd64.whl", hash = "sha256:18c5ee682b9c6dd3696dad6e54cc7ff3a1a9020df6a5c0f861ef8bfd338c3ca0", size = 1164031, upload-time = "2024-10-16T11:21:34.211Z" }, + { url = "https://files.pythonhosted.org/packages/3e/30/d41d3ba765609c0763505d565c4d12d8f3c79793f0d0f044ff5a28bf395b/psycopg2_binary-2.9.10-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:26540d4a9a4e2b096f1ff9cce51253d0504dca5a85872c7f7be23be5a53eb18d", size = 3044699, upload-time = "2024-10-16T11:21:42.841Z" }, + { url = "https://files.pythonhosted.org/packages/35/44/257ddadec7ef04536ba71af6bc6a75ec05c5343004a7ec93006bee66c0bc/psycopg2_binary-2.9.10-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:e217ce4d37667df0bc1c397fdcd8de5e81018ef305aed9415c3b093faaeb10fb", size = 3275245, upload-time = "2024-10-16T11:21:51.989Z" }, + { url = "https://files.pythonhosted.org/packages/1b/11/48ea1cd11de67f9efd7262085588790a95d9dfcd9b8a687d46caf7305c1a/psycopg2_binary-2.9.10-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:245159e7ab20a71d989da00f280ca57da7641fa2cdcf71749c193cea540a74f7", size = 2851631, upload-time = "2024-10-16T11:21:57.584Z" }, + { url = "https://files.pythonhosted.org/packages/62/e0/62ce5ee650e6c86719d621a761fe4bc846ab9eff8c1f12b1ed5741bf1c9b/psycopg2_binary-2.9.10-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3c4ded1a24b20021ebe677b7b08ad10bf09aac197d6943bfe6fec70ac4e4690d", size = 3082140, upload-time = "2024-10-16T11:22:02.005Z" }, + { url = "https://files.pythonhosted.org/packages/27/ce/63f946c098611f7be234c0dd7cb1ad68b0b5744d34f68062bb3c5aa510c8/psycopg2_binary-2.9.10-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3abb691ff9e57d4a93355f60d4f4c1dd2d68326c968e7db17ea96df3c023ef73", size = 3264762, upload-time = "2024-10-16T11:22:06.412Z" }, + { url = "https://files.pythonhosted.org/packages/43/25/c603cd81402e69edf7daa59b1602bd41eb9859e2824b8c0855d748366ac9/psycopg2_binary-2.9.10-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8608c078134f0b3cbd9f89b34bd60a943b23fd33cc5f065e8d5f840061bd0673", size = 3020967, upload-time = "2024-10-16T11:22:11.583Z" }, + { url = "https://files.pythonhosted.org/packages/5f/d6/8708d8c6fca531057fa170cdde8df870e8b6a9b136e82b361c65e42b841e/psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:230eeae2d71594103cd5b93fd29d1ace6420d0b86f4778739cb1a5a32f607d1f", size = 2872326, upload-time = "2024-10-16T11:22:16.406Z" }, + { url = "https://files.pythonhosted.org/packages/ce/ac/5b1ea50fc08a9df82de7e1771537557f07c2632231bbab652c7e22597908/psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:bb89f0a835bcfc1d42ccd5f41f04870c1b936d8507c6df12b7737febc40f0909", size = 2822712, upload-time = "2024-10-16T11:22:21.366Z" }, + { url = "https://files.pythonhosted.org/packages/c4/fc/504d4503b2abc4570fac3ca56eb8fed5e437bf9c9ef13f36b6621db8ef00/psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f0c2d907a1e102526dd2986df638343388b94c33860ff3bbe1384130828714b1", size = 2920155, upload-time = "2024-10-16T11:22:25.684Z" }, + { url = "https://files.pythonhosted.org/packages/b2/d1/323581e9273ad2c0dbd1902f3fb50c441da86e894b6e25a73c3fda32c57e/psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f8157bed2f51db683f31306aa497311b560f2265998122abe1dce6428bd86567", size = 2959356, upload-time = "2024-10-16T11:22:30.562Z" }, + { url = "https://files.pythonhosted.org/packages/08/50/d13ea0a054189ae1bc21af1d85b6f8bb9bbc5572991055d70ad9006fe2d6/psycopg2_binary-2.9.10-cp313-cp313-win_amd64.whl", hash = "sha256:27422aa5f11fbcd9b18da48373eb67081243662f9b46e6fd07c3eb46e4535142", size = 2569224, upload-time = "2025-01-04T20:09:19.234Z" }, +] + [[package]] name = "ptyprocess" version = "0.7.0"