mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-10 13:28:40 +00:00
Merge branch 'main' into main
This commit is contained in:
commit
fc4a75832c
77 changed files with 3806 additions and 2058 deletions
|
@ -8,6 +8,7 @@ from io import BytesIO
|
|||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from llama_stack.core.datatypes import User
|
||||
|
||||
|
@ -79,6 +80,88 @@ def test_openai_client_basic_operations(openai_client):
|
|||
pass # ignore 404
|
||||
|
||||
|
||||
@pytest.mark.xfail(message="expires_after not available on all providers")
|
||||
def test_expires_after(openai_client):
|
||||
"""Test uploading a file with expires_after parameter."""
|
||||
client = openai_client
|
||||
|
||||
uploaded_file = None
|
||||
try:
|
||||
with BytesIO(b"expires_after test") as file_buffer:
|
||||
file_buffer.name = "expires_after.txt"
|
||||
uploaded_file = client.files.create(
|
||||
file=file_buffer,
|
||||
purpose="assistants",
|
||||
expires_after={"anchor": "created_at", "seconds": 4545},
|
||||
)
|
||||
|
||||
assert uploaded_file.expires_at is not None
|
||||
assert uploaded_file.expires_at == uploaded_file.created_at + 4545
|
||||
|
||||
listed = client.files.list()
|
||||
ids = [f.id for f in listed.data]
|
||||
assert uploaded_file.id in ids
|
||||
|
||||
retrieved = client.files.retrieve(uploaded_file.id)
|
||||
assert retrieved.id == uploaded_file.id
|
||||
|
||||
finally:
|
||||
if uploaded_file is not None:
|
||||
try:
|
||||
client.files.delete(uploaded_file.id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.xfail(message="expires_after not available on all providers")
|
||||
def test_expires_after_requests(openai_client):
|
||||
"""Upload a file using requests multipart/form-data and bracketed expires_after fields.
|
||||
|
||||
This ensures clients that send form fields like `expires_after[anchor]` and
|
||||
`expires_after[seconds]` are handled by the server.
|
||||
"""
|
||||
base_url = f"{openai_client.base_url}files"
|
||||
|
||||
uploaded_id = None
|
||||
try:
|
||||
files = {"file": ("expires_after_with_requests.txt", BytesIO(b"expires_after via requests"))}
|
||||
data = {
|
||||
"purpose": "assistants",
|
||||
"expires_after[anchor]": "created_at",
|
||||
"expires_after[seconds]": "4545",
|
||||
}
|
||||
|
||||
session = requests.Session()
|
||||
request = requests.Request("POST", base_url, files=files, data=data)
|
||||
prepared = session.prepare_request(request)
|
||||
resp = session.send(prepared, timeout=30)
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
|
||||
assert result.get("id", "").startswith("file-")
|
||||
uploaded_id = result["id"]
|
||||
assert result.get("created_at") is not None
|
||||
assert result.get("expires_at") == result["created_at"] + 4545
|
||||
|
||||
list_resp = requests.get(base_url, timeout=30)
|
||||
list_resp.raise_for_status()
|
||||
listed = list_resp.json()
|
||||
ids = [f["id"] for f in listed.get("data", [])]
|
||||
assert uploaded_id in ids
|
||||
|
||||
retrieve_resp = requests.get(f"{base_url}/{uploaded_id}", timeout=30)
|
||||
retrieve_resp.raise_for_status()
|
||||
retrieved = retrieve_resp.json()
|
||||
assert retrieved["id"] == uploaded_id
|
||||
|
||||
finally:
|
||||
if uploaded_id:
|
||||
try:
|
||||
requests.delete(f"{base_url}/{uploaded_id}", timeout=30)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.xfail(message="User isolation broken for current providers, must be fixed.")
|
||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||
def test_files_authentication_isolation(mock_get_authenticated_user, llama_stack_client):
|
||||
|
|
Binary file not shown.
|
@ -57,11 +57,13 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode
|
|||
"inline::sqlite-vec",
|
||||
"remote::milvus",
|
||||
"inline::milvus",
|
||||
"remote::pgvector",
|
||||
],
|
||||
"hybrid": [
|
||||
"inline::sqlite-vec",
|
||||
"inline::milvus",
|
||||
"remote::milvus",
|
||||
"remote::pgvector",
|
||||
],
|
||||
}
|
||||
supported_providers = search_mode_support.get(search_mode, [])
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import sqlite3
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
@ -133,7 +132,6 @@ class TestInferenceRecording:
|
|||
# Test directory creation
|
||||
assert storage.test_dir.exists()
|
||||
assert storage.responses_dir.exists()
|
||||
assert storage.db_path.exists()
|
||||
|
||||
# Test storing and retrieving a recording
|
||||
request_hash = "test_hash_123"
|
||||
|
@ -147,15 +145,6 @@ class TestInferenceRecording:
|
|||
|
||||
storage.store_recording(request_hash, request_data, response_data)
|
||||
|
||||
# Verify SQLite record
|
||||
with sqlite3.connect(storage.db_path) as conn:
|
||||
result = conn.execute("SELECT * FROM recordings WHERE request_hash = ?", (request_hash,)).fetchone()
|
||||
|
||||
assert result is not None
|
||||
assert result[0] == request_hash # request_hash
|
||||
assert result[2] == "/v1/chat/completions" # endpoint
|
||||
assert result[3] == "llama3.2:3b" # model
|
||||
|
||||
# Verify file storage and retrieval
|
||||
retrieved = storage.find_recording(request_hash)
|
||||
assert retrieved is not None
|
||||
|
@ -185,10 +174,7 @@ class TestInferenceRecording:
|
|||
|
||||
# Verify recording was stored
|
||||
storage = ResponseStorage(temp_storage_dir)
|
||||
with sqlite3.connect(storage.db_path) as conn:
|
||||
recordings = conn.execute("SELECT COUNT(*) FROM recordings").fetchone()[0]
|
||||
|
||||
assert recordings == 1
|
||||
assert storage.responses_dir.exists()
|
||||
|
||||
async def test_replay_mode(self, temp_storage_dir, real_openai_chat_response):
|
||||
"""Test that replay mode returns stored responses without making real calls."""
|
||||
|
|
62
tests/unit/providers/files/conftest.py
Normal file
62
tests/unit/providers/files/conftest.py
Normal file
|
@ -0,0 +1,62 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import boto3
|
||||
import pytest
|
||||
from moto import mock_aws
|
||||
|
||||
from llama_stack.providers.remote.files.s3 import S3FilesImplConfig, get_adapter_impl
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
|
||||
|
||||
class MockUploadFile:
|
||||
def __init__(self, content: bytes, filename: str, content_type: str = "text/plain"):
|
||||
self.content = content
|
||||
self.filename = filename
|
||||
self.content_type = content_type
|
||||
|
||||
async def read(self):
|
||||
return self.content
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_text_file():
|
||||
content = b"Hello, this is a test file for the S3 Files API!"
|
||||
return MockUploadFile(content, "sample_text_file-0.txt")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_text_file2():
|
||||
content = b"Hello, this is a second test file for the S3 Files API!"
|
||||
return MockUploadFile(content, "sample_text_file-1.txt")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_config(tmp_path):
|
||||
db_path = tmp_path / "s3_files_metadata.db"
|
||||
|
||||
return S3FilesImplConfig(
|
||||
bucket_name=f"test-bucket-{tmp_path.name}",
|
||||
region="not-a-region",
|
||||
auto_create_bucket=True,
|
||||
metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix()),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_client():
|
||||
# we use `with mock_aws()` because @mock_aws decorator does not support
|
||||
# being a generator
|
||||
with mock_aws():
|
||||
# must yield or the mock will be reset before it is used
|
||||
yield boto3.client("s3")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def s3_provider(s3_config, s3_client): # s3_client provides the moto mock, don't remove it
|
||||
provider = await get_adapter_impl(s3_config, {})
|
||||
yield provider
|
||||
await provider.shutdown()
|
|
@ -6,63 +6,11 @@
|
|||
|
||||
from unittest.mock import patch
|
||||
|
||||
import boto3
|
||||
import pytest
|
||||
from botocore.exceptions import ClientError
|
||||
from moto import mock_aws
|
||||
|
||||
from llama_stack.apis.common.errors import ResourceNotFoundError
|
||||
from llama_stack.apis.files import OpenAIFilePurpose
|
||||
from llama_stack.providers.remote.files.s3 import (
|
||||
S3FilesImplConfig,
|
||||
get_adapter_impl,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
|
||||
|
||||
class MockUploadFile:
|
||||
def __init__(self, content: bytes, filename: str, content_type: str = "text/plain"):
|
||||
self.content = content
|
||||
self.filename = filename
|
||||
self.content_type = content_type
|
||||
|
||||
async def read(self):
|
||||
return self.content
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_config(tmp_path):
|
||||
db_path = tmp_path / "s3_files_metadata.db"
|
||||
|
||||
return S3FilesImplConfig(
|
||||
bucket_name="test-bucket",
|
||||
region="not-a-region",
|
||||
auto_create_bucket=True,
|
||||
metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix()),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_client():
|
||||
"""Create a mocked S3 client for testing."""
|
||||
# we use `with mock_aws()` because @mock_aws decorator does not support being a generator
|
||||
with mock_aws():
|
||||
# must yield or the mock will be reset before it is used
|
||||
yield boto3.client("s3")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def s3_provider(s3_config, s3_client):
|
||||
"""Create an S3 files provider with mocked S3 for testing."""
|
||||
provider = await get_adapter_impl(s3_config, {})
|
||||
yield provider
|
||||
await provider.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_text_file():
|
||||
content = b"Hello, this is a test file for the S3 Files API!"
|
||||
return MockUploadFile(content, "sample_text_file.txt")
|
||||
|
||||
|
||||
class TestS3FilesImpl:
|
||||
|
@ -143,7 +91,7 @@ class TestS3FilesImpl:
|
|||
s3_client.head_object(Bucket=s3_config.bucket_name, Key=uploaded.id)
|
||||
assert exc_info.value.response["Error"]["Code"] == "404"
|
||||
|
||||
async def test_list_files(self, s3_provider, sample_text_file):
|
||||
async def test_list_files(self, s3_provider, sample_text_file, sample_text_file2):
|
||||
"""Test listing files after uploading some."""
|
||||
sample_text_file.filename = "test_list_files_with_content_file1"
|
||||
file1 = await s3_provider.openai_upload_file(
|
||||
|
@ -151,9 +99,9 @@ class TestS3FilesImpl:
|
|||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
file2_content = MockUploadFile(b"Second file content", "test_list_files_with_content_file2")
|
||||
sample_text_file2.filename = "test_list_files_with_content_file2"
|
||||
file2 = await s3_provider.openai_upload_file(
|
||||
file=file2_content,
|
||||
file=sample_text_file2,
|
||||
purpose=OpenAIFilePurpose.BATCH,
|
||||
)
|
||||
|
||||
|
@ -164,7 +112,7 @@ class TestS3FilesImpl:
|
|||
assert file1.id in file_ids
|
||||
assert file2.id in file_ids
|
||||
|
||||
async def test_list_files_with_purpose_filter(self, s3_provider, sample_text_file):
|
||||
async def test_list_files_with_purpose_filter(self, s3_provider, sample_text_file, sample_text_file2):
|
||||
"""Test listing files with purpose filter."""
|
||||
sample_text_file.filename = "test_list_files_with_purpose_filter_file1"
|
||||
file1 = await s3_provider.openai_upload_file(
|
||||
|
@ -172,9 +120,9 @@ class TestS3FilesImpl:
|
|||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
file2_content = MockUploadFile(b"Batch file content", "test_list_files_with_purpose_filter_file2")
|
||||
sample_text_file2.filename = "test_list_files_with_purpose_filter_file2"
|
||||
await s3_provider.openai_upload_file(
|
||||
file=file2_content,
|
||||
file=sample_text_file2,
|
||||
purpose=OpenAIFilePurpose.BATCH,
|
||||
)
|
||||
|
||||
|
@ -249,3 +197,104 @@ class TestS3FilesImpl:
|
|||
|
||||
files_list = await s3_provider.openai_list_files()
|
||||
assert len(files_list.data) == 0, "No file metadata should remain after failed upload"
|
||||
|
||||
@pytest.mark.parametrize("purpose", [p for p in OpenAIFilePurpose if p != OpenAIFilePurpose.BATCH])
|
||||
async def test_default_no_expiration(self, s3_provider, sample_text_file, purpose):
|
||||
"""Test that by default files have no expiration."""
|
||||
sample_text_file.filename = "test_default_no_expiration"
|
||||
uploaded = await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=purpose,
|
||||
)
|
||||
assert uploaded.expires_at is None, "By default files should have no expiration"
|
||||
|
||||
async def test_default_batch_expiration(self, s3_provider, sample_text_file):
|
||||
"""Test that by default batch files have an expiration."""
|
||||
sample_text_file.filename = "test_default_batch_an_expiration"
|
||||
uploaded = await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.BATCH,
|
||||
)
|
||||
assert uploaded.expires_at is not None, "By default batch files should have an expiration"
|
||||
thirty_days_seconds = 30 * 24 * 3600
|
||||
assert uploaded.expires_at == uploaded.created_at + thirty_days_seconds, (
|
||||
"Batch default expiration should be 30 days"
|
||||
)
|
||||
|
||||
async def test_expired_file_is_unavailable(self, s3_provider, sample_text_file, s3_config, s3_client):
|
||||
"""Uploaded file that has expired should not be listed or retrievable/deletable."""
|
||||
with patch.object(s3_provider, "_now") as mock_now: # control time
|
||||
two_hours = 2 * 60 * 60
|
||||
|
||||
mock_now.return_value = 0
|
||||
|
||||
sample_text_file.filename = "test_expired_file"
|
||||
uploaded = await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
expires_after_anchor="created_at",
|
||||
expires_after_seconds=two_hours,
|
||||
)
|
||||
|
||||
mock_now.return_value = two_hours * 2 # fast forward 4 hours
|
||||
|
||||
listed = await s3_provider.openai_list_files()
|
||||
assert uploaded.id not in [f.id for f in listed.data]
|
||||
|
||||
with pytest.raises(ResourceNotFoundError, match="not found"):
|
||||
await s3_provider.openai_retrieve_file(uploaded.id)
|
||||
|
||||
with pytest.raises(ResourceNotFoundError, match="not found"):
|
||||
await s3_provider.openai_retrieve_file_content(uploaded.id)
|
||||
|
||||
with pytest.raises(ResourceNotFoundError, match="not found"):
|
||||
await s3_provider.openai_delete_file(uploaded.id)
|
||||
|
||||
with pytest.raises(ClientError) as exc_info:
|
||||
s3_client.head_object(Bucket=s3_config.bucket_name, Key=uploaded.id)
|
||||
assert exc_info.value.response["Error"]["Code"] == "404"
|
||||
|
||||
with pytest.raises(ResourceNotFoundError, match="not found"):
|
||||
await s3_provider._get_file(uploaded.id, return_expired=True)
|
||||
|
||||
async def test_unsupported_expires_after_anchor(self, s3_provider, sample_text_file):
|
||||
"""Unsupported anchor value should raise ValueError."""
|
||||
sample_text_file.filename = "test_unsupported_expires_after_anchor"
|
||||
|
||||
with pytest.raises(ValueError, match="Input should be 'created_at'"):
|
||||
await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
expires_after_anchor="now",
|
||||
expires_after_seconds=3600,
|
||||
)
|
||||
|
||||
async def test_nonint_expires_after_seconds(self, s3_provider, sample_text_file):
|
||||
"""Non-integer seconds in expires_after should raise ValueError."""
|
||||
sample_text_file.filename = "test_nonint_expires_after_seconds"
|
||||
|
||||
with pytest.raises(ValueError, match="should be a valid integer"):
|
||||
await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
expires_after_anchor="created_at",
|
||||
expires_after_seconds="many",
|
||||
)
|
||||
|
||||
async def test_expires_after_seconds_out_of_bounds(self, s3_provider, sample_text_file):
|
||||
"""Seconds outside allowed range should raise ValueError."""
|
||||
with pytest.raises(ValueError, match="greater than or equal to 3600"):
|
||||
await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
expires_after_anchor="created_at",
|
||||
expires_after_seconds=3599,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="less than or equal to 2592000"):
|
||||
await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
expires_after_anchor="created_at",
|
||||
expires_after_seconds=2592001,
|
||||
)
|
||||
|
|
89
tests/unit/providers/files/test_s3_files_auth.py
Normal file
89
tests/unit/providers/files/test_s3_files_auth.py
Normal file
|
@ -0,0 +1,89 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.errors import ResourceNotFoundError
|
||||
from llama_stack.apis.files import OpenAIFilePurpose
|
||||
from llama_stack.core.datatypes import User
|
||||
from llama_stack.providers.remote.files.s3.files import S3FilesImpl
|
||||
|
||||
|
||||
async def test_listing_hides_other_users_file(s3_provider, sample_text_file):
|
||||
"""Listing should not show files uploaded by other users."""
|
||||
user_a = User("user-a", {"roles": ["team-a"]})
|
||||
user_b = User("user-b", {"roles": ["team-b"]})
|
||||
|
||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
mock_get_user.return_value = user_a
|
||||
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
|
||||
|
||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
mock_get_user.return_value = user_b
|
||||
listed = await s3_provider.openai_list_files()
|
||||
assert all(f.id != uploaded.id for f in listed.data)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"op",
|
||||
[S3FilesImpl.openai_retrieve_file, S3FilesImpl.openai_retrieve_file_content, S3FilesImpl.openai_delete_file],
|
||||
ids=["retrieve", "content", "delete"],
|
||||
)
|
||||
async def test_cannot_access_other_user_file(s3_provider, sample_text_file, op):
|
||||
"""Operations (metadata/content/delete) on another user's file should raise ResourceNotFoundError.
|
||||
|
||||
`op` is an async callable (provider, file_id) -> awaits the requested operation.
|
||||
"""
|
||||
user_a = User("user-a", {"roles": ["team-a"]})
|
||||
user_b = User("user-b", {"roles": ["team-b"]})
|
||||
|
||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
mock_get_user.return_value = user_a
|
||||
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
|
||||
|
||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
mock_get_user.return_value = user_b
|
||||
with pytest.raises(ResourceNotFoundError):
|
||||
await op(s3_provider, uploaded.id)
|
||||
|
||||
|
||||
async def test_shared_role_allows_listing(s3_provider, sample_text_file):
|
||||
"""Listing should show files uploaded by other users when roles are shared."""
|
||||
user_a = User("user-a", {"roles": ["shared-role"]})
|
||||
user_b = User("user-b", {"roles": ["shared-role"]})
|
||||
|
||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
mock_get_user.return_value = user_a
|
||||
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
|
||||
|
||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
mock_get_user.return_value = user_b
|
||||
listed = await s3_provider.openai_list_files()
|
||||
assert any(f.id == uploaded.id for f in listed.data)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"op",
|
||||
[S3FilesImpl.openai_retrieve_file, S3FilesImpl.openai_retrieve_file_content, S3FilesImpl.openai_delete_file],
|
||||
ids=["retrieve", "content", "delete"],
|
||||
)
|
||||
async def test_shared_role_allows_access(s3_provider, sample_text_file, op):
|
||||
"""Operations (metadata/content/delete) on another user's file should succeed when users share a role.
|
||||
|
||||
`op` is an async callable (provider, file_id) -> awaits the requested operation.
|
||||
"""
|
||||
user_x = User("user-x", {"roles": ["shared-role"]})
|
||||
user_y = User("user-y", {"roles": ["shared-role"]})
|
||||
|
||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
mock_get_user.return_value = user_x
|
||||
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
|
||||
|
||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
mock_get_user.return_value = user_y
|
||||
await op(s3_provider, uploaded.id)
|
248
tests/unit/providers/utils/memory/test_reranking.py
Normal file
248
tests/unit/providers/utils/memory/test_reranking.py
Normal file
|
@ -0,0 +1,248 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.providers.utils.memory.vector_store import RERANKER_TYPE_RRF, RERANKER_TYPE_WEIGHTED
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator
|
||||
|
||||
|
||||
class TestNormalizeScores:
|
||||
"""Test cases for score normalization."""
|
||||
|
||||
def test_normalize_scores_basic(self):
|
||||
"""Test basic score normalization."""
|
||||
scores = {"doc1": 10.0, "doc2": 5.0, "doc3": 0.0}
|
||||
normalized = WeightedInMemoryAggregator._normalize_scores(scores)
|
||||
|
||||
assert normalized["doc1"] == 1.0 # Max score
|
||||
assert normalized["doc3"] == 0.0 # Min score
|
||||
assert normalized["doc2"] == 0.5 # Middle score
|
||||
assert all(0 <= score <= 1 for score in normalized.values())
|
||||
|
||||
def test_normalize_scores_identical(self):
|
||||
"""Test normalization when all scores are identical."""
|
||||
scores = {"doc1": 5.0, "doc2": 5.0, "doc3": 5.0}
|
||||
normalized = WeightedInMemoryAggregator._normalize_scores(scores)
|
||||
|
||||
# All scores should be 1.0 when identical
|
||||
assert all(score == 1.0 for score in normalized.values())
|
||||
|
||||
def test_normalize_scores_empty(self):
|
||||
"""Test normalization with empty scores."""
|
||||
scores = {}
|
||||
normalized = WeightedInMemoryAggregator._normalize_scores(scores)
|
||||
|
||||
assert normalized == {}
|
||||
|
||||
def test_normalize_scores_single(self):
|
||||
"""Test normalization with single score."""
|
||||
scores = {"doc1": 7.5}
|
||||
normalized = WeightedInMemoryAggregator._normalize_scores(scores)
|
||||
|
||||
assert normalized["doc1"] == 1.0
|
||||
|
||||
|
||||
class TestWeightedRerank:
|
||||
"""Test cases for weighted reranking."""
|
||||
|
||||
def test_weighted_rerank_basic(self):
|
||||
"""Test basic weighted reranking."""
|
||||
vector_scores = {"doc1": 0.9, "doc2": 0.7, "doc3": 0.5}
|
||||
keyword_scores = {"doc1": 0.6, "doc2": 0.8, "doc4": 0.9}
|
||||
|
||||
combined = WeightedInMemoryAggregator.weighted_rerank(vector_scores, keyword_scores, alpha=0.5)
|
||||
|
||||
# Should include all documents
|
||||
expected_docs = {"doc1", "doc2", "doc3", "doc4"}
|
||||
assert set(combined.keys()) == expected_docs
|
||||
|
||||
# All scores should be between 0 and 1
|
||||
assert all(0 <= score <= 1 for score in combined.values())
|
||||
|
||||
# doc1 appears in both searches, should have higher combined score
|
||||
assert combined["doc1"] > 0
|
||||
|
||||
def test_weighted_rerank_alpha_zero(self):
|
||||
"""Test weighted reranking with alpha=0 (keyword only)."""
|
||||
vector_scores = {"doc1": 0.9, "doc2": 0.7, "doc3": 0.5} # All docs present in vector
|
||||
keyword_scores = {"doc1": 0.1, "doc2": 0.3, "doc3": 0.9} # All docs present in keyword
|
||||
|
||||
combined = WeightedInMemoryAggregator.weighted_rerank(vector_scores, keyword_scores, alpha=0.0)
|
||||
|
||||
# Alpha=0 means vector scores are ignored, keyword scores dominate
|
||||
# doc3 should score highest since it has highest keyword score
|
||||
assert combined["doc3"] > combined["doc2"] > combined["doc1"]
|
||||
|
||||
def test_weighted_rerank_alpha_one(self):
|
||||
"""Test weighted reranking with alpha=1 (vector only)."""
|
||||
vector_scores = {"doc1": 0.9, "doc2": 0.7, "doc3": 0.5} # All docs present in vector
|
||||
keyword_scores = {"doc1": 0.1, "doc2": 0.3, "doc3": 0.9} # All docs present in keyword
|
||||
|
||||
combined = WeightedInMemoryAggregator.weighted_rerank(vector_scores, keyword_scores, alpha=1.0)
|
||||
|
||||
# Alpha=1 means keyword scores are ignored, vector scores dominate
|
||||
# doc1 should score highest since it has highest vector score
|
||||
assert combined["doc1"] > combined["doc2"] > combined["doc3"]
|
||||
|
||||
def test_weighted_rerank_no_overlap(self):
|
||||
"""Test weighted reranking with no overlapping documents."""
|
||||
vector_scores = {"doc1": 0.9, "doc2": 0.7}
|
||||
keyword_scores = {"doc3": 0.8, "doc4": 0.6}
|
||||
|
||||
combined = WeightedInMemoryAggregator.weighted_rerank(vector_scores, keyword_scores, alpha=0.5)
|
||||
|
||||
assert len(combined) == 4
|
||||
# With min-max normalization, lowest scoring docs in each group get 0.0
|
||||
# but highest scoring docs should get positive scores
|
||||
assert all(score >= 0 for score in combined.values())
|
||||
assert combined["doc1"] > 0 # highest vector score
|
||||
assert combined["doc3"] > 0 # highest keyword score
|
||||
|
||||
|
||||
class TestRRFRerank:
|
||||
"""Test cases for RRF (Reciprocal Rank Fusion) reranking."""
|
||||
|
||||
def test_rrf_rerank_basic(self):
|
||||
"""Test basic RRF reranking."""
|
||||
vector_scores = {"doc1": 0.9, "doc2": 0.7, "doc3": 0.5}
|
||||
keyword_scores = {"doc1": 0.6, "doc2": 0.8, "doc4": 0.9}
|
||||
|
||||
combined = WeightedInMemoryAggregator.rrf_rerank(vector_scores, keyword_scores, impact_factor=60.0)
|
||||
|
||||
# Should include all documents
|
||||
expected_docs = {"doc1", "doc2", "doc3", "doc4"}
|
||||
assert set(combined.keys()) == expected_docs
|
||||
|
||||
# All scores should be positive
|
||||
assert all(score > 0 for score in combined.values())
|
||||
|
||||
# Documents appearing in both searches should have higher scores
|
||||
# doc1 and doc2 appear in both, doc3 and doc4 appear in only one
|
||||
assert combined["doc1"] > combined["doc3"]
|
||||
assert combined["doc2"] > combined["doc4"]
|
||||
|
||||
def test_rrf_rerank_rank_calculation(self):
|
||||
"""Test that RRF correctly calculates ranks."""
|
||||
# Create clear ranking order
|
||||
vector_scores = {"doc1": 1.0, "doc2": 0.8, "doc3": 0.6} # Ranks: 1, 2, 3
|
||||
keyword_scores = {"doc1": 0.5, "doc2": 1.0, "doc3": 0.7} # Ranks: 3, 1, 2
|
||||
|
||||
combined = WeightedInMemoryAggregator.rrf_rerank(vector_scores, keyword_scores, impact_factor=60.0)
|
||||
|
||||
# doc1: rank 1 in vector, rank 3 in keyword
|
||||
# doc2: rank 2 in vector, rank 1 in keyword
|
||||
# doc3: rank 3 in vector, rank 2 in keyword
|
||||
|
||||
# doc2 should have the highest combined score (ranks 2+1=3)
|
||||
# followed by doc1 (ranks 1+3=4) and doc3 (ranks 3+2=5)
|
||||
# Remember: lower rank sum = higher RRF score
|
||||
assert combined["doc2"] > combined["doc1"] > combined["doc3"]
|
||||
|
||||
def test_rrf_rerank_impact_factor(self):
|
||||
"""Test that impact factor affects RRF scores."""
|
||||
vector_scores = {"doc1": 0.9, "doc2": 0.7}
|
||||
keyword_scores = {"doc1": 0.8, "doc2": 0.6}
|
||||
|
||||
combined_low = WeightedInMemoryAggregator.rrf_rerank(vector_scores, keyword_scores, impact_factor=10.0)
|
||||
combined_high = WeightedInMemoryAggregator.rrf_rerank(vector_scores, keyword_scores, impact_factor=100.0)
|
||||
|
||||
# Higher impact factor should generally result in lower scores
|
||||
# (because 1/(k+r) decreases as k increases)
|
||||
assert combined_low["doc1"] > combined_high["doc1"]
|
||||
assert combined_low["doc2"] > combined_high["doc2"]
|
||||
|
||||
def test_rrf_rerank_missing_documents(self):
|
||||
"""Test RRF handling of documents missing from one search."""
|
||||
vector_scores = {"doc1": 0.9, "doc2": 0.7}
|
||||
keyword_scores = {"doc1": 0.8, "doc3": 0.6}
|
||||
|
||||
combined = WeightedInMemoryAggregator.rrf_rerank(vector_scores, keyword_scores, impact_factor=60.0)
|
||||
|
||||
# Should include all documents
|
||||
assert len(combined) == 3
|
||||
|
||||
# doc1 appears in both searches, should have highest score
|
||||
assert combined["doc1"] > combined["doc2"]
|
||||
assert combined["doc1"] > combined["doc3"]
|
||||
|
||||
|
||||
class TestCombineSearchResults:
|
||||
"""Test cases for the main combine_search_results function."""
|
||||
|
||||
def test_combine_search_results_rrf_default(self):
|
||||
"""Test combining with RRF as default."""
|
||||
vector_scores = {"doc1": 0.9, "doc2": 0.7}
|
||||
keyword_scores = {"doc1": 0.6, "doc3": 0.8}
|
||||
|
||||
combined = WeightedInMemoryAggregator.combine_search_results(vector_scores, keyword_scores)
|
||||
|
||||
# Should default to RRF
|
||||
assert len(combined) == 3
|
||||
assert all(score > 0 for score in combined.values())
|
||||
|
||||
def test_combine_search_results_rrf_explicit(self):
|
||||
"""Test combining with explicit RRF."""
|
||||
vector_scores = {"doc1": 0.9, "doc2": 0.7}
|
||||
keyword_scores = {"doc1": 0.6, "doc3": 0.8}
|
||||
|
||||
combined = WeightedInMemoryAggregator.combine_search_results(
|
||||
vector_scores, keyword_scores, reranker_type=RERANKER_TYPE_RRF, reranker_params={"impact_factor": 50.0}
|
||||
)
|
||||
|
||||
assert len(combined) == 3
|
||||
assert all(score > 0 for score in combined.values())
|
||||
|
||||
def test_combine_search_results_weighted(self):
|
||||
"""Test combining with weighted reranking."""
|
||||
vector_scores = {"doc1": 0.9, "doc2": 0.7}
|
||||
keyword_scores = {"doc1": 0.6, "doc3": 0.8}
|
||||
|
||||
combined = WeightedInMemoryAggregator.combine_search_results(
|
||||
vector_scores, keyword_scores, reranker_type=RERANKER_TYPE_WEIGHTED, reranker_params={"alpha": 0.3}
|
||||
)
|
||||
|
||||
assert len(combined) == 3
|
||||
assert all(0 <= score <= 1 for score in combined.values())
|
||||
|
||||
def test_combine_search_results_unknown_type(self):
|
||||
"""Test combining with unknown reranker type defaults to RRF."""
|
||||
vector_scores = {"doc1": 0.9}
|
||||
keyword_scores = {"doc2": 0.8}
|
||||
|
||||
combined = WeightedInMemoryAggregator.combine_search_results(
|
||||
vector_scores, keyword_scores, reranker_type="unknown_type"
|
||||
)
|
||||
|
||||
# Should fall back to RRF
|
||||
assert len(combined) == 2
|
||||
assert all(score > 0 for score in combined.values())
|
||||
|
||||
def test_combine_search_results_empty_params(self):
|
||||
"""Test combining with empty parameters."""
|
||||
vector_scores = {"doc1": 0.9}
|
||||
keyword_scores = {"doc2": 0.8}
|
||||
|
||||
combined = WeightedInMemoryAggregator.combine_search_results(vector_scores, keyword_scores, reranker_params={})
|
||||
|
||||
# Should use default parameters
|
||||
assert len(combined) == 2
|
||||
assert all(score > 0 for score in combined.values())
|
||||
|
||||
def test_combine_search_results_empty_scores(self):
|
||||
"""Test combining with empty score dictionaries."""
|
||||
# Test with empty vector scores
|
||||
combined = WeightedInMemoryAggregator.combine_search_results({}, {"doc1": 0.8})
|
||||
assert len(combined) == 1
|
||||
assert combined["doc1"] > 0
|
||||
|
||||
# Test with empty keyword scores
|
||||
combined = WeightedInMemoryAggregator.combine_search_results({"doc1": 0.9}, {})
|
||||
assert len(combined) == 1
|
||||
assert combined["doc1"] > 0
|
||||
|
||||
# Test with both empty
|
||||
combined = WeightedInMemoryAggregator.combine_search_results({}, {})
|
||||
assert len(combined) == 0
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import random
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
@ -12,7 +13,7 @@ from chromadb import PersistentClient
|
|||
from pymilvus import MilvusClient, connections
|
||||
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata
|
||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
|
||||
from llama_stack.providers.inline.vector_io.chroma.config import ChromaVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
|
||||
|
@ -22,6 +23,8 @@ from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConf
|
|||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter, maybe_await
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
|
||||
|
||||
EMBEDDING_DIMENSION = 384
|
||||
|
@ -29,7 +32,7 @@ COLLECTION_PREFIX = "test_collection"
|
|||
MILVUS_ALIAS = "test_milvus"
|
||||
|
||||
|
||||
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma"])
|
||||
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector"])
|
||||
def vector_provider(request):
|
||||
return request.param
|
||||
|
||||
|
@ -333,15 +336,127 @@ async def qdrant_vec_index(qdrant_vec_db_path, embedding_dimension):
|
|||
await index.delete()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_psycopg2_connection():
|
||||
connection = MagicMock()
|
||||
cursor = MagicMock()
|
||||
|
||||
cursor.__enter__ = MagicMock(return_value=cursor)
|
||||
cursor.__exit__ = MagicMock()
|
||||
|
||||
connection.cursor.return_value = cursor
|
||||
|
||||
return connection, cursor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection):
|
||||
connection, cursor = mock_psycopg2_connection
|
||||
|
||||
vector_db = VectorDB(
|
||||
identifier="test-vector-db",
|
||||
embedding_model="test-model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
provider_id="pgvector",
|
||||
provider_resource_id="pgvector:test-vector-db",
|
||||
)
|
||||
|
||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
|
||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.execute_values"):
|
||||
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE")
|
||||
index._test_chunks = []
|
||||
original_add_chunks = index.add_chunks
|
||||
|
||||
async def mock_add_chunks(chunks, embeddings):
|
||||
index._test_chunks = list(chunks)
|
||||
await original_add_chunks(chunks, embeddings)
|
||||
|
||||
index.add_chunks = mock_add_chunks
|
||||
|
||||
async def mock_query_vector(embedding, k, score_threshold):
|
||||
chunks = index._test_chunks[:k] if hasattr(index, "_test_chunks") else []
|
||||
scores = [1.0] * len(chunks)
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
index.query_vector = mock_query_vector
|
||||
|
||||
yield index
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def pgvector_vec_adapter(mock_inference_api, embedding_dimension):
|
||||
config = PGVectorVectorIOConfig(
|
||||
host="localhost",
|
||||
port=5432,
|
||||
db="test_db",
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
kvstore=SqliteKVStoreConfig(),
|
||||
)
|
||||
|
||||
adapter = PGVectorVectorIOAdapter(config, mock_inference_api, None)
|
||||
|
||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2.connect") as mock_connect:
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_cursor.__exit__ = MagicMock()
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_conn.autocommit = True
|
||||
mock_connect.return_value = mock_conn
|
||||
|
||||
with patch(
|
||||
"llama_stack.providers.remote.vector_io.pgvector.pgvector.check_extension_version"
|
||||
) as mock_check_version:
|
||||
mock_check_version.return_value = "0.5.1"
|
||||
|
||||
with patch("llama_stack.providers.utils.kvstore.kvstore_impl") as mock_kvstore_impl:
|
||||
mock_kvstore = AsyncMock()
|
||||
mock_kvstore_impl.return_value = mock_kvstore
|
||||
|
||||
with patch.object(adapter, "initialize_openai_vector_stores", new_callable=AsyncMock):
|
||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.upsert_models"):
|
||||
await adapter.initialize()
|
||||
adapter.conn = mock_conn
|
||||
|
||||
async def mock_insert_chunks(vector_db_id, chunks, ttl_seconds=None):
|
||||
index = await adapter._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
adapter.insert_chunks = mock_insert_chunks
|
||||
|
||||
async def mock_query_chunks(vector_db_id, query, params=None):
|
||||
index = await adapter._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
adapter.query_chunks = mock_query_chunks
|
||||
|
||||
test_vector_db = VectorDB(
|
||||
identifier=f"pgvector_test_collection_{random.randint(1, 1_000_000)}",
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
await adapter.register_vector_db(test_vector_db)
|
||||
adapter.test_collection_id = test_vector_db.identifier
|
||||
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_io_adapter(vector_provider, request):
|
||||
"""Returns the appropriate vector IO adapter based on the provider parameter."""
|
||||
vector_provider_dict = {
|
||||
"milvus": "milvus_vec_adapter",
|
||||
"faiss": "faiss_vec_adapter",
|
||||
"sqlite_vec": "sqlite_vec_adapter",
|
||||
"chroma": "chroma_vec_adapter",
|
||||
"qdrant": "qdrant_vec_adapter",
|
||||
"pgvector": "pgvector_vec_adapter",
|
||||
}
|
||||
return request.getfixturevalue(vector_provider_dict[vector_provider])
|
||||
|
||||
|
|
138
tests/unit/providers/vector_io/remote/test_pgvector.py
Normal file
138
tests/unit/providers/vector_io/remote/test_pgvector.py
Normal file
|
@ -0,0 +1,138 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex
|
||||
|
||||
PGVECTOR_PROVIDER = "pgvector"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def loop():
|
||||
return asyncio.new_event_loop()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def embedding_dimension():
|
||||
"""Default embedding dimension for tests."""
|
||||
return 384
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def pgvector_index(embedding_dimension, mock_psycopg2_connection):
|
||||
"""Create a PGVectorIndex instance with mocked database connection."""
|
||||
connection, cursor = mock_psycopg2_connection
|
||||
|
||||
vector_db = VectorDB(
|
||||
identifier="test-vector-db",
|
||||
embedding_model="test-model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
provider_id=PGVECTOR_PROVIDER,
|
||||
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
|
||||
)
|
||||
|
||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
|
||||
# Use explicit COSINE distance metric for consistent testing
|
||||
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE")
|
||||
|
||||
return index, cursor
|
||||
|
||||
|
||||
class TestPGVectorIndex:
|
||||
def test_distance_metric_validation(self, embedding_dimension, mock_psycopg2_connection):
|
||||
connection, cursor = mock_psycopg2_connection
|
||||
|
||||
vector_db = VectorDB(
|
||||
identifier="test-vector-db",
|
||||
embedding_model="test-model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
provider_id=PGVECTOR_PROVIDER,
|
||||
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
|
||||
)
|
||||
|
||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
|
||||
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="L2")
|
||||
assert index.distance_metric == "L2"
|
||||
with pytest.raises(ValueError, match="Distance metric 'INVALID' is not supported"):
|
||||
PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="INVALID")
|
||||
|
||||
def test_get_pgvector_search_function(self, pgvector_index):
|
||||
index, cursor = pgvector_index
|
||||
supported_metrics = index.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION
|
||||
|
||||
for metric, function in supported_metrics.items():
|
||||
index.distance_metric = metric
|
||||
assert index.get_pgvector_search_function() == function
|
||||
|
||||
def test_check_distance_metric_availability(self, pgvector_index):
|
||||
index, cursor = pgvector_index
|
||||
supported_metrics = index.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION
|
||||
|
||||
for metric in supported_metrics:
|
||||
index.check_distance_metric_availability(metric)
|
||||
|
||||
with pytest.raises(ValueError, match="Distance metric 'INVALID' is not supported"):
|
||||
index.check_distance_metric_availability("INVALID")
|
||||
|
||||
def test_constructor_invalid_distance_metric(self, embedding_dimension, mock_psycopg2_connection):
|
||||
connection, cursor = mock_psycopg2_connection
|
||||
|
||||
vector_db = VectorDB(
|
||||
identifier="test-vector-db",
|
||||
embedding_model="test-model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
provider_id=PGVECTOR_PROVIDER,
|
||||
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
|
||||
)
|
||||
|
||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
|
||||
with pytest.raises(ValueError, match="Distance metric 'INVALID_METRIC' is not supported by PGVector"):
|
||||
PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="INVALID_METRIC")
|
||||
|
||||
with pytest.raises(ValueError, match="Supported metrics are:"):
|
||||
PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="UNKNOWN")
|
||||
|
||||
try:
|
||||
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE")
|
||||
assert index.distance_metric == "COSINE"
|
||||
except ValueError:
|
||||
pytest.fail("Valid distance metric 'COSINE' should not raise ValueError")
|
||||
|
||||
def test_constructor_all_supported_distance_metrics(self, embedding_dimension, mock_psycopg2_connection):
|
||||
connection, cursor = mock_psycopg2_connection
|
||||
|
||||
vector_db = VectorDB(
|
||||
identifier="test-vector-db",
|
||||
embedding_model="test-model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
provider_id=PGVECTOR_PROVIDER,
|
||||
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
|
||||
)
|
||||
|
||||
supported_metrics = ["L2", "L1", "COSINE", "INNER_PRODUCT", "HAMMING", "JACCARD"]
|
||||
|
||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
|
||||
for metric in supported_metrics:
|
||||
try:
|
||||
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric=metric)
|
||||
assert index.distance_metric == metric
|
||||
|
||||
expected_operators = {
|
||||
"L2": "<->",
|
||||
"L1": "<+>",
|
||||
"COSINE": "<=>",
|
||||
"INNER_PRODUCT": "<#>",
|
||||
"HAMMING": "<~>",
|
||||
"JACCARD": "<%>",
|
||||
}
|
||||
assert index.get_pgvector_search_function() == expected_operators[metric]
|
||||
except Exception as e:
|
||||
pytest.fail(f"Valid distance metric '{metric}' should not raise exception: {e}")
|
|
@ -88,3 +88,10 @@ def test_nested_structures(setup_env_vars):
|
|||
}
|
||||
expected = {"key1": "test_value", "key2": ["default", "conditional"], "key3": {"nested": None}}
|
||||
assert replace_env_vars(data) == expected
|
||||
|
||||
|
||||
def test_explicit_strings_preserved(setup_env_vars):
|
||||
# Explicit strings that look like numbers/booleans should remain strings
|
||||
data = {"port": "8080", "enabled": "true", "count": "123", "ratio": "3.14"}
|
||||
expected = {"port": "8080", "enabled": "true", "count": "123", "ratio": "3.14"}
|
||||
assert replace_env_vars(data) == expected
|
||||
|
|
|
@ -332,6 +332,63 @@ async def test_sqlstore_pagination_error_handling():
|
|||
)
|
||||
|
||||
|
||||
async def test_where_operator_gt_and_update_delete():
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path))
|
||||
|
||||
await store.create_table(
|
||||
"items",
|
||||
{
|
||||
"id": ColumnType.INTEGER,
|
||||
"value": ColumnType.INTEGER,
|
||||
"name": ColumnType.STRING,
|
||||
},
|
||||
)
|
||||
|
||||
await store.insert("items", {"id": 1, "value": 10, "name": "one"})
|
||||
await store.insert("items", {"id": 2, "value": 20, "name": "two"})
|
||||
await store.insert("items", {"id": 3, "value": 30, "name": "three"})
|
||||
|
||||
result = await store.fetch_all("items", where={"value": {">": 15}})
|
||||
assert {r["id"] for r in result.data} == {2, 3}
|
||||
|
||||
row = await store.fetch_one("items", where={"value": {">=": 30}})
|
||||
assert row["id"] == 3
|
||||
|
||||
await store.update("items", {"name": "small"}, {"value": {"<": 25}})
|
||||
rows = (await store.fetch_all("items")).data
|
||||
names = {r["id"]: r["name"] for r in rows}
|
||||
assert names[1] == "small"
|
||||
assert names[2] == "small"
|
||||
assert names[3] == "three"
|
||||
|
||||
await store.delete("items", {"id": {"==": 2}})
|
||||
rows_after = (await store.fetch_all("items")).data
|
||||
assert {r["id"] for r in rows_after} == {1, 3}
|
||||
|
||||
|
||||
async def test_where_operator_edge_cases():
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path))
|
||||
|
||||
await store.create_table(
|
||||
"events",
|
||||
{"id": ColumnType.STRING, "ts": ColumnType.INTEGER},
|
||||
)
|
||||
|
||||
base = 1024
|
||||
await store.insert("events", {"id": "a", "ts": base - 10})
|
||||
await store.insert("events", {"id": "b", "ts": base + 10})
|
||||
|
||||
row = await store.fetch_one("events", where={"id": "a"})
|
||||
assert row["id"] == "a"
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported operator"):
|
||||
await store.fetch_all("events", where={"ts": {"!=": base}})
|
||||
|
||||
|
||||
async def test_sqlstore_pagination_custom_key_column():
|
||||
"""Test pagination with custom primary key column (not 'id')."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue