feat: implement keyword, vector and hybrid search inside vector stores for PGVector provider (#3064)

# What does this PR do?
<!-- Provide a short summary of what this PR does and why. Link to
relevant issues if applicable. -->
The purpose of this task is to implement
`openai/v1/vector_stores/{vector_store_id}/search` for PGVector
provider. It involves implementing vector similarity search, keyword
search and hybrid search for `PGVectorIndex`.

<!-- If resolving an issue, uncomment and update the line below -->
<!-- Closes #[issue-number] -->
Closes #3006 

## Test Plan
<!-- Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.* -->
Run unit tests:
` ./scripts/unit-tests.sh `

Run integration tests for openai vector stores:
1. Export env vars:
```
export ENABLE_PGVECTOR=true
export PGVECTOR_HOST=localhost
export PGVECTOR_PORT=5432
export PGVECTOR_DB=llamastack
export PGVECTOR_USER=llamastack
export PGVECTOR_PASSWORD=llamastack
```

2. Create DB:
```
psql -h localhost -U postgres -c "CREATE ROLE llamastack LOGIN PASSWORD 'llamastack';"
psql -h localhost -U postgres -c "CREATE DATABASE llamastack OWNER llamastack;"
psql -h localhost -U llamastack -d llamastack -c "CREATE EXTENSION IF NOT EXISTS vector;"
```

3. Install sentence-transformers:
` uv pip install sentence-transformers  `

4. Run:
```
uv run --group test pytest -s -v --stack-config="inference=inline::sentence-transformers,vector_io=remote::pgvector" --embedding-model sentence-transformers/all-MiniLM-L6-v2 tests/integration/vector_io/test_openai_vector_stores.py
```
Inspect PGVector vector stores (optional):
```
psql llamastack                                                                                                         
psql (14.18 (Homebrew))
Type "help" for help.

llamastack=# \z
                                                    Access privileges
 Schema |                         Name                         | Type  | Access privileges | Column privileges | Policies 
--------+------------------------------------------------------+-------+-------------------+-------------------+----------
 public | llamastack_kvstore                                   | table |                   |                   | 
 public | metadata_store                                       | table |                   |                   | 
 public | vector_store_pgvector_main                           | table |                   |                   | 
 public | vector_store_vs_1dfbc061_1f4d_4497_9165_ecba2622ba3a | table |                   |                   | 
 public | vector_store_vs_2085a9fb_1822_4e42_a277_c6a685843fa7 | table |                   |                   | 
 public | vector_store_vs_2b3dae46_38be_462a_afd6_37ee5fe661b1 | table |                   |                   | 
 public | vector_store_vs_2f438de6_f606_4561_9d50_ef9160eb9060 | table |                   |                   | 
 public | vector_store_vs_3eeca564_2580_4c68_bfea_83dc57e31214 | table |                   |                   | 
 public | vector_store_vs_53942163_05f3_40e0_83c0_0997c64613da | table |                   |                   | 
 public | vector_store_vs_545bac75_8950_4ff1_b084_e221192d4709 | table |                   |                   | 
 public | vector_store_vs_688a37d8_35b2_4298_a035_bfedf5b21f86 | table |                   |                   | 
 public | vector_store_vs_70624d9a_f6ac_4c42_b8ab_0649473c6600 | table |                   |                   | 
 public | vector_store_vs_73fc1dd2_e942_4972_afb1_1e177b591ac2 | table |                   |                   | 
 public | vector_store_vs_9d464949_d51f_49db_9f87_e033b8b84ac9 | table |                   |                   | 
 public | vector_store_vs_a1e4d724_5162_4d6d_a6c0_bdafaf6b76ec | table |                   |                   | 
 public | vector_store_vs_a328fb1b_1a21_480f_9624_ffaa60fb6672 | table |                   |                   | 
 public | vector_store_vs_a8981bf0_2e66_4445_a267_a8fff442db53 | table |                   |                   | 
 public | vector_store_vs_ccd4b6a4_1efd_4984_ad03_e7ff8eadb296 | table |                   |                   | 
 public | vector_store_vs_cd6420a4_a1fc_4cec_948c_1413a26281c9 | table |                   |                   | 
 public | vector_store_vs_cd709284_e5cf_4a88_aba5_dc76a35364bd | table |                   |                   | 
 public | vector_store_vs_d7a4548e_fbc1_44d7_b2ec_b664417f2a46 | table |                   |                   | 
 public | vector_store_vs_e7f73231_414c_4523_886c_d1174eee836e | table |                   |                   | 
 public | vector_store_vs_ffd53588_819f_47e8_bb9d_954af6f7833d | table |                   |                   | 
(23 rows)

llamastack=# 
```

Co-authored-by: Francisco Arceo <arceofrancisco@gmail.com>
This commit is contained in:
IAN MILLER 2025-08-29 15:30:12 +01:00 committed by GitHub
parent e96e3c4da4
commit 3130ca0a78
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 1014 additions and 29 deletions

View file

@ -57,11 +57,13 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode
"inline::sqlite-vec",
"remote::milvus",
"inline::milvus",
"remote::pgvector",
],
"hybrid": [
"inline::sqlite-vec",
"inline::milvus",
"remote::milvus",
"remote::pgvector",
],
}
supported_providers = search_mode_support.get(search_mode, [])

View file

@ -0,0 +1,248 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.utils.memory.vector_store import RERANKER_TYPE_RRF, RERANKER_TYPE_WEIGHTED
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator
class TestNormalizeScores:
"""Test cases for score normalization."""
def test_normalize_scores_basic(self):
"""Test basic score normalization."""
scores = {"doc1": 10.0, "doc2": 5.0, "doc3": 0.0}
normalized = WeightedInMemoryAggregator._normalize_scores(scores)
assert normalized["doc1"] == 1.0 # Max score
assert normalized["doc3"] == 0.0 # Min score
assert normalized["doc2"] == 0.5 # Middle score
assert all(0 <= score <= 1 for score in normalized.values())
def test_normalize_scores_identical(self):
"""Test normalization when all scores are identical."""
scores = {"doc1": 5.0, "doc2": 5.0, "doc3": 5.0}
normalized = WeightedInMemoryAggregator._normalize_scores(scores)
# All scores should be 1.0 when identical
assert all(score == 1.0 for score in normalized.values())
def test_normalize_scores_empty(self):
"""Test normalization with empty scores."""
scores = {}
normalized = WeightedInMemoryAggregator._normalize_scores(scores)
assert normalized == {}
def test_normalize_scores_single(self):
"""Test normalization with single score."""
scores = {"doc1": 7.5}
normalized = WeightedInMemoryAggregator._normalize_scores(scores)
assert normalized["doc1"] == 1.0
class TestWeightedRerank:
"""Test cases for weighted reranking."""
def test_weighted_rerank_basic(self):
"""Test basic weighted reranking."""
vector_scores = {"doc1": 0.9, "doc2": 0.7, "doc3": 0.5}
keyword_scores = {"doc1": 0.6, "doc2": 0.8, "doc4": 0.9}
combined = WeightedInMemoryAggregator.weighted_rerank(vector_scores, keyword_scores, alpha=0.5)
# Should include all documents
expected_docs = {"doc1", "doc2", "doc3", "doc4"}
assert set(combined.keys()) == expected_docs
# All scores should be between 0 and 1
assert all(0 <= score <= 1 for score in combined.values())
# doc1 appears in both searches, should have higher combined score
assert combined["doc1"] > 0
def test_weighted_rerank_alpha_zero(self):
"""Test weighted reranking with alpha=0 (keyword only)."""
vector_scores = {"doc1": 0.9, "doc2": 0.7, "doc3": 0.5} # All docs present in vector
keyword_scores = {"doc1": 0.1, "doc2": 0.3, "doc3": 0.9} # All docs present in keyword
combined = WeightedInMemoryAggregator.weighted_rerank(vector_scores, keyword_scores, alpha=0.0)
# Alpha=0 means vector scores are ignored, keyword scores dominate
# doc3 should score highest since it has highest keyword score
assert combined["doc3"] > combined["doc2"] > combined["doc1"]
def test_weighted_rerank_alpha_one(self):
"""Test weighted reranking with alpha=1 (vector only)."""
vector_scores = {"doc1": 0.9, "doc2": 0.7, "doc3": 0.5} # All docs present in vector
keyword_scores = {"doc1": 0.1, "doc2": 0.3, "doc3": 0.9} # All docs present in keyword
combined = WeightedInMemoryAggregator.weighted_rerank(vector_scores, keyword_scores, alpha=1.0)
# Alpha=1 means keyword scores are ignored, vector scores dominate
# doc1 should score highest since it has highest vector score
assert combined["doc1"] > combined["doc2"] > combined["doc3"]
def test_weighted_rerank_no_overlap(self):
"""Test weighted reranking with no overlapping documents."""
vector_scores = {"doc1": 0.9, "doc2": 0.7}
keyword_scores = {"doc3": 0.8, "doc4": 0.6}
combined = WeightedInMemoryAggregator.weighted_rerank(vector_scores, keyword_scores, alpha=0.5)
assert len(combined) == 4
# With min-max normalization, lowest scoring docs in each group get 0.0
# but highest scoring docs should get positive scores
assert all(score >= 0 for score in combined.values())
assert combined["doc1"] > 0 # highest vector score
assert combined["doc3"] > 0 # highest keyword score
class TestRRFRerank:
"""Test cases for RRF (Reciprocal Rank Fusion) reranking."""
def test_rrf_rerank_basic(self):
"""Test basic RRF reranking."""
vector_scores = {"doc1": 0.9, "doc2": 0.7, "doc3": 0.5}
keyword_scores = {"doc1": 0.6, "doc2": 0.8, "doc4": 0.9}
combined = WeightedInMemoryAggregator.rrf_rerank(vector_scores, keyword_scores, impact_factor=60.0)
# Should include all documents
expected_docs = {"doc1", "doc2", "doc3", "doc4"}
assert set(combined.keys()) == expected_docs
# All scores should be positive
assert all(score > 0 for score in combined.values())
# Documents appearing in both searches should have higher scores
# doc1 and doc2 appear in both, doc3 and doc4 appear in only one
assert combined["doc1"] > combined["doc3"]
assert combined["doc2"] > combined["doc4"]
def test_rrf_rerank_rank_calculation(self):
"""Test that RRF correctly calculates ranks."""
# Create clear ranking order
vector_scores = {"doc1": 1.0, "doc2": 0.8, "doc3": 0.6} # Ranks: 1, 2, 3
keyword_scores = {"doc1": 0.5, "doc2": 1.0, "doc3": 0.7} # Ranks: 3, 1, 2
combined = WeightedInMemoryAggregator.rrf_rerank(vector_scores, keyword_scores, impact_factor=60.0)
# doc1: rank 1 in vector, rank 3 in keyword
# doc2: rank 2 in vector, rank 1 in keyword
# doc3: rank 3 in vector, rank 2 in keyword
# doc2 should have the highest combined score (ranks 2+1=3)
# followed by doc1 (ranks 1+3=4) and doc3 (ranks 3+2=5)
# Remember: lower rank sum = higher RRF score
assert combined["doc2"] > combined["doc1"] > combined["doc3"]
def test_rrf_rerank_impact_factor(self):
"""Test that impact factor affects RRF scores."""
vector_scores = {"doc1": 0.9, "doc2": 0.7}
keyword_scores = {"doc1": 0.8, "doc2": 0.6}
combined_low = WeightedInMemoryAggregator.rrf_rerank(vector_scores, keyword_scores, impact_factor=10.0)
combined_high = WeightedInMemoryAggregator.rrf_rerank(vector_scores, keyword_scores, impact_factor=100.0)
# Higher impact factor should generally result in lower scores
# (because 1/(k+r) decreases as k increases)
assert combined_low["doc1"] > combined_high["doc1"]
assert combined_low["doc2"] > combined_high["doc2"]
def test_rrf_rerank_missing_documents(self):
"""Test RRF handling of documents missing from one search."""
vector_scores = {"doc1": 0.9, "doc2": 0.7}
keyword_scores = {"doc1": 0.8, "doc3": 0.6}
combined = WeightedInMemoryAggregator.rrf_rerank(vector_scores, keyword_scores, impact_factor=60.0)
# Should include all documents
assert len(combined) == 3
# doc1 appears in both searches, should have highest score
assert combined["doc1"] > combined["doc2"]
assert combined["doc1"] > combined["doc3"]
class TestCombineSearchResults:
"""Test cases for the main combine_search_results function."""
def test_combine_search_results_rrf_default(self):
"""Test combining with RRF as default."""
vector_scores = {"doc1": 0.9, "doc2": 0.7}
keyword_scores = {"doc1": 0.6, "doc3": 0.8}
combined = WeightedInMemoryAggregator.combine_search_results(vector_scores, keyword_scores)
# Should default to RRF
assert len(combined) == 3
assert all(score > 0 for score in combined.values())
def test_combine_search_results_rrf_explicit(self):
"""Test combining with explicit RRF."""
vector_scores = {"doc1": 0.9, "doc2": 0.7}
keyword_scores = {"doc1": 0.6, "doc3": 0.8}
combined = WeightedInMemoryAggregator.combine_search_results(
vector_scores, keyword_scores, reranker_type=RERANKER_TYPE_RRF, reranker_params={"impact_factor": 50.0}
)
assert len(combined) == 3
assert all(score > 0 for score in combined.values())
def test_combine_search_results_weighted(self):
"""Test combining with weighted reranking."""
vector_scores = {"doc1": 0.9, "doc2": 0.7}
keyword_scores = {"doc1": 0.6, "doc3": 0.8}
combined = WeightedInMemoryAggregator.combine_search_results(
vector_scores, keyword_scores, reranker_type=RERANKER_TYPE_WEIGHTED, reranker_params={"alpha": 0.3}
)
assert len(combined) == 3
assert all(0 <= score <= 1 for score in combined.values())
def test_combine_search_results_unknown_type(self):
"""Test combining with unknown reranker type defaults to RRF."""
vector_scores = {"doc1": 0.9}
keyword_scores = {"doc2": 0.8}
combined = WeightedInMemoryAggregator.combine_search_results(
vector_scores, keyword_scores, reranker_type="unknown_type"
)
# Should fall back to RRF
assert len(combined) == 2
assert all(score > 0 for score in combined.values())
def test_combine_search_results_empty_params(self):
"""Test combining with empty parameters."""
vector_scores = {"doc1": 0.9}
keyword_scores = {"doc2": 0.8}
combined = WeightedInMemoryAggregator.combine_search_results(vector_scores, keyword_scores, reranker_params={})
# Should use default parameters
assert len(combined) == 2
assert all(score > 0 for score in combined.values())
def test_combine_search_results_empty_scores(self):
"""Test combining with empty score dictionaries."""
# Test with empty vector scores
combined = WeightedInMemoryAggregator.combine_search_results({}, {"doc1": 0.8})
assert len(combined) == 1
assert combined["doc1"] > 0
# Test with empty keyword scores
combined = WeightedInMemoryAggregator.combine_search_results({"doc1": 0.9}, {})
assert len(combined) == 1
assert combined["doc1"] > 0
# Test with both empty
combined = WeightedInMemoryAggregator.combine_search_results({}, {})
assert len(combined) == 0

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import random
from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
@ -12,7 +13,7 @@ from chromadb import PersistentClient
from pymilvus import MilvusClient, connections
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, ChunkMetadata
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
from llama_stack.providers.inline.vector_io.chroma.config import ChromaVectorIOConfig
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
@ -22,6 +23,8 @@ from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConf
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter, maybe_await
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
EMBEDDING_DIMENSION = 384
@ -29,7 +32,7 @@ COLLECTION_PREFIX = "test_collection"
MILVUS_ALIAS = "test_milvus"
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma"])
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector"])
def vector_provider(request):
return request.param
@ -333,15 +336,127 @@ async def qdrant_vec_index(qdrant_vec_db_path, embedding_dimension):
await index.delete()
@pytest.fixture
def mock_psycopg2_connection():
connection = MagicMock()
cursor = MagicMock()
cursor.__enter__ = MagicMock(return_value=cursor)
cursor.__exit__ = MagicMock()
connection.cursor.return_value = cursor
return connection, cursor
@pytest.fixture
async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection):
connection, cursor = mock_psycopg2_connection
vector_db = VectorDB(
identifier="test-vector-db",
embedding_model="test-model",
embedding_dimension=embedding_dimension,
provider_id="pgvector",
provider_resource_id="pgvector:test-vector-db",
)
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.execute_values"):
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE")
index._test_chunks = []
original_add_chunks = index.add_chunks
async def mock_add_chunks(chunks, embeddings):
index._test_chunks = list(chunks)
await original_add_chunks(chunks, embeddings)
index.add_chunks = mock_add_chunks
async def mock_query_vector(embedding, k, score_threshold):
chunks = index._test_chunks[:k] if hasattr(index, "_test_chunks") else []
scores = [1.0] * len(chunks)
return QueryChunksResponse(chunks=chunks, scores=scores)
index.query_vector = mock_query_vector
yield index
@pytest.fixture
async def pgvector_vec_adapter(mock_inference_api, embedding_dimension):
config = PGVectorVectorIOConfig(
host="localhost",
port=5432,
db="test_db",
user="test_user",
password="test_password",
kvstore=SqliteKVStoreConfig(),
)
adapter = PGVectorVectorIOAdapter(config, mock_inference_api, None)
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2.connect") as mock_connect:
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_cursor.__enter__ = MagicMock(return_value=mock_cursor)
mock_cursor.__exit__ = MagicMock()
mock_conn.cursor.return_value = mock_cursor
mock_conn.autocommit = True
mock_connect.return_value = mock_conn
with patch(
"llama_stack.providers.remote.vector_io.pgvector.pgvector.check_extension_version"
) as mock_check_version:
mock_check_version.return_value = "0.5.1"
with patch("llama_stack.providers.utils.kvstore.kvstore_impl") as mock_kvstore_impl:
mock_kvstore = AsyncMock()
mock_kvstore_impl.return_value = mock_kvstore
with patch.object(adapter, "initialize_openai_vector_stores", new_callable=AsyncMock):
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.upsert_models"):
await adapter.initialize()
adapter.conn = mock_conn
async def mock_insert_chunks(vector_db_id, chunks, ttl_seconds=None):
index = await adapter._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
await index.insert_chunks(chunks)
adapter.insert_chunks = mock_insert_chunks
async def mock_query_chunks(vector_db_id, query, params=None):
index = await adapter._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
return await index.query_chunks(query, params)
adapter.query_chunks = mock_query_chunks
test_vector_db = VectorDB(
identifier=f"pgvector_test_collection_{random.randint(1, 1_000_000)}",
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=embedding_dimension,
)
await adapter.register_vector_db(test_vector_db)
adapter.test_collection_id = test_vector_db.identifier
yield adapter
await adapter.shutdown()
@pytest.fixture
def vector_io_adapter(vector_provider, request):
"""Returns the appropriate vector IO adapter based on the provider parameter."""
vector_provider_dict = {
"milvus": "milvus_vec_adapter",
"faiss": "faiss_vec_adapter",
"sqlite_vec": "sqlite_vec_adapter",
"chroma": "chroma_vec_adapter",
"qdrant": "qdrant_vec_adapter",
"pgvector": "pgvector_vec_adapter",
}
return request.getfixturevalue(vector_provider_dict[vector_provider])

View file

@ -0,0 +1,138 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from unittest.mock import patch
import pytest
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex
PGVECTOR_PROVIDER = "pgvector"
@pytest.fixture(scope="session")
def loop():
return asyncio.new_event_loop()
@pytest.fixture
def embedding_dimension():
"""Default embedding dimension for tests."""
return 384
@pytest.fixture
async def pgvector_index(embedding_dimension, mock_psycopg2_connection):
"""Create a PGVectorIndex instance with mocked database connection."""
connection, cursor = mock_psycopg2_connection
vector_db = VectorDB(
identifier="test-vector-db",
embedding_model="test-model",
embedding_dimension=embedding_dimension,
provider_id=PGVECTOR_PROVIDER,
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
)
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
# Use explicit COSINE distance metric for consistent testing
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE")
return index, cursor
class TestPGVectorIndex:
def test_distance_metric_validation(self, embedding_dimension, mock_psycopg2_connection):
connection, cursor = mock_psycopg2_connection
vector_db = VectorDB(
identifier="test-vector-db",
embedding_model="test-model",
embedding_dimension=embedding_dimension,
provider_id=PGVECTOR_PROVIDER,
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
)
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="L2")
assert index.distance_metric == "L2"
with pytest.raises(ValueError, match="Distance metric 'INVALID' is not supported"):
PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="INVALID")
def test_get_pgvector_search_function(self, pgvector_index):
index, cursor = pgvector_index
supported_metrics = index.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION
for metric, function in supported_metrics.items():
index.distance_metric = metric
assert index.get_pgvector_search_function() == function
def test_check_distance_metric_availability(self, pgvector_index):
index, cursor = pgvector_index
supported_metrics = index.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION
for metric in supported_metrics:
index.check_distance_metric_availability(metric)
with pytest.raises(ValueError, match="Distance metric 'INVALID' is not supported"):
index.check_distance_metric_availability("INVALID")
def test_constructor_invalid_distance_metric(self, embedding_dimension, mock_psycopg2_connection):
connection, cursor = mock_psycopg2_connection
vector_db = VectorDB(
identifier="test-vector-db",
embedding_model="test-model",
embedding_dimension=embedding_dimension,
provider_id=PGVECTOR_PROVIDER,
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
)
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
with pytest.raises(ValueError, match="Distance metric 'INVALID_METRIC' is not supported by PGVector"):
PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="INVALID_METRIC")
with pytest.raises(ValueError, match="Supported metrics are:"):
PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="UNKNOWN")
try:
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE")
assert index.distance_metric == "COSINE"
except ValueError:
pytest.fail("Valid distance metric 'COSINE' should not raise ValueError")
def test_constructor_all_supported_distance_metrics(self, embedding_dimension, mock_psycopg2_connection):
connection, cursor = mock_psycopg2_connection
vector_db = VectorDB(
identifier="test-vector-db",
embedding_model="test-model",
embedding_dimension=embedding_dimension,
provider_id=PGVECTOR_PROVIDER,
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
)
supported_metrics = ["L2", "L1", "COSINE", "INNER_PRODUCT", "HAMMING", "JACCARD"]
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
for metric in supported_metrics:
try:
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric=metric)
assert index.distance_metric == metric
expected_operators = {
"L2": "<->",
"L1": "<+>",
"COSINE": "<=>",
"INNER_PRODUCT": "<#>",
"HAMMING": "<~>",
"JACCARD": "<%>",
}
assert index.get_pgvector_search_function() == expected_operators[metric]
except Exception as e:
pytest.fail(f"Valid distance metric '{metric}' should not raise exception: {e}")