From e96e3c4da430903dbf4e410f05909b99f47f358c Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Fri, 29 Aug 2025 10:14:00 -0400 Subject: [PATCH 1/3] feat(s3 auth): add authorization support for s3 files provider (#3265) # What does this PR do? adds support for authorized users to the s3 files provider ## Test Plan existing and new unit tests --- .../providers/remote/files/s3/__init__.py | 7 +- .../providers/remote/files/s3/files.py | 20 +++-- tests/unit/providers/files/conftest.py | 62 +++++++++++++ tests/unit/providers/files/test_s3_files.py | 64 ++----------- .../providers/files/test_s3_files_auth.py | 89 +++++++++++++++++++ 5 files changed, 172 insertions(+), 70 deletions(-) create mode 100644 tests/unit/providers/files/conftest.py create mode 100644 tests/unit/providers/files/test_s3_files_auth.py diff --git a/llama_stack/providers/remote/files/s3/__init__.py b/llama_stack/providers/remote/files/s3/__init__.py index 3f5dfc88a..7027f1db3 100644 --- a/llama_stack/providers/remote/files/s3/__init__.py +++ b/llama_stack/providers/remote/files/s3/__init__.py @@ -6,15 +6,14 @@ from typing import Any -from llama_stack.core.datatypes import Api +from llama_stack.core.datatypes import AccessRule, Api from .config import S3FilesImplConfig -async def get_adapter_impl(config: S3FilesImplConfig, deps: dict[Api, Any]): +async def get_adapter_impl(config: S3FilesImplConfig, deps: dict[Api, Any], policy: list[AccessRule] | None = None): from .files import S3FilesImpl - # TODO: authorization policies and user separation - impl = S3FilesImpl(config) + impl = S3FilesImpl(config, policy or []) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/files/s3/files.py b/llama_stack/providers/remote/files/s3/files.py index 52e0cbbf4..0451f74ea 100644 --- a/llama_stack/providers/remote/files/s3/files.py +++ b/llama_stack/providers/remote/files/s3/files.py @@ -21,8 +21,10 @@ from llama_stack.apis.files import ( OpenAIFileObject, OpenAIFilePurpose, ) +from llama_stack.core.datatypes import AccessRule from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType -from llama_stack.providers.utils.sqlstore.sqlstore import SqlStore, sqlstore_impl +from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore +from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl from .config import S3FilesImplConfig @@ -89,16 +91,17 @@ class S3FilesImpl(Files): # TODO: implement expiration, for now a silly offset _SILLY_EXPIRATION_OFFSET = 100 * 365 * 24 * 60 * 60 - def __init__(self, config: S3FilesImplConfig) -> None: + def __init__(self, config: S3FilesImplConfig, policy: list[AccessRule]) -> None: self._config = config + self.policy = policy self._client: boto3.client | None = None - self._sql_store: SqlStore | None = None + self._sql_store: AuthorizedSqlStore | None = None async def initialize(self) -> None: self._client = _create_s3_client(self._config) await _create_bucket_if_not_exists(self._client, self._config) - self._sql_store = sqlstore_impl(self._config.metadata_store) + self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store)) await self._sql_store.create_table( "openai_files", { @@ -121,7 +124,7 @@ class S3FilesImpl(Files): return self._client @property - def sql_store(self) -> SqlStore: + def sql_store(self) -> AuthorizedSqlStore: assert self._sql_store is not None, "Provider not initialized" return self._sql_store @@ -189,6 +192,7 @@ class S3FilesImpl(Files): paginated_result = await self.sql_store.fetch_all( table="openai_files", + policy=self.policy, where=where_conditions if where_conditions else None, order_by=[("created_at", order.value)], cursor=("id", after) if after else None, @@ -216,7 +220,7 @@ class S3FilesImpl(Files): ) async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject: - row = await self.sql_store.fetch_one("openai_files", where={"id": file_id}) + row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id}) if not row: raise ResourceNotFoundError(file_id, "File", "files.list()") @@ -230,7 +234,7 @@ class S3FilesImpl(Files): ) async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse: - row = await self.sql_store.fetch_one("openai_files", where={"id": file_id}) + row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id}) if not row: raise ResourceNotFoundError(file_id, "File", "files.list()") @@ -248,7 +252,7 @@ class S3FilesImpl(Files): return OpenAIFileDeleteResponse(id=file_id, deleted=True) async def openai_retrieve_file_content(self, file_id: str) -> Response: - row = await self.sql_store.fetch_one("openai_files", where={"id": file_id}) + row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id}) if not row: raise ResourceNotFoundError(file_id, "File", "files.list()") diff --git a/tests/unit/providers/files/conftest.py b/tests/unit/providers/files/conftest.py new file mode 100644 index 000000000..46282e3dc --- /dev/null +++ b/tests/unit/providers/files/conftest.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import boto3 +import pytest +from moto import mock_aws + +from llama_stack.providers.remote.files.s3 import S3FilesImplConfig, get_adapter_impl +from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig + + +class MockUploadFile: + def __init__(self, content: bytes, filename: str, content_type: str = "text/plain"): + self.content = content + self.filename = filename + self.content_type = content_type + + async def read(self): + return self.content + + +@pytest.fixture +def sample_text_file(): + content = b"Hello, this is a test file for the S3 Files API!" + return MockUploadFile(content, "sample_text_file-0.txt") + + +@pytest.fixture +def sample_text_file2(): + content = b"Hello, this is a second test file for the S3 Files API!" + return MockUploadFile(content, "sample_text_file-1.txt") + + +@pytest.fixture +def s3_config(tmp_path): + db_path = tmp_path / "s3_files_metadata.db" + + return S3FilesImplConfig( + bucket_name=f"test-bucket-{tmp_path.name}", + region="not-a-region", + auto_create_bucket=True, + metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix()), + ) + + +@pytest.fixture +def s3_client(): + # we use `with mock_aws()` because @mock_aws decorator does not support + # being a generator + with mock_aws(): + # must yield or the mock will be reset before it is used + yield boto3.client("s3") + + +@pytest.fixture +async def s3_provider(s3_config, s3_client): # s3_client provides the moto mock, don't remove it + provider = await get_adapter_impl(s3_config, {}) + yield provider + await provider.shutdown() diff --git a/tests/unit/providers/files/test_s3_files.py b/tests/unit/providers/files/test_s3_files.py index daa250f10..3bd4836df 100644 --- a/tests/unit/providers/files/test_s3_files.py +++ b/tests/unit/providers/files/test_s3_files.py @@ -6,63 +6,11 @@ from unittest.mock import patch -import boto3 import pytest from botocore.exceptions import ClientError -from moto import mock_aws from llama_stack.apis.common.errors import ResourceNotFoundError from llama_stack.apis.files import OpenAIFilePurpose -from llama_stack.providers.remote.files.s3 import ( - S3FilesImplConfig, - get_adapter_impl, -) -from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig - - -class MockUploadFile: - def __init__(self, content: bytes, filename: str, content_type: str = "text/plain"): - self.content = content - self.filename = filename - self.content_type = content_type - - async def read(self): - return self.content - - -@pytest.fixture -def s3_config(tmp_path): - db_path = tmp_path / "s3_files_metadata.db" - - return S3FilesImplConfig( - bucket_name="test-bucket", - region="not-a-region", - auto_create_bucket=True, - metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix()), - ) - - -@pytest.fixture -def s3_client(): - """Create a mocked S3 client for testing.""" - # we use `with mock_aws()` because @mock_aws decorator does not support being a generator - with mock_aws(): - # must yield or the mock will be reset before it is used - yield boto3.client("s3") - - -@pytest.fixture -async def s3_provider(s3_config, s3_client): - """Create an S3 files provider with mocked S3 for testing.""" - provider = await get_adapter_impl(s3_config, {}) - yield provider - await provider.shutdown() - - -@pytest.fixture -def sample_text_file(): - content = b"Hello, this is a test file for the S3 Files API!" - return MockUploadFile(content, "sample_text_file.txt") class TestS3FilesImpl: @@ -143,7 +91,7 @@ class TestS3FilesImpl: s3_client.head_object(Bucket=s3_config.bucket_name, Key=uploaded.id) assert exc_info.value.response["Error"]["Code"] == "404" - async def test_list_files(self, s3_provider, sample_text_file): + async def test_list_files(self, s3_provider, sample_text_file, sample_text_file2): """Test listing files after uploading some.""" sample_text_file.filename = "test_list_files_with_content_file1" file1 = await s3_provider.openai_upload_file( @@ -151,9 +99,9 @@ class TestS3FilesImpl: purpose=OpenAIFilePurpose.ASSISTANTS, ) - file2_content = MockUploadFile(b"Second file content", "test_list_files_with_content_file2") + sample_text_file2.filename = "test_list_files_with_content_file2" file2 = await s3_provider.openai_upload_file( - file=file2_content, + file=sample_text_file2, purpose=OpenAIFilePurpose.BATCH, ) @@ -164,7 +112,7 @@ class TestS3FilesImpl: assert file1.id in file_ids assert file2.id in file_ids - async def test_list_files_with_purpose_filter(self, s3_provider, sample_text_file): + async def test_list_files_with_purpose_filter(self, s3_provider, sample_text_file, sample_text_file2): """Test listing files with purpose filter.""" sample_text_file.filename = "test_list_files_with_purpose_filter_file1" file1 = await s3_provider.openai_upload_file( @@ -172,9 +120,9 @@ class TestS3FilesImpl: purpose=OpenAIFilePurpose.ASSISTANTS, ) - file2_content = MockUploadFile(b"Batch file content", "test_list_files_with_purpose_filter_file2") + sample_text_file2.filename = "test_list_files_with_purpose_filter_file2" await s3_provider.openai_upload_file( - file=file2_content, + file=sample_text_file2, purpose=OpenAIFilePurpose.BATCH, ) diff --git a/tests/unit/providers/files/test_s3_files_auth.py b/tests/unit/providers/files/test_s3_files_auth.py new file mode 100644 index 000000000..6097f2808 --- /dev/null +++ b/tests/unit/providers/files/test_s3_files_auth.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import patch + +import pytest + +from llama_stack.apis.common.errors import ResourceNotFoundError +from llama_stack.apis.files import OpenAIFilePurpose +from llama_stack.core.datatypes import User +from llama_stack.providers.remote.files.s3.files import S3FilesImpl + + +async def test_listing_hides_other_users_file(s3_provider, sample_text_file): + """Listing should not show files uploaded by other users.""" + user_a = User("user-a", {"roles": ["team-a"]}) + user_b = User("user-b", {"roles": ["team-b"]}) + + with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: + mock_get_user.return_value = user_a + uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS) + + with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: + mock_get_user.return_value = user_b + listed = await s3_provider.openai_list_files() + assert all(f.id != uploaded.id for f in listed.data) + + +@pytest.mark.parametrize( + "op", + [S3FilesImpl.openai_retrieve_file, S3FilesImpl.openai_retrieve_file_content, S3FilesImpl.openai_delete_file], + ids=["retrieve", "content", "delete"], +) +async def test_cannot_access_other_user_file(s3_provider, sample_text_file, op): + """Operations (metadata/content/delete) on another user's file should raise ResourceNotFoundError. + + `op` is an async callable (provider, file_id) -> awaits the requested operation. + """ + user_a = User("user-a", {"roles": ["team-a"]}) + user_b = User("user-b", {"roles": ["team-b"]}) + + with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: + mock_get_user.return_value = user_a + uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS) + + with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: + mock_get_user.return_value = user_b + with pytest.raises(ResourceNotFoundError): + await op(s3_provider, uploaded.id) + + +async def test_shared_role_allows_listing(s3_provider, sample_text_file): + """Listing should show files uploaded by other users when roles are shared.""" + user_a = User("user-a", {"roles": ["shared-role"]}) + user_b = User("user-b", {"roles": ["shared-role"]}) + + with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: + mock_get_user.return_value = user_a + uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS) + + with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: + mock_get_user.return_value = user_b + listed = await s3_provider.openai_list_files() + assert any(f.id == uploaded.id for f in listed.data) + + +@pytest.mark.parametrize( + "op", + [S3FilesImpl.openai_retrieve_file, S3FilesImpl.openai_retrieve_file_content, S3FilesImpl.openai_delete_file], + ids=["retrieve", "content", "delete"], +) +async def test_shared_role_allows_access(s3_provider, sample_text_file, op): + """Operations (metadata/content/delete) on another user's file should succeed when users share a role. + + `op` is an async callable (provider, file_id) -> awaits the requested operation. + """ + user_x = User("user-x", {"roles": ["shared-role"]}) + user_y = User("user-y", {"roles": ["shared-role"]}) + + with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: + mock_get_user.return_value = user_x + uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS) + + with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: + mock_get_user.return_value = user_y + await op(s3_provider, uploaded.id) From 3130ca0a787bf9d6bf936229fcf4c334d58b8a70 Mon Sep 17 00:00:00 2001 From: IAN MILLER <75687988+r3v5@users.noreply.github.com> Date: Fri, 29 Aug 2025 15:30:12 +0100 Subject: [PATCH 2/3] feat: implement keyword, vector and hybrid search inside vector stores for PGVector provider (#3064) # What does this PR do? The purpose of this task is to implement `openai/v1/vector_stores/{vector_store_id}/search` for PGVector provider. It involves implementing vector similarity search, keyword search and hybrid search for `PGVectorIndex`. Closes #3006 ## Test Plan Run unit tests: ` ./scripts/unit-tests.sh ` Run integration tests for openai vector stores: 1. Export env vars: ``` export ENABLE_PGVECTOR=true export PGVECTOR_HOST=localhost export PGVECTOR_PORT=5432 export PGVECTOR_DB=llamastack export PGVECTOR_USER=llamastack export PGVECTOR_PASSWORD=llamastack ``` 2. Create DB: ``` psql -h localhost -U postgres -c "CREATE ROLE llamastack LOGIN PASSWORD 'llamastack';" psql -h localhost -U postgres -c "CREATE DATABASE llamastack OWNER llamastack;" psql -h localhost -U llamastack -d llamastack -c "CREATE EXTENSION IF NOT EXISTS vector;" ``` 3. Install sentence-transformers: ` uv pip install sentence-transformers ` 4. Run: ``` uv run --group test pytest -s -v --stack-config="inference=inline::sentence-transformers,vector_io=remote::pgvector" --embedding-model sentence-transformers/all-MiniLM-L6-v2 tests/integration/vector_io/test_openai_vector_stores.py ``` Inspect PGVector vector stores (optional): ``` psql llamastack psql (14.18 (Homebrew)) Type "help" for help. llamastack=# \z Access privileges Schema | Name | Type | Access privileges | Column privileges | Policies --------+------------------------------------------------------+-------+-------------------+-------------------+---------- public | llamastack_kvstore | table | | | public | metadata_store | table | | | public | vector_store_pgvector_main | table | | | public | vector_store_vs_1dfbc061_1f4d_4497_9165_ecba2622ba3a | table | | | public | vector_store_vs_2085a9fb_1822_4e42_a277_c6a685843fa7 | table | | | public | vector_store_vs_2b3dae46_38be_462a_afd6_37ee5fe661b1 | table | | | public | vector_store_vs_2f438de6_f606_4561_9d50_ef9160eb9060 | table | | | public | vector_store_vs_3eeca564_2580_4c68_bfea_83dc57e31214 | table | | | public | vector_store_vs_53942163_05f3_40e0_83c0_0997c64613da | table | | | public | vector_store_vs_545bac75_8950_4ff1_b084_e221192d4709 | table | | | public | vector_store_vs_688a37d8_35b2_4298_a035_bfedf5b21f86 | table | | | public | vector_store_vs_70624d9a_f6ac_4c42_b8ab_0649473c6600 | table | | | public | vector_store_vs_73fc1dd2_e942_4972_afb1_1e177b591ac2 | table | | | public | vector_store_vs_9d464949_d51f_49db_9f87_e033b8b84ac9 | table | | | public | vector_store_vs_a1e4d724_5162_4d6d_a6c0_bdafaf6b76ec | table | | | public | vector_store_vs_a328fb1b_1a21_480f_9624_ffaa60fb6672 | table | | | public | vector_store_vs_a8981bf0_2e66_4445_a267_a8fff442db53 | table | | | public | vector_store_vs_ccd4b6a4_1efd_4984_ad03_e7ff8eadb296 | table | | | public | vector_store_vs_cd6420a4_a1fc_4cec_948c_1413a26281c9 | table | | | public | vector_store_vs_cd709284_e5cf_4a88_aba5_dc76a35364bd | table | | | public | vector_store_vs_d7a4548e_fbc1_44d7_b2ec_b664417f2a46 | table | | | public | vector_store_vs_e7f73231_414c_4523_886c_d1174eee836e | table | | | public | vector_store_vs_ffd53588_819f_47e8_bb9d_954af6f7833d | table | | | (23 rows) llamastack=# ``` Co-authored-by: Francisco Arceo --- .../providers/vector_io/remote_pgvector.md | 73 ++++++ .../providers/vector_io/remote_weaviate.md | 1 + llama_stack/providers/registry/vector_io.py | 74 ++++++ .../remote/vector_io/pgvector/pgvector.py | 230 ++++++++++++++-- .../providers/utils/vector_io/vector_utils.py | 119 +++++++++ pyproject.toml | 2 + .../vector_io/test_openai_vector_stores.py | 2 + .../providers/utils/memory/test_reranking.py | 248 ++++++++++++++++++ tests/unit/providers/vector_io/conftest.py | 121 ++++++++- .../vector_io/remote/test_pgvector.py | 138 ++++++++++ uv.lock | 35 +++ 11 files changed, 1014 insertions(+), 29 deletions(-) create mode 100644 tests/unit/providers/utils/memory/test_reranking.py create mode 100644 tests/unit/providers/vector_io/remote/test_pgvector.py diff --git a/docs/source/providers/vector_io/remote_pgvector.md b/docs/source/providers/vector_io/remote_pgvector.md index 74f588a13..6312edabc 100644 --- a/docs/source/providers/vector_io/remote_pgvector.md +++ b/docs/source/providers/vector_io/remote_pgvector.md @@ -12,6 +12,60 @@ That means you'll get fast and efficient vector retrieval. - Easy to use - Fully integrated with Llama Stack +There are three implementations of search for PGVectoIndex available: + +1. Vector Search: +- How it works: + - Uses PostgreSQL's vector extension (pgvector) to perform similarity search + - Compares query embeddings against stored embeddings using Cosine distance or other distance metrics + - Eg. SQL query: SELECT document, embedding <=> %s::vector AS distance FROM table ORDER BY distance + +-Characteristics: + - Semantic understanding - finds documents similar in meaning even if they don't share keywords + - Works with high-dimensional vector embeddings (typically 768, 1024, or higher dimensions) + - Best for: Finding conceptually related content, handling synonyms, cross-language search + +2. Keyword Search +- How it works: + - Uses PostgreSQL's full-text search capabilities with tsvector and ts_rank + - Converts text to searchable tokens using to_tsvector('english', text). Default language is English. + - Eg. SQL query: SELECT document, ts_rank(tokenized_content, plainto_tsquery('english', %s)) AS score + +- Characteristics: + - Lexical matching - finds exact keyword matches and variations + - Uses GIN (Generalized Inverted Index) for fast text search performance + - Scoring: Uses PostgreSQL's ts_rank function for relevance scoring + - Best for: Exact term matching, proper names, technical terms, Boolean-style queries + +3. Hybrid Search +- How it works: + - Combines both vector and keyword search results + - Runs both searches independently, then merges results using configurable reranking + +- Two reranking strategies available: + - Reciprocal Rank Fusion (RRF) - (default: 60.0) + - Weighted Average - (default: 0.5) + +- Characteristics: + - Best of both worlds: semantic understanding + exact matching + - Documents appearing in both searches get boosted scores + - Configurable balance between semantic and lexical matching + - Best for: General-purpose search where you want both precision and recall + +4. Database Schema +The PGVector implementation stores data optimized for all three search types: +CREATE TABLE vector_store_xxx ( + id TEXT PRIMARY KEY, + document JSONB, -- Original document + embedding vector(dimension), -- For vector search + content_text TEXT, -- Raw text content + tokenized_content TSVECTOR -- For keyword search +); + +-- Indexes for performance +CREATE INDEX content_gin_idx ON table USING GIN(tokenized_content); -- Keyword search +-- Vector index created automatically by pgvector + ## Usage To use PGVector in your Llama Stack project, follow these steps: @@ -20,6 +74,25 @@ To use PGVector in your Llama Stack project, follow these steps: 2. Configure your Llama Stack project to use pgvector. (e.g. remote::pgvector). 3. Start storing and querying vectors. +## This is an example how you can set up your environment for using PGVector + +1. Export env vars: +```bash +export ENABLE_PGVECTOR=true +export PGVECTOR_HOST=localhost +export PGVECTOR_PORT=5432 +export PGVECTOR_DB=llamastack +export PGVECTOR_USER=llamastack +export PGVECTOR_PASSWORD=llamastack +``` + +2. Create DB: +```bash +psql -h localhost -U postgres -c "CREATE ROLE llamastack LOGIN PASSWORD 'llamastack';" +psql -h localhost -U postgres -c "CREATE DATABASE llamastack OWNER llamastack;" +psql -h localhost -U llamastack -d llamastack -c "CREATE EXTENSION IF NOT EXISTS vector;" +``` + ## Installation You can install PGVector using docker: diff --git a/docs/source/providers/vector_io/remote_weaviate.md b/docs/source/providers/vector_io/remote_weaviate.md index c59487cf6..8fb0f7c11 100644 --- a/docs/source/providers/vector_io/remote_weaviate.md +++ b/docs/source/providers/vector_io/remote_weaviate.md @@ -17,6 +17,7 @@ Weaviate supports: - Metadata filtering - Multi-modal retrieval + ## Usage To use Weaviate in your Llama Stack project, follow these steps: diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py index 70148eb15..511734d57 100644 --- a/llama_stack/providers/registry/vector_io.py +++ b/llama_stack/providers/registry/vector_io.py @@ -404,6 +404,60 @@ That means you'll get fast and efficient vector retrieval. - Easy to use - Fully integrated with Llama Stack +There are three implementations of search for PGVectoIndex available: + +1. Vector Search: +- How it works: + - Uses PostgreSQL's vector extension (pgvector) to perform similarity search + - Compares query embeddings against stored embeddings using Cosine distance or other distance metrics + - Eg. SQL query: SELECT document, embedding <=> %s::vector AS distance FROM table ORDER BY distance + +-Characteristics: + - Semantic understanding - finds documents similar in meaning even if they don't share keywords + - Works with high-dimensional vector embeddings (typically 768, 1024, or higher dimensions) + - Best for: Finding conceptually related content, handling synonyms, cross-language search + +2. Keyword Search +- How it works: + - Uses PostgreSQL's full-text search capabilities with tsvector and ts_rank + - Converts text to searchable tokens using to_tsvector('english', text). Default language is English. + - Eg. SQL query: SELECT document, ts_rank(tokenized_content, plainto_tsquery('english', %s)) AS score + +- Characteristics: + - Lexical matching - finds exact keyword matches and variations + - Uses GIN (Generalized Inverted Index) for fast text search performance + - Scoring: Uses PostgreSQL's ts_rank function for relevance scoring + - Best for: Exact term matching, proper names, technical terms, Boolean-style queries + +3. Hybrid Search +- How it works: + - Combines both vector and keyword search results + - Runs both searches independently, then merges results using configurable reranking + +- Two reranking strategies available: + - Reciprocal Rank Fusion (RRF) - (default: 60.0) + - Weighted Average - (default: 0.5) + +- Characteristics: + - Best of both worlds: semantic understanding + exact matching + - Documents appearing in both searches get boosted scores + - Configurable balance between semantic and lexical matching + - Best for: General-purpose search where you want both precision and recall + +4. Database Schema +The PGVector implementation stores data optimized for all three search types: +CREATE TABLE vector_store_xxx ( + id TEXT PRIMARY KEY, + document JSONB, -- Original document + embedding vector(dimension), -- For vector search + content_text TEXT, -- Raw text content + tokenized_content TSVECTOR -- For keyword search +); + +-- Indexes for performance +CREATE INDEX content_gin_idx ON table USING GIN(tokenized_content); -- Keyword search +-- Vector index created automatically by pgvector + ## Usage To use PGVector in your Llama Stack project, follow these steps: @@ -412,6 +466,25 @@ To use PGVector in your Llama Stack project, follow these steps: 2. Configure your Llama Stack project to use pgvector. (e.g. remote::pgvector). 3. Start storing and querying vectors. +## This is an example how you can set up your environment for using PGVector + +1. Export env vars: +```bash +export ENABLE_PGVECTOR=true +export PGVECTOR_HOST=localhost +export PGVECTOR_PORT=5432 +export PGVECTOR_DB=llamastack +export PGVECTOR_USER=llamastack +export PGVECTOR_PASSWORD=llamastack +``` + +2. Create DB: +```bash +psql -h localhost -U postgres -c "CREATE ROLE llamastack LOGIN PASSWORD 'llamastack';" +psql -h localhost -U postgres -c "CREATE DATABASE llamastack OWNER llamastack;" +psql -h localhost -U llamastack -d llamastack -c "CREATE EXTENSION IF NOT EXISTS vector;" +``` + ## Installation You can install PGVector using docker: @@ -449,6 +522,7 @@ Weaviate supports: - Metadata filtering - Multi-modal retrieval + ## Usage To use Weaviate in your Llama Stack project, follow these steps: diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index 1c8d361c2..1c140e782 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import heapq from typing import Any import psycopg2 @@ -23,6 +24,9 @@ from llama_stack.apis.vector_io import ( ) from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate +from llama_stack.providers.utils.inference.prompt_adapter import ( + interleaved_content_as_str, +) from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin @@ -31,6 +35,7 @@ from llama_stack.providers.utils.memory.vector_store import ( EmbeddingIndex, VectorDBWithIndex, ) +from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator, sanitize_collection_name from .config import PGVectorVectorIOConfig @@ -72,25 +77,63 @@ def load_models(cur, cls): class PGVectorIndex(EmbeddingIndex): - def __init__(self, vector_db: VectorDB, dimension: int, conn, kvstore: KVStore | None = None): - self.conn = conn - with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: - # Sanitize the table name by replacing hyphens with underscores - # SQL doesn't allow hyphens in table names, and vector_db.identifier may contain hyphens - # when created with patterns like "test-vector-db-{uuid4()}" - sanitized_identifier = vector_db.identifier.replace("-", "_") - self.table_name = f"vector_store_{sanitized_identifier}" - self.kvstore = kvstore + # reference: https://github.com/pgvector/pgvector?tab=readme-ov-file#querying + PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION: dict[str, str] = { + "L2": "<->", + "L1": "<+>", + "COSINE": "<=>", + "INNER_PRODUCT": "<#>", + "HAMMING": "<~>", + "JACCARD": "<%>", + } - cur.execute( - f""" - CREATE TABLE IF NOT EXISTS {self.table_name} ( - id TEXT PRIMARY KEY, - document JSONB, - embedding vector({dimension}) + def __init__( + self, + vector_db: VectorDB, + dimension: int, + conn: psycopg2.extensions.connection, + kvstore: KVStore | None = None, + distance_metric: str = "COSINE", + ): + self.vector_db = vector_db + self.dimension = dimension + self.conn = conn + self.kvstore = kvstore + self.check_distance_metric_availability(distance_metric) + self.distance_metric = distance_metric + self.table_name = None + + async def initialize(self) -> None: + try: + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + # Sanitize the table name by replacing hyphens with underscores + # SQL doesn't allow hyphens in table names, and vector_db.identifier may contain hyphens + # when created with patterns like "test-vector-db-{uuid4()}" + sanitized_identifier = sanitize_collection_name(self.vector_db.identifier) + self.table_name = f"vs_{sanitized_identifier}" + + cur.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.table_name} ( + id TEXT PRIMARY KEY, + document JSONB, + embedding vector({self.dimension}), + content_text TEXT, + tokenized_content TSVECTOR + ) + """ ) - """ - ) + + # Create GIN index for full-text search performance + cur.execute( + f""" + CREATE INDEX IF NOT EXISTS {self.table_name}_content_gin_idx + ON {self.table_name} USING GIN(tokenized_content) + """ + ) + except Exception as e: + log.exception(f"Error creating PGVectorIndex for vector_db: {self.vector_db.identifier}") + raise RuntimeError(f"Error creating PGVectorIndex for vector_db: {self.vector_db.identifier}") from e async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): assert len(chunks) == len(embeddings), ( @@ -99,29 +142,49 @@ class PGVectorIndex(EmbeddingIndex): values = [] for i, chunk in enumerate(chunks): + content_text = interleaved_content_as_str(chunk.content) values.append( ( f"{chunk.chunk_id}", Json(chunk.model_dump()), embeddings[i].tolist(), + content_text, + content_text, # Pass content_text twice - once for content_text column, once for to_tsvector function. Eg. to_tsvector(content_text) = tokenized_content ) ) query = sql.SQL( f""" - INSERT INTO {self.table_name} (id, document, embedding) + INSERT INTO {self.table_name} (id, document, embedding, content_text, tokenized_content) VALUES %s - ON CONFLICT (id) DO UPDATE SET embedding = EXCLUDED.embedding, document = EXCLUDED.document + ON CONFLICT (id) DO UPDATE SET + embedding = EXCLUDED.embedding, + document = EXCLUDED.document, + content_text = EXCLUDED.content_text, + tokenized_content = EXCLUDED.tokenized_content """ ) with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: - execute_values(cur, query, values, template="(%s, %s, %s::vector)") + execute_values(cur, query, values, template="(%s, %s, %s::vector, %s, to_tsvector('english', %s))") async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + """ + Performs vector similarity search using PostgreSQL's search function. Default distance metric is COSINE. + + Args: + embedding: The query embedding vector + k: Number of results to return + score_threshold: Minimum similarity score threshold + + Returns: + QueryChunksResponse with combined results + """ + pgvector_search_function = self.get_pgvector_search_function() + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: cur.execute( f""" - SELECT document, embedding <-> %s::vector AS distance + SELECT document, embedding {pgvector_search_function} %s::vector AS distance FROM {self.table_name} ORDER BY distance LIMIT %s @@ -147,7 +210,40 @@ class PGVectorIndex(EmbeddingIndex): k: int, score_threshold: float, ) -> QueryChunksResponse: - raise NotImplementedError("Keyword search is not supported in PGVector") + """ + Performs keyword-based search using PostgreSQL's full-text search with ts_rank scoring. + + Args: + query_string: The text query for keyword search + k: Number of results to return + score_threshold: Minimum similarity score threshold + + Returns: + QueryChunksResponse with combined results + """ + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + # Use plainto_tsquery to handle user input safely and ts_rank for relevance scoring + cur.execute( + f""" + SELECT document, ts_rank(tokenized_content, plainto_tsquery('english', %s)) AS score + FROM {self.table_name} + WHERE tokenized_content @@ plainto_tsquery('english', %s) + ORDER BY score DESC + LIMIT %s + """, + (query_string, query_string, k), + ) + results = cur.fetchall() + + chunks = [] + scores = [] + for doc, score in results: + if score < score_threshold: + continue + chunks.append(Chunk(**doc)) + scores.append(float(score)) + + return QueryChunksResponse(chunks=chunks, scores=scores) async def query_hybrid( self, @@ -158,7 +254,59 @@ class PGVectorIndex(EmbeddingIndex): reranker_type: str, reranker_params: dict[str, Any] | None = None, ) -> QueryChunksResponse: - raise NotImplementedError("Hybrid search is not supported in PGVector") + """ + Hybrid search combining vector similarity and keyword search using configurable reranking. + + Args: + embedding: The query embedding vector + query_string: The text query for keyword search + k: Number of results to return + score_threshold: Minimum similarity score threshold + reranker_type: Type of reranker to use ("rrf" or "weighted") + reranker_params: Parameters for the reranker + + Returns: + QueryChunksResponse with combined results + """ + if reranker_params is None: + reranker_params = {} + + # Get results from both search methods + vector_response = await self.query_vector(embedding, k, score_threshold) + keyword_response = await self.query_keyword(query_string, k, score_threshold) + + # Convert responses to score dictionaries using chunk_id + vector_scores = { + chunk.chunk_id: score for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False) + } + keyword_scores = { + chunk.chunk_id: score + for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False) + } + + # Combine scores using the reranking utility + combined_scores = WeightedInMemoryAggregator.combine_search_results( + vector_scores, keyword_scores, reranker_type, reranker_params + ) + + # Efficient top-k selection because it only tracks the k best candidates it's seen so far + top_k_items = heapq.nlargest(k, combined_scores.items(), key=lambda x: x[1]) + + # Filter by score threshold + filtered_items = [(doc_id, score) for doc_id, score in top_k_items if score >= score_threshold] + + # Create a map of chunk_id to chunk for both responses + chunk_map = {c.chunk_id: c for c in vector_response.chunks + keyword_response.chunks} + + # Use the map to look up chunks by their IDs + chunks = [] + scores = [] + for doc_id, score in filtered_items: + if doc_id in chunk_map: + chunks.append(chunk_map[doc_id]) + scores.append(score) + + return QueryChunksResponse(chunks=chunks, scores=scores) async def delete(self): with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: @@ -170,6 +318,25 @@ class PGVectorIndex(EmbeddingIndex): with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids,)) + def get_pgvector_search_function(self) -> str: + return self.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION[self.distance_metric] + + def check_distance_metric_availability(self, distance_metric: str) -> None: + """Check if the distance metric is supported by PGVector. + + Args: + distance_metric: The distance metric to check + + Raises: + ValueError: If the distance metric is not supported + """ + if distance_metric not in self.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION: + supported_metrics = list(self.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION.keys()) + raise ValueError( + f"Distance metric '{distance_metric}' is not supported by PGVector. " + f"Supported metrics are: {', '.join(supported_metrics)}" + ) + class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): def __init__( @@ -185,8 +352,8 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco self.files_api = files_api self.kvstore: KVStore | None = None self.vector_db_store = None - self.openai_vector_store: dict[str, dict[str, Any]] = {} - self.metadatadata_collection_name = "openai_vector_stores_metadata" + self.openai_vector_stores: dict[str, dict[str, Any]] = {} + self.metadata_collection_name = "openai_vector_stores_metadata" async def initialize(self) -> None: log.info(f"Initializing PGVector memory adapter with config: {self.config}") @@ -233,9 +400,13 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco upsert_models(self.conn, [(vector_db.identifier, vector_db)]) # Create and cache the PGVector index table for the vector DB + pgvector_index = PGVectorIndex( + vector_db=vector_db, dimension=vector_db.embedding_dimension, conn=self.conn, kvstore=self.kvstore + ) + await pgvector_index.initialize() index = VectorDBWithIndex( vector_db, - index=PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn, kvstore=self.kvstore), + index=pgvector_index, inference_api=self.inference_api, ) self.cache[vector_db.identifier] = index @@ -272,8 +443,15 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco if vector_db_id in self.cache: return self.cache[vector_db_id] + if self.vector_db_store is None: + raise VectorStoreNotFoundError(vector_db_id) + vector_db = await self.vector_db_store.get_vector_db(vector_db_id) + if not vector_db: + raise VectorStoreNotFoundError(vector_db_id) + index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn) + await index.initialize() self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api) return self.cache[vector_db_id] diff --git a/llama_stack/providers/utils/vector_io/vector_utils.py b/llama_stack/providers/utils/vector_io/vector_utils.py index f2888043e..e55ac75ae 100644 --- a/llama_stack/providers/utils/vector_io/vector_utils.py +++ b/llama_stack/providers/utils/vector_io/vector_utils.py @@ -37,3 +37,122 @@ def sanitize_collection_name(name: str, weaviate_format=False) -> str: else: s = proper_case(re.sub(r"[^a-zA-Z0-9]", "", name)) return s + + +class WeightedInMemoryAggregator: + @staticmethod + def _normalize_scores(scores: dict[str, float]) -> dict[str, float]: + """ + Normalize scores to 0-1 range using min-max normalization. + + Args: + scores: dictionary of scores with document IDs as keys and scores as values + + Returns: + Normalized scores with document IDs as keys and normalized scores as values + """ + if not scores: + return {} + min_score, max_score = min(scores.values()), max(scores.values()) + score_range = max_score - min_score + if score_range > 0: + return {doc_id: (score - min_score) / score_range for doc_id, score in scores.items()} + return dict.fromkeys(scores, 1.0) + + @staticmethod + def weighted_rerank( + vector_scores: dict[str, float], + keyword_scores: dict[str, float], + alpha: float = 0.5, + ) -> dict[str, float]: + """ + Rerank via weighted average of scores. + + Args: + vector_scores: scores from vector search + keyword_scores: scores from keyword search + alpha: weight factor between 0 and 1 (default: 0.5) + 0 = keyword only, 1 = vector only, 0.5 = equal weight + + Returns: + All unique document IDs with weighted combined scores + """ + all_ids = set(vector_scores.keys()) | set(keyword_scores.keys()) + normalized_vector_scores = WeightedInMemoryAggregator._normalize_scores(vector_scores) + normalized_keyword_scores = WeightedInMemoryAggregator._normalize_scores(keyword_scores) + + # Weighted formula: score = (1-alpha) * keyword_score + alpha * vector_score + # alpha=0 means keyword only, alpha=1 means vector only + return { + doc_id: ((1 - alpha) * normalized_keyword_scores.get(doc_id, 0.0)) + + (alpha * normalized_vector_scores.get(doc_id, 0.0)) + for doc_id in all_ids + } + + @staticmethod + def rrf_rerank( + vector_scores: dict[str, float], + keyword_scores: dict[str, float], + impact_factor: float = 60.0, + ) -> dict[str, float]: + """ + Rerank via Reciprocal Rank Fusion. + + Args: + vector_scores: scores from vector search + keyword_scores: scores from keyword search + impact_factor: impact factor for RRF (default: 60.0) + + Returns: + All unique document IDs with RRF combined scores + """ + + # Convert scores to ranks + vector_ranks = { + doc_id: i + 1 + for i, (doc_id, _) in enumerate(sorted(vector_scores.items(), key=lambda x: x[1], reverse=True)) + } + keyword_ranks = { + doc_id: i + 1 + for i, (doc_id, _) in enumerate(sorted(keyword_scores.items(), key=lambda x: x[1], reverse=True)) + } + + all_ids = set(vector_scores.keys()) | set(keyword_scores.keys()) + rrf_scores = {} + for doc_id in all_ids: + vector_rank = vector_ranks.get(doc_id, float("inf")) + keyword_rank = keyword_ranks.get(doc_id, float("inf")) + + # RRF formula: score = 1/(k + r) where k is impact_factor (default: 60.0) and r is the rank + rrf_scores[doc_id] = (1.0 / (impact_factor + vector_rank)) + (1.0 / (impact_factor + keyword_rank)) + return rrf_scores + + @staticmethod + def combine_search_results( + vector_scores: dict[str, float], + keyword_scores: dict[str, float], + reranker_type: str = "rrf", + reranker_params: dict[str, float] | None = None, + ) -> dict[str, float]: + """ + Combine vector and keyword search results using specified reranking strategy. + + Args: + vector_scores: scores from vector search + keyword_scores: scores from keyword search + reranker_type: type of reranker to use (default: RERANKER_TYPE_RRF) + reranker_params: parameters for the reranker + + Returns: + All unique document IDs with combined scores + """ + if reranker_params is None: + reranker_params = {} + + if reranker_type == "weighted": + alpha = reranker_params.get("alpha", 0.5) + return WeightedInMemoryAggregator.weighted_rerank(vector_scores, keyword_scores, alpha) + else: + # Default to RRF for None, RRF, or any unknown types + impact_factor = reranker_params.get("impact_factor", 60.0) + return WeightedInMemoryAggregator.rrf_rerank(vector_scores, keyword_scores, impact_factor) diff --git a/pyproject.toml b/pyproject.toml index dd8529546..3ab042a8e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,6 +84,7 @@ unit = [ "openai", "aiosqlite", "aiohttp", + "psycopg2-binary>=2.9.0", "pypdf", "mcp", "chardet", @@ -111,6 +112,7 @@ test = [ "torch>=2.6.0", "torchvision>=0.21.0", "chardet", + "psycopg2-binary>=2.9.0", "pypdf", "mcp", "datasets", diff --git a/tests/integration/vector_io/test_openai_vector_stores.py b/tests/integration/vector_io/test_openai_vector_stores.py index 82868164f..c67036eab 100644 --- a/tests/integration/vector_io/test_openai_vector_stores.py +++ b/tests/integration/vector_io/test_openai_vector_stores.py @@ -57,11 +57,13 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode "inline::sqlite-vec", "remote::milvus", "inline::milvus", + "remote::pgvector", ], "hybrid": [ "inline::sqlite-vec", "inline::milvus", "remote::milvus", + "remote::pgvector", ], } supported_providers = search_mode_support.get(search_mode, []) diff --git a/tests/unit/providers/utils/memory/test_reranking.py b/tests/unit/providers/utils/memory/test_reranking.py new file mode 100644 index 000000000..02d7a1b6a --- /dev/null +++ b/tests/unit/providers/utils/memory/test_reranking.py @@ -0,0 +1,248 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from llama_stack.providers.utils.memory.vector_store import RERANKER_TYPE_RRF, RERANKER_TYPE_WEIGHTED +from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator + + +class TestNormalizeScores: + """Test cases for score normalization.""" + + def test_normalize_scores_basic(self): + """Test basic score normalization.""" + scores = {"doc1": 10.0, "doc2": 5.0, "doc3": 0.0} + normalized = WeightedInMemoryAggregator._normalize_scores(scores) + + assert normalized["doc1"] == 1.0 # Max score + assert normalized["doc3"] == 0.0 # Min score + assert normalized["doc2"] == 0.5 # Middle score + assert all(0 <= score <= 1 for score in normalized.values()) + + def test_normalize_scores_identical(self): + """Test normalization when all scores are identical.""" + scores = {"doc1": 5.0, "doc2": 5.0, "doc3": 5.0} + normalized = WeightedInMemoryAggregator._normalize_scores(scores) + + # All scores should be 1.0 when identical + assert all(score == 1.0 for score in normalized.values()) + + def test_normalize_scores_empty(self): + """Test normalization with empty scores.""" + scores = {} + normalized = WeightedInMemoryAggregator._normalize_scores(scores) + + assert normalized == {} + + def test_normalize_scores_single(self): + """Test normalization with single score.""" + scores = {"doc1": 7.5} + normalized = WeightedInMemoryAggregator._normalize_scores(scores) + + assert normalized["doc1"] == 1.0 + + +class TestWeightedRerank: + """Test cases for weighted reranking.""" + + def test_weighted_rerank_basic(self): + """Test basic weighted reranking.""" + vector_scores = {"doc1": 0.9, "doc2": 0.7, "doc3": 0.5} + keyword_scores = {"doc1": 0.6, "doc2": 0.8, "doc4": 0.9} + + combined = WeightedInMemoryAggregator.weighted_rerank(vector_scores, keyword_scores, alpha=0.5) + + # Should include all documents + expected_docs = {"doc1", "doc2", "doc3", "doc4"} + assert set(combined.keys()) == expected_docs + + # All scores should be between 0 and 1 + assert all(0 <= score <= 1 for score in combined.values()) + + # doc1 appears in both searches, should have higher combined score + assert combined["doc1"] > 0 + + def test_weighted_rerank_alpha_zero(self): + """Test weighted reranking with alpha=0 (keyword only).""" + vector_scores = {"doc1": 0.9, "doc2": 0.7, "doc3": 0.5} # All docs present in vector + keyword_scores = {"doc1": 0.1, "doc2": 0.3, "doc3": 0.9} # All docs present in keyword + + combined = WeightedInMemoryAggregator.weighted_rerank(vector_scores, keyword_scores, alpha=0.0) + + # Alpha=0 means vector scores are ignored, keyword scores dominate + # doc3 should score highest since it has highest keyword score + assert combined["doc3"] > combined["doc2"] > combined["doc1"] + + def test_weighted_rerank_alpha_one(self): + """Test weighted reranking with alpha=1 (vector only).""" + vector_scores = {"doc1": 0.9, "doc2": 0.7, "doc3": 0.5} # All docs present in vector + keyword_scores = {"doc1": 0.1, "doc2": 0.3, "doc3": 0.9} # All docs present in keyword + + combined = WeightedInMemoryAggregator.weighted_rerank(vector_scores, keyword_scores, alpha=1.0) + + # Alpha=1 means keyword scores are ignored, vector scores dominate + # doc1 should score highest since it has highest vector score + assert combined["doc1"] > combined["doc2"] > combined["doc3"] + + def test_weighted_rerank_no_overlap(self): + """Test weighted reranking with no overlapping documents.""" + vector_scores = {"doc1": 0.9, "doc2": 0.7} + keyword_scores = {"doc3": 0.8, "doc4": 0.6} + + combined = WeightedInMemoryAggregator.weighted_rerank(vector_scores, keyword_scores, alpha=0.5) + + assert len(combined) == 4 + # With min-max normalization, lowest scoring docs in each group get 0.0 + # but highest scoring docs should get positive scores + assert all(score >= 0 for score in combined.values()) + assert combined["doc1"] > 0 # highest vector score + assert combined["doc3"] > 0 # highest keyword score + + +class TestRRFRerank: + """Test cases for RRF (Reciprocal Rank Fusion) reranking.""" + + def test_rrf_rerank_basic(self): + """Test basic RRF reranking.""" + vector_scores = {"doc1": 0.9, "doc2": 0.7, "doc3": 0.5} + keyword_scores = {"doc1": 0.6, "doc2": 0.8, "doc4": 0.9} + + combined = WeightedInMemoryAggregator.rrf_rerank(vector_scores, keyword_scores, impact_factor=60.0) + + # Should include all documents + expected_docs = {"doc1", "doc2", "doc3", "doc4"} + assert set(combined.keys()) == expected_docs + + # All scores should be positive + assert all(score > 0 for score in combined.values()) + + # Documents appearing in both searches should have higher scores + # doc1 and doc2 appear in both, doc3 and doc4 appear in only one + assert combined["doc1"] > combined["doc3"] + assert combined["doc2"] > combined["doc4"] + + def test_rrf_rerank_rank_calculation(self): + """Test that RRF correctly calculates ranks.""" + # Create clear ranking order + vector_scores = {"doc1": 1.0, "doc2": 0.8, "doc3": 0.6} # Ranks: 1, 2, 3 + keyword_scores = {"doc1": 0.5, "doc2": 1.0, "doc3": 0.7} # Ranks: 3, 1, 2 + + combined = WeightedInMemoryAggregator.rrf_rerank(vector_scores, keyword_scores, impact_factor=60.0) + + # doc1: rank 1 in vector, rank 3 in keyword + # doc2: rank 2 in vector, rank 1 in keyword + # doc3: rank 3 in vector, rank 2 in keyword + + # doc2 should have the highest combined score (ranks 2+1=3) + # followed by doc1 (ranks 1+3=4) and doc3 (ranks 3+2=5) + # Remember: lower rank sum = higher RRF score + assert combined["doc2"] > combined["doc1"] > combined["doc3"] + + def test_rrf_rerank_impact_factor(self): + """Test that impact factor affects RRF scores.""" + vector_scores = {"doc1": 0.9, "doc2": 0.7} + keyword_scores = {"doc1": 0.8, "doc2": 0.6} + + combined_low = WeightedInMemoryAggregator.rrf_rerank(vector_scores, keyword_scores, impact_factor=10.0) + combined_high = WeightedInMemoryAggregator.rrf_rerank(vector_scores, keyword_scores, impact_factor=100.0) + + # Higher impact factor should generally result in lower scores + # (because 1/(k+r) decreases as k increases) + assert combined_low["doc1"] > combined_high["doc1"] + assert combined_low["doc2"] > combined_high["doc2"] + + def test_rrf_rerank_missing_documents(self): + """Test RRF handling of documents missing from one search.""" + vector_scores = {"doc1": 0.9, "doc2": 0.7} + keyword_scores = {"doc1": 0.8, "doc3": 0.6} + + combined = WeightedInMemoryAggregator.rrf_rerank(vector_scores, keyword_scores, impact_factor=60.0) + + # Should include all documents + assert len(combined) == 3 + + # doc1 appears in both searches, should have highest score + assert combined["doc1"] > combined["doc2"] + assert combined["doc1"] > combined["doc3"] + + +class TestCombineSearchResults: + """Test cases for the main combine_search_results function.""" + + def test_combine_search_results_rrf_default(self): + """Test combining with RRF as default.""" + vector_scores = {"doc1": 0.9, "doc2": 0.7} + keyword_scores = {"doc1": 0.6, "doc3": 0.8} + + combined = WeightedInMemoryAggregator.combine_search_results(vector_scores, keyword_scores) + + # Should default to RRF + assert len(combined) == 3 + assert all(score > 0 for score in combined.values()) + + def test_combine_search_results_rrf_explicit(self): + """Test combining with explicit RRF.""" + vector_scores = {"doc1": 0.9, "doc2": 0.7} + keyword_scores = {"doc1": 0.6, "doc3": 0.8} + + combined = WeightedInMemoryAggregator.combine_search_results( + vector_scores, keyword_scores, reranker_type=RERANKER_TYPE_RRF, reranker_params={"impact_factor": 50.0} + ) + + assert len(combined) == 3 + assert all(score > 0 for score in combined.values()) + + def test_combine_search_results_weighted(self): + """Test combining with weighted reranking.""" + vector_scores = {"doc1": 0.9, "doc2": 0.7} + keyword_scores = {"doc1": 0.6, "doc3": 0.8} + + combined = WeightedInMemoryAggregator.combine_search_results( + vector_scores, keyword_scores, reranker_type=RERANKER_TYPE_WEIGHTED, reranker_params={"alpha": 0.3} + ) + + assert len(combined) == 3 + assert all(0 <= score <= 1 for score in combined.values()) + + def test_combine_search_results_unknown_type(self): + """Test combining with unknown reranker type defaults to RRF.""" + vector_scores = {"doc1": 0.9} + keyword_scores = {"doc2": 0.8} + + combined = WeightedInMemoryAggregator.combine_search_results( + vector_scores, keyword_scores, reranker_type="unknown_type" + ) + + # Should fall back to RRF + assert len(combined) == 2 + assert all(score > 0 for score in combined.values()) + + def test_combine_search_results_empty_params(self): + """Test combining with empty parameters.""" + vector_scores = {"doc1": 0.9} + keyword_scores = {"doc2": 0.8} + + combined = WeightedInMemoryAggregator.combine_search_results(vector_scores, keyword_scores, reranker_params={}) + + # Should use default parameters + assert len(combined) == 2 + assert all(score > 0 for score in combined.values()) + + def test_combine_search_results_empty_scores(self): + """Test combining with empty score dictionaries.""" + # Test with empty vector scores + combined = WeightedInMemoryAggregator.combine_search_results({}, {"doc1": 0.8}) + assert len(combined) == 1 + assert combined["doc1"] > 0 + + # Test with empty keyword scores + combined = WeightedInMemoryAggregator.combine_search_results({"doc1": 0.9}, {}) + assert len(combined) == 1 + assert combined["doc1"] > 0 + + # Test with both empty + combined = WeightedInMemoryAggregator.combine_search_results({}, {}) + assert len(combined) == 0 diff --git a/tests/unit/providers/vector_io/conftest.py b/tests/unit/providers/vector_io/conftest.py index f71073651..91bddd037 100644 --- a/tests/unit/providers/vector_io/conftest.py +++ b/tests/unit/providers/vector_io/conftest.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import random +from unittest.mock import AsyncMock, MagicMock, patch import numpy as np import pytest @@ -12,7 +13,7 @@ from chromadb import PersistentClient from pymilvus import MilvusClient, connections from llama_stack.apis.vector_dbs import VectorDB -from llama_stack.apis.vector_io import Chunk, ChunkMetadata +from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse from llama_stack.providers.inline.vector_io.chroma.config import ChromaVectorIOConfig from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter @@ -22,6 +23,8 @@ from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConf from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter, maybe_await from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter +from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig +from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter EMBEDDING_DIMENSION = 384 @@ -29,7 +32,7 @@ COLLECTION_PREFIX = "test_collection" MILVUS_ALIAS = "test_milvus" -@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma"]) +@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector"]) def vector_provider(request): return request.param @@ -333,15 +336,127 @@ async def qdrant_vec_index(qdrant_vec_db_path, embedding_dimension): await index.delete() +@pytest.fixture +def mock_psycopg2_connection(): + connection = MagicMock() + cursor = MagicMock() + + cursor.__enter__ = MagicMock(return_value=cursor) + cursor.__exit__ = MagicMock() + + connection.cursor.return_value = cursor + + return connection, cursor + + +@pytest.fixture +async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection): + connection, cursor = mock_psycopg2_connection + + vector_db = VectorDB( + identifier="test-vector-db", + embedding_model="test-model", + embedding_dimension=embedding_dimension, + provider_id="pgvector", + provider_resource_id="pgvector:test-vector-db", + ) + + with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"): + with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.execute_values"): + index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE") + index._test_chunks = [] + original_add_chunks = index.add_chunks + + async def mock_add_chunks(chunks, embeddings): + index._test_chunks = list(chunks) + await original_add_chunks(chunks, embeddings) + + index.add_chunks = mock_add_chunks + + async def mock_query_vector(embedding, k, score_threshold): + chunks = index._test_chunks[:k] if hasattr(index, "_test_chunks") else [] + scores = [1.0] * len(chunks) + return QueryChunksResponse(chunks=chunks, scores=scores) + + index.query_vector = mock_query_vector + + yield index + + +@pytest.fixture +async def pgvector_vec_adapter(mock_inference_api, embedding_dimension): + config = PGVectorVectorIOConfig( + host="localhost", + port=5432, + db="test_db", + user="test_user", + password="test_password", + kvstore=SqliteKVStoreConfig(), + ) + + adapter = PGVectorVectorIOAdapter(config, mock_inference_api, None) + + with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2.connect") as mock_connect: + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.__enter__ = MagicMock(return_value=mock_cursor) + mock_cursor.__exit__ = MagicMock() + mock_conn.cursor.return_value = mock_cursor + mock_conn.autocommit = True + mock_connect.return_value = mock_conn + + with patch( + "llama_stack.providers.remote.vector_io.pgvector.pgvector.check_extension_version" + ) as mock_check_version: + mock_check_version.return_value = "0.5.1" + + with patch("llama_stack.providers.utils.kvstore.kvstore_impl") as mock_kvstore_impl: + mock_kvstore = AsyncMock() + mock_kvstore_impl.return_value = mock_kvstore + + with patch.object(adapter, "initialize_openai_vector_stores", new_callable=AsyncMock): + with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.upsert_models"): + await adapter.initialize() + adapter.conn = mock_conn + + async def mock_insert_chunks(vector_db_id, chunks, ttl_seconds=None): + index = await adapter._get_and_cache_vector_db_index(vector_db_id) + if not index: + raise ValueError(f"Vector DB {vector_db_id} not found") + await index.insert_chunks(chunks) + + adapter.insert_chunks = mock_insert_chunks + + async def mock_query_chunks(vector_db_id, query, params=None): + index = await adapter._get_and_cache_vector_db_index(vector_db_id) + if not index: + raise ValueError(f"Vector DB {vector_db_id} not found") + return await index.query_chunks(query, params) + + adapter.query_chunks = mock_query_chunks + + test_vector_db = VectorDB( + identifier=f"pgvector_test_collection_{random.randint(1, 1_000_000)}", + provider_id="test_provider", + embedding_model="test_model", + embedding_dimension=embedding_dimension, + ) + await adapter.register_vector_db(test_vector_db) + adapter.test_collection_id = test_vector_db.identifier + + yield adapter + await adapter.shutdown() + + @pytest.fixture def vector_io_adapter(vector_provider, request): - """Returns the appropriate vector IO adapter based on the provider parameter.""" vector_provider_dict = { "milvus": "milvus_vec_adapter", "faiss": "faiss_vec_adapter", "sqlite_vec": "sqlite_vec_adapter", "chroma": "chroma_vec_adapter", "qdrant": "qdrant_vec_adapter", + "pgvector": "pgvector_vec_adapter", } return request.getfixturevalue(vector_provider_dict[vector_provider]) diff --git a/tests/unit/providers/vector_io/remote/test_pgvector.py b/tests/unit/providers/vector_io/remote/test_pgvector.py new file mode 100644 index 000000000..6f498bf46 --- /dev/null +++ b/tests/unit/providers/vector_io/remote/test_pgvector.py @@ -0,0 +1,138 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +from unittest.mock import patch + +import pytest + +from llama_stack.apis.vector_dbs import VectorDB +from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex + +PGVECTOR_PROVIDER = "pgvector" + + +@pytest.fixture(scope="session") +def loop(): + return asyncio.new_event_loop() + + +@pytest.fixture +def embedding_dimension(): + """Default embedding dimension for tests.""" + return 384 + + +@pytest.fixture +async def pgvector_index(embedding_dimension, mock_psycopg2_connection): + """Create a PGVectorIndex instance with mocked database connection.""" + connection, cursor = mock_psycopg2_connection + + vector_db = VectorDB( + identifier="test-vector-db", + embedding_model="test-model", + embedding_dimension=embedding_dimension, + provider_id=PGVECTOR_PROVIDER, + provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db", + ) + + with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"): + # Use explicit COSINE distance metric for consistent testing + index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE") + + return index, cursor + + +class TestPGVectorIndex: + def test_distance_metric_validation(self, embedding_dimension, mock_psycopg2_connection): + connection, cursor = mock_psycopg2_connection + + vector_db = VectorDB( + identifier="test-vector-db", + embedding_model="test-model", + embedding_dimension=embedding_dimension, + provider_id=PGVECTOR_PROVIDER, + provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db", + ) + + with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"): + index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="L2") + assert index.distance_metric == "L2" + with pytest.raises(ValueError, match="Distance metric 'INVALID' is not supported"): + PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="INVALID") + + def test_get_pgvector_search_function(self, pgvector_index): + index, cursor = pgvector_index + supported_metrics = index.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION + + for metric, function in supported_metrics.items(): + index.distance_metric = metric + assert index.get_pgvector_search_function() == function + + def test_check_distance_metric_availability(self, pgvector_index): + index, cursor = pgvector_index + supported_metrics = index.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION + + for metric in supported_metrics: + index.check_distance_metric_availability(metric) + + with pytest.raises(ValueError, match="Distance metric 'INVALID' is not supported"): + index.check_distance_metric_availability("INVALID") + + def test_constructor_invalid_distance_metric(self, embedding_dimension, mock_psycopg2_connection): + connection, cursor = mock_psycopg2_connection + + vector_db = VectorDB( + identifier="test-vector-db", + embedding_model="test-model", + embedding_dimension=embedding_dimension, + provider_id=PGVECTOR_PROVIDER, + provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db", + ) + + with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"): + with pytest.raises(ValueError, match="Distance metric 'INVALID_METRIC' is not supported by PGVector"): + PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="INVALID_METRIC") + + with pytest.raises(ValueError, match="Supported metrics are:"): + PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="UNKNOWN") + + try: + index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE") + assert index.distance_metric == "COSINE" + except ValueError: + pytest.fail("Valid distance metric 'COSINE' should not raise ValueError") + + def test_constructor_all_supported_distance_metrics(self, embedding_dimension, mock_psycopg2_connection): + connection, cursor = mock_psycopg2_connection + + vector_db = VectorDB( + identifier="test-vector-db", + embedding_model="test-model", + embedding_dimension=embedding_dimension, + provider_id=PGVECTOR_PROVIDER, + provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db", + ) + + supported_metrics = ["L2", "L1", "COSINE", "INNER_PRODUCT", "HAMMING", "JACCARD"] + + with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"): + for metric in supported_metrics: + try: + index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric=metric) + assert index.distance_metric == metric + + expected_operators = { + "L2": "<->", + "L1": "<+>", + "COSINE": "<=>", + "INNER_PRODUCT": "<#>", + "HAMMING": "<~>", + "JACCARD": "<%>", + } + assert index.get_pgvector_search_function() == expected_operators[metric] + except Exception as e: + pytest.fail(f"Valid distance metric '{metric}' should not raise exception: {e}") diff --git a/uv.lock b/uv.lock index 0626caba6..b47eeccc4 100644 --- a/uv.lock +++ b/uv.lock @@ -1859,6 +1859,7 @@ test = [ { name = "mcp" }, { name = "milvus-lite" }, { name = "openai" }, + { name = "psycopg2-binary" }, { name = "pymilvus" }, { name = "pypdf" }, { name = "requests" }, @@ -1884,6 +1885,7 @@ unit = [ { name = "moto", extra = ["s3"] }, { name = "ollama" }, { name = "openai" }, + { name = "psycopg2-binary" }, { name = "pymilvus" }, { name = "pypdf" }, { name = "qdrant-client" }, @@ -1978,6 +1980,7 @@ test = [ { name = "mcp" }, { name = "milvus-lite", specifier = ">=2.5.0" }, { name = "openai" }, + { name = "psycopg2-binary", specifier = ">=2.9.0" }, { name = "pymilvus", specifier = ">=2.5.12" }, { name = "pypdf" }, { name = "requests" }, @@ -2002,6 +2005,7 @@ unit = [ { name = "moto", extras = ["s3"], specifier = ">=5.1.10" }, { name = "ollama" }, { name = "openai" }, + { name = "psycopg2-binary", specifier = ">=2.9.0" }, { name = "pymilvus", specifier = ">=2.5.12" }, { name = "pypdf" }, { name = "qdrant-client" }, @@ -3139,6 +3143,37 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/50/1b/6921afe68c74868b4c9fa424dad3be35b095e16687989ebbb50ce4fceb7c/psutil-7.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:4cf3d4eb1aa9b348dec30105c55cd9b7d4629285735a102beb4441e38db90553", size = 244885, upload-time = "2025-02-13T21:54:37.486Z" }, ] +[[package]] +name = "psycopg2-binary" +version = "2.9.10" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cb/0e/bdc8274dc0585090b4e3432267d7be4dfbfd8971c0fa59167c711105a6bf/psycopg2-binary-2.9.10.tar.gz", hash = "sha256:4b3df0e6990aa98acda57d983942eff13d824135fe2250e6522edaa782a06de2", size = 385764, upload-time = "2024-10-16T11:24:58.126Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/49/7d/465cc9795cf76f6d329efdafca74693714556ea3891813701ac1fee87545/psycopg2_binary-2.9.10-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:880845dfe1f85d9d5f7c412efea7a08946a46894537e4e5d091732eb1d34d9a0", size = 3044771, upload-time = "2024-10-16T11:20:35.234Z" }, + { url = "https://files.pythonhosted.org/packages/8b/31/6d225b7b641a1a2148e3ed65e1aa74fc86ba3fee850545e27be9e1de893d/psycopg2_binary-2.9.10-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:9440fa522a79356aaa482aa4ba500b65f28e5d0e63b801abf6aa152a29bd842a", size = 3275336, upload-time = "2024-10-16T11:20:38.742Z" }, + { url = "https://files.pythonhosted.org/packages/30/b7/a68c2b4bff1cbb1728e3ec864b2d92327c77ad52edcd27922535a8366f68/psycopg2_binary-2.9.10-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e3923c1d9870c49a2d44f795df0c889a22380d36ef92440ff618ec315757e539", size = 2851637, upload-time = "2024-10-16T11:20:42.145Z" }, + { url = "https://files.pythonhosted.org/packages/0b/b1/cfedc0e0e6f9ad61f8657fd173b2f831ce261c02a08c0b09c652b127d813/psycopg2_binary-2.9.10-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7b2c956c028ea5de47ff3a8d6b3cc3330ab45cf0b7c3da35a2d6ff8420896526", size = 3082097, upload-time = "2024-10-16T11:20:46.185Z" }, + { url = "https://files.pythonhosted.org/packages/18/ed/0a8e4153c9b769f59c02fb5e7914f20f0b2483a19dae7bf2db54b743d0d0/psycopg2_binary-2.9.10-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f758ed67cab30b9a8d2833609513ce4d3bd027641673d4ebc9c067e4d208eec1", size = 3264776, upload-time = "2024-10-16T11:20:50.879Z" }, + { url = "https://files.pythonhosted.org/packages/10/db/d09da68c6a0cdab41566b74e0a6068a425f077169bed0946559b7348ebe9/psycopg2_binary-2.9.10-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8cd9b4f2cfab88ed4a9106192de509464b75a906462fb846b936eabe45c2063e", size = 3020968, upload-time = "2024-10-16T11:20:56.819Z" }, + { url = "https://files.pythonhosted.org/packages/94/28/4d6f8c255f0dfffb410db2b3f9ac5218d959a66c715c34cac31081e19b95/psycopg2_binary-2.9.10-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6dc08420625b5a20b53551c50deae6e231e6371194fa0651dbe0fb206452ae1f", size = 2872334, upload-time = "2024-10-16T11:21:02.411Z" }, + { url = "https://files.pythonhosted.org/packages/05/f7/20d7bf796593c4fea95e12119d6cc384ff1f6141a24fbb7df5a668d29d29/psycopg2_binary-2.9.10-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:d7cd730dfa7c36dbe8724426bf5612798734bff2d3c3857f36f2733f5bfc7c00", size = 2822722, upload-time = "2024-10-16T11:21:09.01Z" }, + { url = "https://files.pythonhosted.org/packages/4d/e4/0c407ae919ef626dbdb32835a03b6737013c3cc7240169843965cada2bdf/psycopg2_binary-2.9.10-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:155e69561d54d02b3c3209545fb08938e27889ff5a10c19de8d23eb5a41be8a5", size = 2920132, upload-time = "2024-10-16T11:21:16.339Z" }, + { url = "https://files.pythonhosted.org/packages/2d/70/aa69c9f69cf09a01da224909ff6ce8b68faeef476f00f7ec377e8f03be70/psycopg2_binary-2.9.10-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c3cc28a6fd5a4a26224007712e79b81dbaee2ffb90ff406256158ec4d7b52b47", size = 2959312, upload-time = "2024-10-16T11:21:25.584Z" }, + { url = "https://files.pythonhosted.org/packages/d3/bd/213e59854fafe87ba47814bf413ace0dcee33a89c8c8c814faca6bc7cf3c/psycopg2_binary-2.9.10-cp312-cp312-win32.whl", hash = "sha256:ec8a77f521a17506a24a5f626cb2aee7850f9b69a0afe704586f63a464f3cd64", size = 1025191, upload-time = "2024-10-16T11:21:29.912Z" }, + { url = "https://files.pythonhosted.org/packages/92/29/06261ea000e2dc1e22907dbbc483a1093665509ea586b29b8986a0e56733/psycopg2_binary-2.9.10-cp312-cp312-win_amd64.whl", hash = "sha256:18c5ee682b9c6dd3696dad6e54cc7ff3a1a9020df6a5c0f861ef8bfd338c3ca0", size = 1164031, upload-time = "2024-10-16T11:21:34.211Z" }, + { url = "https://files.pythonhosted.org/packages/3e/30/d41d3ba765609c0763505d565c4d12d8f3c79793f0d0f044ff5a28bf395b/psycopg2_binary-2.9.10-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:26540d4a9a4e2b096f1ff9cce51253d0504dca5a85872c7f7be23be5a53eb18d", size = 3044699, upload-time = "2024-10-16T11:21:42.841Z" }, + { url = "https://files.pythonhosted.org/packages/35/44/257ddadec7ef04536ba71af6bc6a75ec05c5343004a7ec93006bee66c0bc/psycopg2_binary-2.9.10-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:e217ce4d37667df0bc1c397fdcd8de5e81018ef305aed9415c3b093faaeb10fb", size = 3275245, upload-time = "2024-10-16T11:21:51.989Z" }, + { url = "https://files.pythonhosted.org/packages/1b/11/48ea1cd11de67f9efd7262085588790a95d9dfcd9b8a687d46caf7305c1a/psycopg2_binary-2.9.10-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:245159e7ab20a71d989da00f280ca57da7641fa2cdcf71749c193cea540a74f7", size = 2851631, upload-time = "2024-10-16T11:21:57.584Z" }, + { url = "https://files.pythonhosted.org/packages/62/e0/62ce5ee650e6c86719d621a761fe4bc846ab9eff8c1f12b1ed5741bf1c9b/psycopg2_binary-2.9.10-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3c4ded1a24b20021ebe677b7b08ad10bf09aac197d6943bfe6fec70ac4e4690d", size = 3082140, upload-time = "2024-10-16T11:22:02.005Z" }, + { url = "https://files.pythonhosted.org/packages/27/ce/63f946c098611f7be234c0dd7cb1ad68b0b5744d34f68062bb3c5aa510c8/psycopg2_binary-2.9.10-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3abb691ff9e57d4a93355f60d4f4c1dd2d68326c968e7db17ea96df3c023ef73", size = 3264762, upload-time = "2024-10-16T11:22:06.412Z" }, + { url = "https://files.pythonhosted.org/packages/43/25/c603cd81402e69edf7daa59b1602bd41eb9859e2824b8c0855d748366ac9/psycopg2_binary-2.9.10-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8608c078134f0b3cbd9f89b34bd60a943b23fd33cc5f065e8d5f840061bd0673", size = 3020967, upload-time = "2024-10-16T11:22:11.583Z" }, + { url = "https://files.pythonhosted.org/packages/5f/d6/8708d8c6fca531057fa170cdde8df870e8b6a9b136e82b361c65e42b841e/psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:230eeae2d71594103cd5b93fd29d1ace6420d0b86f4778739cb1a5a32f607d1f", size = 2872326, upload-time = "2024-10-16T11:22:16.406Z" }, + { url = "https://files.pythonhosted.org/packages/ce/ac/5b1ea50fc08a9df82de7e1771537557f07c2632231bbab652c7e22597908/psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:bb89f0a835bcfc1d42ccd5f41f04870c1b936d8507c6df12b7737febc40f0909", size = 2822712, upload-time = "2024-10-16T11:22:21.366Z" }, + { url = "https://files.pythonhosted.org/packages/c4/fc/504d4503b2abc4570fac3ca56eb8fed5e437bf9c9ef13f36b6621db8ef00/psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f0c2d907a1e102526dd2986df638343388b94c33860ff3bbe1384130828714b1", size = 2920155, upload-time = "2024-10-16T11:22:25.684Z" }, + { url = "https://files.pythonhosted.org/packages/b2/d1/323581e9273ad2c0dbd1902f3fb50c441da86e894b6e25a73c3fda32c57e/psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f8157bed2f51db683f31306aa497311b560f2265998122abe1dce6428bd86567", size = 2959356, upload-time = "2024-10-16T11:22:30.562Z" }, + { url = "https://files.pythonhosted.org/packages/08/50/d13ea0a054189ae1bc21af1d85b6f8bb9bbc5572991055d70ad9006fe2d6/psycopg2_binary-2.9.10-cp313-cp313-win_amd64.whl", hash = "sha256:27422aa5f11fbcd9b18da48373eb67081243662f9b46e6fd07c3eb46e4535142", size = 2569224, upload-time = "2025-01-04T20:09:19.234Z" }, +] + [[package]] name = "ptyprocess" version = "0.7.0" From efdb5558b8dcab4d141678bfed0a405e2f312b6f Mon Sep 17 00:00:00 2001 From: slekkala1 Date: Fri, 29 Aug 2025 11:03:52 -0700 Subject: [PATCH 3/3] fix: Remove bfcl scoring function as not supported (#3281) # What does this PR do? BFCL scoring function is not supported, removing it. Also minor fixes as the llama stack run is broken for open-benchmark for test plan verification 1. Correct the model paths for supported models 2. Fix another issue as there is no `provider_id` for DatasetInput but logger assumes it exists. ``` File "/Users/swapna942/llama-stack/llama_stack/core/stack.py", line 332, in construct_stack await register_resources(run_config, impls) File "/Users/swapna942/llama-stack/llama_stack/core/stack.py", line 108, in register_resources logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}") ^^^^^^^^^^^^^^^ File "/Users/swapna942/llama-stack/.venv/lib/python3.13/site-packages/pydantic/main.py", line 991, in __getattr__ raise AttributeError(f'{type(self).__name__!r} object has no attribute {item!r}') AttributeError: 'DatasetInput' object has no attribute 'provider_id' ``` ## Test Plan ```llama stack build --distro open-benchmark --image-type venv``` and run the server succeeds Issue Link: https://github.com/llamastack/llama-stack/issues/3282 --- .../llama_stack_client_cli_reference.md | 1 - llama_stack/core/stack.py | 12 +- .../open-benchmark/open_benchmark.py | 16 +- .../distributions/open-benchmark/run.yaml | 19 +- .../providers/inline/scoring/basic/scoring.py | 2 - .../basic/scoring_fn/bfcl_scoring_fn.py | 93 -- .../scoring/basic/scoring_fn/fn_defs/bfcl.py | 21 - .../scoring/basic/utils/bfcl/__init__.py | 5 - .../scoring/basic/utils/bfcl/ast_parser.py | 296 ------ .../scoring/basic/utils/bfcl/checker.py | 989 ------------------ .../scoring/basic/utils/bfcl/tree_sitter.py | 40 - 11 files changed, 12 insertions(+), 1482 deletions(-) delete mode 100644 llama_stack/providers/inline/scoring/basic/scoring_fn/bfcl_scoring_fn.py delete mode 100644 llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/bfcl.py delete mode 100644 llama_stack/providers/inline/scoring/basic/utils/bfcl/__init__.py delete mode 100644 llama_stack/providers/inline/scoring/basic/utils/bfcl/ast_parser.py delete mode 100644 llama_stack/providers/inline/scoring/basic/utils/bfcl/checker.py delete mode 100644 llama_stack/providers/inline/scoring/basic/utils/bfcl/tree_sitter.py diff --git a/docs/source/references/llama_stack_client_cli_reference.md b/docs/source/references/llama_stack_client_cli_reference.md index 2d386dbfa..d4d79cea1 100644 --- a/docs/source/references/llama_stack_client_cli_reference.md +++ b/docs/source/references/llama_stack_client_cli_reference.md @@ -478,7 +478,6 @@ llama-stack-client scoring_functions list ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┓ ┃ identifier ┃ provider_id ┃ description ┃ type ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━┩ -│ basic::bfcl │ basic │ BFCL complex scoring │ scoring_function │ │ basic::docvqa │ basic │ DocVQA Visual Question & Answer scoring function │ scoring_function │ │ basic::equality │ basic │ Returns 1.0 if the input is equal to the target, 0.0 │ scoring_function │ │ │ │ otherwise. │ │ diff --git a/llama_stack/core/stack.py b/llama_stack/core/stack.py index f734d0285..bccea48d3 100644 --- a/llama_stack/core/stack.py +++ b/llama_stack/core/stack.py @@ -105,12 +105,12 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]): method = getattr(impls[api], register_method) for obj in objects: - logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}") - - # Do not register models on disabled providers - if hasattr(obj, "provider_id") and (not obj.provider_id or obj.provider_id == "__disabled__"): - logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.") - continue + if hasattr(obj, "provider_id"): + # Do not register models on disabled providers + if not obj.provider_id or obj.provider_id == "__disabled__": + logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.") + continue + logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}") # we want to maintain the type information in arguments to method. # instead of method(**obj.model_dump()), which may convert a typed attr to a dict, diff --git a/llama_stack/distributions/open-benchmark/open_benchmark.py b/llama_stack/distributions/open-benchmark/open_benchmark.py index af08ac7ba..1d84512cd 100644 --- a/llama_stack/distributions/open-benchmark/open_benchmark.py +++ b/llama_stack/distributions/open-benchmark/open_benchmark.py @@ -43,7 +43,7 @@ def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderMo "openai", [ ProviderModelEntry( - provider_model_id="openai/gpt-4o", + provider_model_id="gpt-4o", model_type=ModelType.llm, ) ], @@ -53,7 +53,7 @@ def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderMo "anthropic", [ ProviderModelEntry( - provider_model_id="anthropic/claude-3-5-sonnet-latest", + provider_model_id="claude-3-5-sonnet-latest", model_type=ModelType.llm, ) ], @@ -206,13 +206,6 @@ def get_distribution_template() -> DistributionTemplate: uri="huggingface://datasets/llamastack/math_500?split=test", ), ), - DatasetInput( - dataset_id="bfcl", - purpose=DatasetPurpose.eval_messages_answer, - source=URIDataSource( - uri="huggingface://datasets/llamastack/bfcl_v3?split=train", - ), - ), DatasetInput( dataset_id="ifeval", purpose=DatasetPurpose.eval_messages_answer, @@ -250,11 +243,6 @@ def get_distribution_template() -> DistributionTemplate: dataset_id="math_500", scoring_functions=["basic::regex_parser_math_response"], ), - BenchmarkInput( - benchmark_id="meta-reference-bfcl", - dataset_id="bfcl", - scoring_functions=["basic::bfcl"], - ), BenchmarkInput( benchmark_id="meta-reference-ifeval", dataset_id="ifeval", diff --git a/llama_stack/distributions/open-benchmark/run.yaml b/llama_stack/distributions/open-benchmark/run.yaml index 779bca47e..d068a0b5a 100644 --- a/llama_stack/distributions/open-benchmark/run.yaml +++ b/llama_stack/distributions/open-benchmark/run.yaml @@ -136,14 +136,14 @@ inference_store: db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/inference_store.db models: - metadata: {} - model_id: openai/gpt-4o + model_id: gpt-4o provider_id: openai - provider_model_id: openai/gpt-4o + provider_model_id: gpt-4o model_type: llm - metadata: {} - model_id: anthropic/claude-3-5-sonnet-latest + model_id: claude-3-5-sonnet-latest provider_id: anthropic - provider_model_id: anthropic/claude-3-5-sonnet-latest + provider_model_id: claude-3-5-sonnet-latest model_type: llm - metadata: {} model_id: gemini/gemini-1.5-flash @@ -188,12 +188,6 @@ datasets: uri: huggingface://datasets/llamastack/math_500?split=test metadata: {} dataset_id: math_500 -- purpose: eval/messages-answer - source: - type: uri - uri: huggingface://datasets/llamastack/bfcl_v3?split=train - metadata: {} - dataset_id: bfcl - purpose: eval/messages-answer source: type: uri @@ -228,11 +222,6 @@ benchmarks: - basic::regex_parser_math_response metadata: {} benchmark_id: meta-reference-math-500 -- dataset_id: bfcl - scoring_functions: - - basic::bfcl - metadata: {} - benchmark_id: meta-reference-bfcl - dataset_id: ifeval scoring_functions: - basic::ifeval diff --git a/llama_stack/providers/inline/scoring/basic/scoring.py b/llama_stack/providers/inline/scoring/basic/scoring.py index 91b10daae..b19b68039 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring.py +++ b/llama_stack/providers/inline/scoring/basic/scoring.py @@ -22,7 +22,6 @@ from llama_stack.providers.utils.common.data_schema_validator import ( ) from .config import BasicScoringConfig -from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn from .scoring_fn.docvqa_scoring_fn import DocVQAScoringFn from .scoring_fn.equality_scoring_fn import EqualityScoringFn from .scoring_fn.ifeval_scoring_fn import IfEvalScoringFn @@ -37,7 +36,6 @@ FIXED_FNS = [ SubsetOfScoringFn, RegexParserScoringFn, RegexParserMathResponseScoringFn, - BFCLScoringFn, IfEvalScoringFn, DocVQAScoringFn, ] diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/bfcl_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/bfcl_scoring_fn.py deleted file mode 100644 index b29620be2..000000000 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/bfcl_scoring_fn.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import json -import re -from typing import Any - -from llama_stack.apis.scoring import ScoringResultRow -from llama_stack.apis.scoring_functions import ScoringFnParams -from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn - -from ..utils.bfcl.ast_parser import decode_ast -from ..utils.bfcl.checker import ast_checker, is_empty_output -from .fn_defs.bfcl import bfcl - - -def postprocess(x: dict[str, Any], test_category: str) -> dict[str, Any]: - contain_func_call = False - error = None - error_type = None - checker_result = {} - try: - prediction = decode_ast(x["generated_answer"], x["language"]) or "" - contain_func_call = True - # if not is_function_calling_format_output(prediction): - if is_empty_output(prediction): - contain_func_call = False - error = "Did not output in the specified format. Note: the model_result is wrapped in a string to ensure json serializability." - error_type = "ast_decoder:decoder_wrong_output_format" - else: - checker_result = ast_checker( - json.loads(x["function"]), - prediction, - json.loads(x["ground_truth"]), - x["language"], - test_category=test_category, - model_name="", - ) - except Exception as e: - prediction = "" - error = f"Invalid syntax. Failed to decode AST. {str(e)}" - error_type = "ast_decoder:decoder_failed" - return { - "prediction": prediction, - "contain_func_call": contain_func_call, - "valid": checker_result.get("valid", False), - "error": error or checker_result.get("error", ""), - "error_type": error_type or checker_result.get("error_type", ""), - } - - -def gen_valid(x: dict[str, Any]) -> dict[str, float]: - return {"valid": x["valid"]} - - -def gen_relevance_acc(x: dict[str, Any]) -> dict[str, float]: - # This function serves for both relevance and irrelevance tests, which share the exact opposite logic. - # If `test_category` is "irrelevance", the model is expected to output no function call. - # No function call means either the AST decoding fails (a error message is generated) or the decoded AST does not contain any function call (such as a empty list, `[]`). - # If `test_category` is "relevance", the model is expected to output to a function call, and empty list doesn't count as a function call. - acc = not x["contain_func_call"] if "irrelevance" in x["id"] else x["contain_func_call"] - return {"valid": float(acc)} - - -class BFCLScoringFn(RegisteredBaseScoringFn): - """ - A scoring_fn for BFCL - """ - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.supported_fn_defs_registry = { - bfcl.identifier: bfcl, - } - - async def score_row( - self, - input_row: dict[str, Any], - scoring_fn_identifier: str | None = "bfcl", - scoring_params: ScoringFnParams | None = None, - ) -> ScoringResultRow: - test_category = re.sub(r"_[0-9_-]+$", "", input_row["id"]) - score_result = postprocess(input_row, test_category) - if test_category in {"irrelevance", "live_relevance", "live_irrelevance"}: - score = gen_relevance_acc(score_result)["valid"] - else: - score = gen_valid(score_result)["valid"] - return { - "score": float(score), - } diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/bfcl.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/bfcl.py deleted file mode 100644 index 392d92c86..000000000 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/bfcl.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.apis.common.type_system import NumberType -from llama_stack.apis.scoring_functions import ( - AggregationFunctionType, - BasicScoringFnParams, - ScoringFn, -) - -bfcl = ScoringFn( - identifier="basic::bfcl", - description="BFCL complex scoring", - return_type=NumberType(), - provider_id="basic", - provider_resource_id="bfcl", - params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]), -) diff --git a/llama_stack/providers/inline/scoring/basic/utils/bfcl/__init__.py b/llama_stack/providers/inline/scoring/basic/utils/bfcl/__init__.py deleted file mode 100644 index 756f351d8..000000000 --- a/llama_stack/providers/inline/scoring/basic/utils/bfcl/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/llama_stack/providers/inline/scoring/basic/utils/bfcl/ast_parser.py b/llama_stack/providers/inline/scoring/basic/utils/bfcl/ast_parser.py deleted file mode 100644 index 445cdfc77..000000000 --- a/llama_stack/providers/inline/scoring/basic/utils/bfcl/ast_parser.py +++ /dev/null @@ -1,296 +0,0 @@ -# ruff: noqa -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. -import ast - -from .tree_sitter import get_parser - - -def parse_java_function_call(source_code): - if not source_code.endswith(";"): - source_code += ";" # Necessary for the parser not to register an error - parser = get_parser("java") - tree = parser.parse(bytes(source_code, "utf8")) - root_node = tree.root_node - - if root_node.has_error: - raise Exception("Error parsing java the source code.") - - def get_text(node): - """Returns the text represented by the node.""" - return source_code[node.start_byte : node.end_byte] - - def traverse_node(node, nested=False): - if node.type == "string_literal": - if nested: - return get_text(node) - # Strip surrounding quotes from string literals - return get_text(node)[1:-1] - elif node.type == "character_literal": - if nested: - return get_text(node) - # Strip surrounding single quotes from character literals - return get_text(node)[1:-1] - """Traverse the node to collect texts for complex structures.""" - if node.type in [ - "identifier", - "class_literal", - "type_identifier", - "method_invocation", - ]: - return get_text(node) - elif node.type == "array_creation_expression": - # Handle array creation expression specifically - type_node = node.child_by_field_name("type") - value_node = node.child_by_field_name("value") - type_text = traverse_node(type_node, True) - value_text = traverse_node(value_node, True) - return f"new {type_text}[]{value_text}" - elif node.type == "object_creation_expression": - # Handle object creation expression specifically - type_node = node.child_by_field_name("type") - arguments_node = node.child_by_field_name("arguments") - type_text = traverse_node(type_node, True) - if arguments_node: - # Process each argument carefully, avoiding unnecessary punctuation - argument_texts = [] - for child in arguments_node.children: - if child.type not in [ - ",", - "(", - ")", - ]: # Exclude commas and parentheses - argument_text = traverse_node(child, True) - argument_texts.append(argument_text) - arguments_text = ", ".join(argument_texts) - return f"new {type_text}({arguments_text})" - else: - return f"new {type_text}()" - elif node.type == "set": - # Handling sets specifically - items = [traverse_node(n, True) for n in node.children if n.type not in [",", "set"]] - return "{" + ", ".join(items) + "}" - - elif node.child_count > 0: - return "".join(traverse_node(child, True) for child in node.children) - else: - return get_text(node) - - def extract_arguments(args_node): - arguments = {} - for child in args_node.children: - if child.type == "assignment_expression": - # For named parameters - name_node, value_node = child.children[0], child.children[2] - name = get_text(name_node) - value = traverse_node(value_node) - if name in arguments: - if not isinstance(arguments[name], list): - arguments[name] = [arguments[name]] - arguments[name].append(value) - else: - arguments[name] = value - # arguments.append({'name': name, 'value': value}) - elif child.type in ["identifier", "class_literal", "set"]: - # For unnamed parameters and handling sets - value = traverse_node(child) - if None in arguments: - if not isinstance(arguments[None], list): - arguments[None] = [arguments[None]] - arguments[None].append(value) - else: - arguments[None] = value - return arguments - - def traverse(node): - if node.type == "method_invocation": - # Extract the function name and its arguments - method_name = get_text(node.child_by_field_name("name")) - class_name_node = node.child_by_field_name("object") - if class_name_node: - class_name = get_text(class_name_node) - function_name = f"{class_name}.{method_name}" - else: - function_name = method_name - arguments_node = node.child_by_field_name("arguments") - if arguments_node: - arguments = extract_arguments(arguments_node) - for key, value in arguments.items(): - if isinstance(value, list): - raise Exception("Error: Multiple arguments with the same name are not supported.") - return [{function_name: arguments}] - - else: - for child in node.children: - result = traverse(child) - if result: - return result - - result = traverse(root_node) - return result if result else {} - - -def parse_javascript_function_call(source_code): - if not source_code.endswith(";"): - source_code += ";" # Necessary for the parser not to register an error - parser = get_parser("javascript") - # Parse the source code - tree = parser.parse(bytes(source_code, "utf8")) - root_node = tree.root_node - if root_node.has_error: - raise Exception("Error js parsing the source code.") - - # Function to recursively extract argument details - def extract_arguments(node): - args = {} - for child in node.children: - if child.type == "assignment_expression": - # Extract left (name) and right (value) parts of the assignment - name = child.children[0].text.decode("utf-8") - value = child.children[2].text.decode("utf-8") - if (value.startswith('"') and value.endswith('"')) or (value.startswith("'") and value.endswith("'")): - value = value[1:-1] # Trim the quotation marks - if name in args: - if not isinstance(args[name], list): - args[name] = [args[name]] - args[name].append(value) - else: - args[name] = value - - elif child.type == "identifier" or child.type == "true": - # Handle non-named arguments and boolean values - value = child.text.decode("utf-8") - if None in args: - if not isinstance(args[None], list): - args[None] = [args[None]] - args[None].append(value) - else: - args[None] = value - return args - - # Find the function call and extract its name and arguments - if root_node.type == "program": - for child in root_node.children: - if child.type == "expression_statement": - for sub_child in child.children: - if sub_child.type == "call_expression": - function_name = sub_child.children[0].text.decode("utf8") - arguments_node = sub_child.children[1] - parameters = extract_arguments(arguments_node) - for key, value in parameters.items(): - if isinstance(value, list): - raise Exception("Error: Multiple arguments with the same name are not supported.") - result = [{function_name: parameters}] - return result - - -def ast_parse(input_str, language="Python"): - if language == "Python": - cleaned_input = input_str.strip("[]'") - parsed = ast.parse(cleaned_input, mode="eval") - extracted = [] - if isinstance(parsed.body, ast.Call): - extracted.append(resolve_ast_call(parsed.body)) - else: - for elem in parsed.body.elts: - extracted.append(resolve_ast_call(elem)) - return extracted - elif language == "Java": - return parse_java_function_call(input_str[1:-1]) # Remove the [ and ] from the string - elif language == "JavaScript": - return parse_javascript_function_call(input_str[1:-1]) - else: - raise NotImplementedError(f"Unsupported language: {language}") - - -def resolve_ast_call(elem): - # Handle nested attributes for deeply nested module paths - func_parts = [] - func_part = elem.func - while isinstance(func_part, ast.Attribute): - func_parts.append(func_part.attr) - func_part = func_part.value - if isinstance(func_part, ast.Name): - func_parts.append(func_part.id) - func_name = ".".join(reversed(func_parts)) - args_dict = {} - # Parse when args are simply passed as an unnamed dictionary arg - for arg in elem.args: - if isinstance(arg, ast.Dict): - for key, value in zip(arg.keys, arg.values): - if isinstance(key, ast.Constant): - arg_name = key.value - output = resolve_ast_by_type(value) - args_dict[arg_name] = output - for arg in elem.keywords: - output = resolve_ast_by_type(arg.value) - args_dict[arg.arg] = output - return {func_name: args_dict} - - -def resolve_ast_by_type(value): - if isinstance(value, ast.Constant): - if value.value is Ellipsis: - output = "..." - else: - output = value.value - elif isinstance(value, ast.UnaryOp): - output = -value.operand.value - elif isinstance(value, ast.List): - output = [resolve_ast_by_type(v) for v in value.elts] - elif isinstance(value, ast.Dict): - output = {resolve_ast_by_type(k): resolve_ast_by_type(v) for k, v in zip(value.keys, value.values)} - elif isinstance(value, ast.NameConstant): # Added this condition to handle boolean values - output = value.value - elif isinstance(value, ast.BinOp): # Added this condition to handle function calls as arguments - output = eval(ast.unparse(value)) - elif isinstance(value, ast.Name): - output = value.id - elif isinstance(value, ast.Call): - if len(value.keywords) == 0: - output = ast.unparse(value) - else: - output = resolve_ast_call(value) - elif isinstance(value, ast.Tuple): - output = tuple(resolve_ast_by_type(v) for v in value.elts) - elif isinstance(value, ast.Lambda): - output = eval(ast.unparse(value.body[0].value)) - elif isinstance(value, ast.Ellipsis): - output = "..." - elif isinstance(value, ast.Subscript): - try: - output = ast.unparse(value.body[0].value) - except: - output = ast.unparse(value.value) + "[" + ast.unparse(value.slice) + "]" - else: - raise Exception(f"Unsupported AST type: {type(value)}") - return output - - -def decode_ast(result, language="Python"): - func = result - func = func.replace("\n", "") # remove new line characters - if not func.startswith("["): - func = "[" + func - if not func.endswith("]"): - func = func + "]" - decoded_output = ast_parse(func, language) - return decoded_output - - -def decode_execute(result): - func = result - func = func.replace("\n", "") # remove new line characters - if not func.startswith("["): - func = "[" + func - if not func.endswith("]"): - func = func + "]" - decode_output = ast_parse(func) - execution_list = [] - for function_call in decode_output: - for key, value in function_call.items(): - execution_list.append(f"{key}({','.join([f'{k}={repr(v)}' for k, v in value.items()])})") - return execution_list diff --git a/llama_stack/providers/inline/scoring/basic/utils/bfcl/checker.py b/llama_stack/providers/inline/scoring/basic/utils/bfcl/checker.py deleted file mode 100644 index f6aab123c..000000000 --- a/llama_stack/providers/inline/scoring/basic/utils/bfcl/checker.py +++ /dev/null @@ -1,989 +0,0 @@ -# ruff: noqa -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. -import json -import re -import time -from typing import Any - -# Comment out for now until we actually use the rest checker in evals -# import requests # Do not remove this import even though it seems to be unused. It's used in the executable_checker_rest function. - - -class NoAPIKeyError(Exception): - def __init__(self): - self.message = "❗️Please fill in the API keys in the function_credential_config.json file. If you do not provide the API keys, the executable test category results will be inaccurate." - super().__init__(self.message) - - -REAL_TIME_MATCH_ALLOWED_DIFFERENCE = 0.2 - - -JAVA_TYPE_CONVERSION = { - "byte": int, - "short": int, - "integer": int, - "float": float, - "double": float, - "long": int, - "boolean": bool, - "char": str, - "Array": list, - "ArrayList": list, - "Set": set, - "HashMap": dict, - "Hashtable": dict, - "Queue": list, # this can be `queue.Queue` as well, for simplicity we check with list - "Stack": list, - "String": str, - "any": str, -} - -JS_TYPE_CONVERSION = { - "String": str, - "integer": int, - "float": float, - "Bigint": int, - "Boolean": bool, - "dict": dict, - "array": list, - "any": str, -} - -# We switch to conditional import for the following two imports to avoid unnecessary installations. -# User doesn't need to setup the tree-sitter packages if they are not running the test for that language. -# from js_type_converter import js_type_converter -# from java_type_converter import java_type_converter - -PYTHON_TYPE_MAPPING = { - "string": str, - "integer": int, - "float": float, - "boolean": bool, - "array": list, - "tuple": list, - "dict": dict, - "any": str, -} - -# This is the list of types that we need to recursively check its values -PYTHON_NESTED_TYPE_CHECK_LIST = ["array", "tuple"] - - -NESTED_CONVERSION_TYPE_LIST = ["Array", "ArrayList", "array"] - - -#### Helper functions for AST #### -def find_description(func_descriptions, name): - if type(func_descriptions) == list: - for func_description in func_descriptions: - if func_description["name"] == name: - return func_description - return None - else: - # it is a dict, there is only one function - return func_descriptions - - -def get_possible_answer_type(possible_answer: list): - for answer in possible_answer: - if answer != "": # Optional parameter - return type(answer) - return None - - -def type_checker( - param: str, - value, - possible_answer: list, - expected_type_description: str, - expected_type_converted, - nested_type_converted, -): - # NOTE: This type checker only supports nested type checking for one level deep. - # We didn't implement recursive type checking for nested types, as it's not needed for the current use case and it's very complex. - - result: Any = { - "valid": True, - "error": [], - "is_variable": False, - "error_type": "type_error:simple", - } - - is_variable = False - # check for the case where a variable is used instead of a actual value. - # use the type in possible_answer as the expected type - possible_answer_type = get_possible_answer_type(possible_answer) - # if possible_answer only contains optional parameters, we can't determine the type - if possible_answer_type != None: - # we are being precise here. - # in fact, possible_answer_type should always be string, as that's how we treat varibale in possible_answer - if possible_answer_type != expected_type_converted: - is_variable = True - - # value is the same type as in function description - if type(value) == expected_type_converted: - # We don't need to do recursive check for simple types - if nested_type_converted == None: - result["is_variable"] = is_variable - return result - else: - for possible_answer_item in possible_answer: - flag = True # Each parameter should match to at least one possible answer type. - # Here, we assume that each item should be the same type. We could also relax it. - if type(possible_answer_item) == list: - for value_item in value: - checker_result = type_checker( - param, - value_item, - possible_answer_item, - str(nested_type_converted), - nested_type_converted, - None, - ) - if not checker_result["valid"]: - flag = False - break - - if flag: - return {"valid": True, "error": [], "is_variable": is_variable} - - result["valid"] = False - result["error"] = [ - f"Nested type checking failed for parameter {repr(param)}. Expected outer type {expected_type_description} with inner type {str(nested_type_converted)}. Parameter value: {repr(value)}." - ] - result["error_type"] = "type_error:nested" - - # value is not as expected, check for the case where a variable is used instead of a actual value - # use the type in possible_answer as the expected type - possible_answer_type = get_possible_answer_type(possible_answer) - # if possible_answer only contains optional parameters, we can't determine the type - if possible_answer_type != None: - # we are being precise here. - # in fact, possible_answer_type should always be string, as that's how we treat varibale in possible_answer - if type(value) == possible_answer_type: - result["is_variable"] = True - return result - - result["valid"] = False - result["error"].append( - f"Incorrect type for parameter {repr(param)}. Expected type {expected_type_description}, got {type(value).__name__}. Parameter value: {repr(value)}." - ) - result["error_type"] = "type_error:simple" - return result - - -def standardize_string(input_string: str): - # This function standardizes the string by removing all the spaces, ",./-_*^" punctuation, and converting it to lowercase - # It will also convert all the single quotes to double quotes - # This is used to compare the model output with the possible answers - # We don't want to punish model for answer like April 1, 2024 vs April 1,2024, vs April 1 2024 - regex_string = r"[ \,\.\/\-\_\*\^]" - return re.sub(regex_string, "", input_string).lower().replace("'", '"') - - -def string_checker(param: str, model_output: str, possible_answer: list): - standardize_possible_answer = [] - standardize_model_output = standardize_string(model_output) - for i in range(len(possible_answer)): - if type(possible_answer[i]) == str: - standardize_possible_answer.append(standardize_string(possible_answer[i])) - - if standardize_model_output not in standardize_possible_answer: - return { - "valid": False, - "error": [ - f"Invalid value for parameter {repr(param)}: {repr(model_output)}. Expected one of {possible_answer}. Case insensitive." - ], - "error_type": "value_error:string", - } - - return {"valid": True, "error": []} - - -def list_checker(param: str, model_output: list, possible_answer: list): - # Convert the tuple to a list - - standardize_model_output = list(model_output) - - # If the element in the list is a string, we need to standardize it - for i in range(len(standardize_model_output)): - if type(standardize_model_output[i]) == str: - standardize_model_output[i] = standardize_string(model_output[i]) - - standardize_possible_answer: Any = [] - # We also need to standardize the possible answers - for i in range(len(possible_answer)): - standardize_possible_answer.append([]) - for j in range(len(possible_answer[i])): - if type(possible_answer[i][j]) == str: - standardize_possible_answer[i].append(standardize_string(possible_answer[i][j])) - else: - standardize_possible_answer[i].append(possible_answer[i][j]) - - if standardize_model_output not in standardize_possible_answer: - return { - "valid": False, - "error": [ - f"Invalid value for parameter {repr(param)}: {repr(model_output)}. Expected one of {possible_answer}." - ], - "error_type": "value_error:list/tuple", - } - - return {"valid": True, "error": []} - - -def dict_checker(param: str, model_output: dict, possible_answers: list): - # This function works for simple dictionaries, but not dictionaries with nested dictionaries. - # The current dataset only contains simple dictionaries, so this is sufficient. - - result = {"valid": False, "error": [], "error_type": "dict_checker:unclear"} - for i in range(len(possible_answers)): - if possible_answers[i] == "": - continue - - result = {"valid": False, "error": [], "error_type": "dict_checker:unclear"} - - flag = True - - possible_answer = possible_answers[i] - # possible_anwer is a single dictionary - - for key, value in model_output.items(): - if key not in possible_answer: - result["valid"] = False - result["error"].append(f"Unexpected dict key parameter: '{key}'.") # type: ignore[attr-defined] - result["error_type"] = "value_error:dict_key" - flag = False - break - - standardize_value = value - # If the value is a string, we need to standardize it - if type(value) == str: - standardize_value = standardize_string(value) - - # We also need to standardize the possible answers if they are string - standardize_possible_answer = [] - for i in range(len(possible_answer[key])): - if type(possible_answer[key][i]) == str: - standardize_possible_answer.append(standardize_string(possible_answer[key][i])) - else: - standardize_possible_answer.append(possible_answer[key][i]) - - if standardize_value not in standardize_possible_answer: - result["valid"] = False - result["error"].append( # type: ignore[attr-defined] - f"Invalid value for parameter {repr(key)}: {repr(value)}. Expected one of {standardize_possible_answer}." - ) - result["error_type"] = "value_error:dict_value" - flag = False - break - - for key, value in possible_answer.items(): - if key not in model_output and "" not in value: - result["valid"] = False - result["error"].append(f"Missing dict key parameter: '{key}'.") # type: ignore[attr-defined] - result["error_type"] = "value_error:dict_key" - flag = False - break - - if flag: - return {"valid": True, "error": []} - - return result - - -def list_dict_checker(param: str, model_output: list, possible_answers: list): - # This function takes in a list of dictionaries and checks if each dictionary is valid - # The order of the dictionaries in the list must match the order of the possible answers - - result = {"valid": False, "error": [], "error_type": "list_dict_checker:unclear"} - - for answer_index in range(len(possible_answers)): - flag = True # True means so far, all dictionaries are valid - - # Only proceed if the number of dictionaries in the list matches the number of dictionaries in the possible answers - if len(model_output) != len(possible_answers[answer_index]): - result["valid"] = False - result["error"] = ["Wrong number of dictionaries in the list."] - result["error_type"] = "value_error:list_dict_count" - flag = False - continue - - for dict_index in range(len(model_output)): - result = dict_checker( - param, - model_output[dict_index], - [possible_answers[answer_index][dict_index]], - ) - if not result["valid"]: - flag = False - break - if flag: - return {"valid": True, "error": []} - - return result - - -def simple_function_checker( - func_description: dict, - model_output: dict, - possible_answer: dict, - language: str, - model_name: str, -): - possible_answer = list(possible_answer.values())[0] - # Extract function name and parameters details - func_name = func_description["name"] - param_details = func_description["parameters"]["properties"] - required_params = func_description["parameters"]["required"] - - # Initialize a result dictionary - result = { - "valid": True, - "error": [], - "error_type": "simple_function_checker:unclear", - } - - # Check if function name matches - if func_name not in model_output: - result["valid"] = False - result["error"].append( # type: ignore[attr-defined] - f"Function name {repr(func_name)} not found in model output." - ) - result["error_type"] = "simple_function_checker:wrong_func_name" - return result - - model_params = model_output[func_name] - - # Check for required parameters in model output - for param in required_params: - if param not in model_params: - result["valid"] = False - result["error"].append(f"Missing required parameter: {repr(param)}.") # type: ignore[attr-defined] - result["error_type"] = "simple_function_checker:missing_required" - return result - - # Validate types and values for each parameter in model output - for param, value in model_params.items(): - if param not in param_details or param not in possible_answer: - result["valid"] = False - result["error"].append(f"Unexpected parameter: {repr(param)}.") # type: ignore[attr-defined] - result["error_type"] = "simple_function_checker:unexpected_param" - return result - - full_param_details = param_details[param] - expected_type_description = full_param_details["type"] # This is a string - is_variable = False - nested_type_converted = None - - if language == "Java": - from evals.utils.bfcl.java_type_converter import java_type_converter - - expected_type_converted = JAVA_TYPE_CONVERSION[expected_type_description] - - if expected_type_description in JAVA_TYPE_CONVERSION: - if type(value) != str: - result["valid"] = False - result["error"].append( # type: ignore[attr-defined] - f"Incorrect type for parameter {repr(param)}. Expected type String, got {type(value).__name__}. Parameter value: {repr(value)}." - ) - result["error_type"] = "type_error:java" - return result - - if expected_type_description in NESTED_CONVERSION_TYPE_LIST: - nested_type = param_details[param]["items"]["type"] - nested_type_converted = JAVA_TYPE_CONVERSION[nested_type] - value = java_type_converter(value, expected_type_description, nested_type) - else: - value = java_type_converter(value, expected_type_description) - - elif language == "JavaScript": - from evals.utils.bfcl.js_type_converter import js_type_converter - - expected_type_converted = JS_TYPE_CONVERSION[expected_type_description] - - if expected_type_description in JS_TYPE_CONVERSION: - if type(value) != str: - result["valid"] = False - result["error"].append( # type: ignore[attr-defined] - f"Incorrect type for parameter {repr(param)}. Expected type String, got {type(value).__name__}. Parameter value: {repr(value)}." - ) - result["error_type"] = "type_error:js" - return result - - if expected_type_description in NESTED_CONVERSION_TYPE_LIST: - nested_type = param_details[param]["items"]["type"] - nested_type_converted = JS_TYPE_CONVERSION[nested_type] - value = js_type_converter(value, expected_type_description, nested_type) - else: - value = js_type_converter(value, expected_type_description) - - elif language == "Python": - expected_type_converted = PYTHON_TYPE_MAPPING[expected_type_description] - if expected_type_description in PYTHON_NESTED_TYPE_CHECK_LIST: - nested_type = param_details[param]["items"]["type"] - nested_type_converted = PYTHON_TYPE_MAPPING[nested_type] - - # We convert all tuple value to list when the expected type is tuple. - # The conversion is necessary because any tuple in the possible answer would become a list after being processed through json.dump() and json.load(). - # This does introduce some false positive (eg, when the model provides a list value instead of tuple). We hope to find a better solution in the future. - if expected_type_description == "tuple" and type(value) == tuple: - value = list(value) - - # Allow python auto conversion from int to float - if language == "Python" and expected_type_description == "float" and type(value) == int: - value = float(value) - - # Type checking - # In fact, we only check for Python here. - # Type check for other languages are handled by the type converter, and so their value (after conversion) is always correct. - type_check_result = type_checker( - param, - value, - possible_answer[param], - expected_type_description, - expected_type_converted, - nested_type_converted, - ) - is_variable = type_check_result["is_variable"] - if not type_check_result["valid"]: - return type_check_result - - # It doesn't make sense to special handle dictionaries and list of dictionaries if the value is a variable. - # We can just treat the variable as a string and use the normal flow. - if not is_variable: - # Special handle for dictionaries - if expected_type_converted == dict: - result = dict_checker(param, value, possible_answer[param]) - if not result["valid"]: - return result - continue - - # Special handle for list of dictionaries - elif expected_type_converted == list and nested_type_converted == dict: - result = list_dict_checker(param, value, possible_answer[param]) - if not result["valid"]: - return result - continue - - # Special handle for strings - elif expected_type_converted == str: - # We don't check for case sensitivity for string, as long as it's not a variable - result = string_checker(param, value, possible_answer[param]) - if not result["valid"]: - return result - continue - - elif expected_type_converted == list: - result = list_checker(param, value, possible_answer[param]) - if not result["valid"]: - return result - continue - - # Check if the value is within the possible answers - if value not in possible_answer[param]: - result["valid"] = False - result["error"].append( # type: ignore[attr-defined] - f"Invalid value for parameter {repr(param)}: {repr(value)}. Expected one of {possible_answer[param]}." - ) - result["error_type"] = "value_error:others" - return result - - # Check for optional parameters not provided but allowed - for param in possible_answer: - if param not in model_params and "" not in possible_answer[param]: - result["valid"] = False - result["error"].append( # type: ignore[attr-defined] - f"Optional parameter {repr(param)} not provided and not marked as optional." - ) - result["error_type"] = "simple_function_checker:missing_optional" - return result - - return result - - -def parallel_function_checker_enforce_order( - func_descriptions: list, - model_output: list, - possible_answers: dict, - language: str, - model_name: str, -): - if len(model_output) != len(possible_answers): - return { - "valid": False, - "error": ["Wrong number of functions."], - "error_type": "parallel_function_checker_enforce_order:wrong_count", - } - - func_name_list = list(possible_answers.keys()) - possible_answers_list = [] - - for key, value in possible_answers.items(): - possible_answers_list.append({key: value}) - - for i in range(len(possible_answers_list)): - func_description = find_description(func_descriptions, func_name_list[i]) - - result = simple_function_checker( - func_description, - model_output[i], - possible_answers_list[i], - language, - model_name, - ) - if not result["valid"]: - return result - - return {"valid": True, "error": []} - - -def parallel_function_checker_no_order( - func_descriptions: list, - model_output: list, - possible_answers: list, - language: str, - model_name: str, -): - if len(model_output) != len(possible_answers): - return { - "valid": False, - "error": ["Wrong number of functions."], - "error_type": "parallel_function_checker_no_order:wrong_count", - } - - matched_indices = [] - - # We go throught the possible answers one by one, and eliminate the model output that matches the possible answer - # It must be this way because we need ground truth to fetch the correct function description - for i in range(len(possible_answers)): - # possible_answers[i] is a dictionary with only one key - func_name_expected = list(possible_answers[i].keys())[0] - func_description = find_description(func_descriptions, func_name_expected) - - all_errors = [] - - for index in range(len(model_output)): - if index in matched_indices: - continue - - result = simple_function_checker( - func_description, - model_output[index], - possible_answers[i], - language, - model_name, - ) - - if result["valid"]: - matched_indices.append(index) - break - else: - all_errors.append( - { - f"Model Result Index {index}": { - "sub_error": result["error"], - "sub_error_type": result["error_type"], - "model_output_item": model_output[index], - "possible_answer_item": possible_answers[i], - } - } - ) - - if not result["valid"]: - considered_indices = [i for i in range(len(model_output)) if i not in matched_indices] - all_errors.insert( - 0, - f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type] - ) - return { - "valid": False, - "error": all_errors, - "error_type": "parallel_function_checker_no_order:cannot_find_match", - } - - return {"valid": True, "error": []} - - -def multiple_function_checker( - func_descriptions: list, - model_output: list, - possible_answers: list, - language: str, - model_name: str, -): - if len(model_output) != len(possible_answers): - return { - "valid": False, - "error": ["Wrong number of functions."], - "error_type": "multiple_function_checker:wrong_count", - } - - # possible_answers is a list of only one dictionary with only one key - func_name_expected = list(possible_answers[0].keys())[0] - func_description = find_description(func_descriptions, func_name_expected) - return simple_function_checker( - func_description, - model_output[0], - possible_answers[0], - language, - model_name, - ) - - -def patten_matcher(exec_output, expected_result, function_call, is_sanity_check): - result = {"valid": True, "error": [], "error_type": "executable_checker:unclear"} - - if type(exec_output) != type(expected_result): - return { - "valid": False, - "error": [ - f"Wrong execution result type for {repr(function_call)}. Expected type: {type(expected_result)}, but got: {type(exec_output)}." - ], - "error_type": "executable_checker:wrong_result_type", - "model_executed_output": exec_output, - } - if type(exec_output) == dict: - # We loose the requirement for the sanity check as the expected result used in the sanity check might not be the most up-to-date one. - # This happens when the key is a timestamp or a random number. - if is_sanity_check: - if len(exec_output) != len(expected_result): - return { - "valid": False, - "error": [ - f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but wrong number of elements in the output. Expected length: {len(expected_result)}, but got: {len(exec_output)}." - ], - "error_type": "executable_checker:wrong_result_type:dict_length", - "model_executed_output": exec_output, - } - else: - return result - - for key, value in expected_result.items(): - if key not in exec_output: - return { - "valid": False, - "error": [ - f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but key {repr(key)} not found in the model output." - ], - "error_type": "executable_checker:wrong_result_type:dict_key_not_found", - "model_executed_output": exec_output, - } - for key, value in exec_output.items(): - if key not in expected_result: - return { - "valid": False, - "error": [ - f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but key {repr(key)} not expected in the model output." - ], - "error_type": "executable_checker:wrong_result_type:dict_extra_key", - "model_executed_output": exec_output, - } - if type(exec_output) == list: - if len(exec_output) != len(expected_result): - return { - "valid": False, - "error": [ - f"Wrong execution result pattern for {repr(function_call)}. Expect type list, but wrong number of elements in the output. Expected length: {len(expected_result)}, but got: {len(exec_output)}." - ], - "error_type": "executable_checker:wrong_result_type:list_length", - "model_executed_output": exec_output, - } - return result - - -#### Helper functions for Exec #### -def executable_checker_simple( - function_call: str, - expected_result, - expected_result_type: str, - is_sanity_check=False, -): - result = {"valid": True, "error": [], "error_type": "executable_checker:unclear"} - - exec_dict: Any = {} - - try: - exec( - "from executable_python_function import *" + "\nresult=" + function_call, - exec_dict, - ) - exec_output = exec_dict["result"] - except NoAPIKeyError as e: - raise e - except Exception as e: - result["valid"] = False - result["error"].append( # type: ignore[attr-defined] - f"Error in execution: {repr(function_call)}. Error: {str(e)}" - ) - result["error_type"] = "executable_checker:execution_error" - return result - - # We need to special handle the case where the execution result is a tuple and convert it to a list - # Because when json is stored, the tuple is converted to a list, and so the expected result is a list when loaded from json - if isinstance(exec_output, tuple): - exec_output = list(exec_output) - - if expected_result_type == "exact_match": - if exec_output != expected_result: - result["valid"] = False - result["error"].append( # type: ignore[attr-defined] - f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}." - ) - result["error_type"] = "executable_checker:wrong_result" - result["model_executed_output"] = exec_output - return result - - elif expected_result_type == "real_time_match": - # Allow for 5% difference - if (type(expected_result) == float or type(expected_result) == int) and ( - type(exec_output) == float or type(exec_output) == int - ): - if not ( - expected_result * (1 - REAL_TIME_MATCH_ALLOWED_DIFFERENCE) - <= exec_output - <= expected_result * (1 + REAL_TIME_MATCH_ALLOWED_DIFFERENCE) - ): - result["valid"] = False - result["error"].append( # type: ignore[attr-defined] - f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}. {REAL_TIME_MATCH_ALLOWED_DIFFERENCE * 100}% difference allowed." - ) - result["error_type"] = "executable_checker:wrong_result_real_time" - result["model_executed_output"] = exec_output - return result - else: - result["valid"] = False - result["error"].append( # type: ignore[attr-defined] - f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}. Type needs to be float or int for real time match criteria." - ) - result["error_type"] = "executable_checker:wrong_result_real_time" - result["model_executed_output"] = exec_output - return result - - else: - # structural match - pattern_match_result = patten_matcher(exec_output, expected_result, function_call, is_sanity_check) - if not pattern_match_result["valid"]: - return pattern_match_result - - return result - - -def executable_checker_parallel_no_order( - decoded_result: list, expected_exec_result: list, expected_exec_result_type: list -): - if len(decoded_result) != len(expected_exec_result): - return { - "valid": False, - "error": [ - f"Wrong number of functions provided. Expected {len(expected_exec_result)}, but got {len(decoded_result)}." - ], - "error_type": "value_error:exec_result_count", - } - - matched_indices = [] - for i in range(len(expected_exec_result)): - all_errors = [] - for index in range(len(decoded_result)): - if index in matched_indices: - continue - - result = executable_checker_simple( - decoded_result[index], - expected_exec_result[i], - expected_exec_result_type[i], - False, - ) - - if result["valid"]: - matched_indices.append(index) - break - else: - all_errors.append( - { - f"Model Result Index {index}": { - "sub_error": result["error"], - "sub_error_type": result["error_type"], - "model_executed_output": ( - result["model_executed_output"] if "model_executed_output" in result else None - ), - } - } - ) - - if not result["valid"]: - considered_indices = [i for i in range(len(decoded_result)) if i not in matched_indices] - all_errors.insert( - 0, - f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type] - ) - return { - "valid": False, - "error": all_errors, - "error_type": "executable_checker:cannot_find_match", - } - - return {"valid": True, "error": [], "error_type": "executable_checker:unclear"} - - -#### Main function #### -def executable_checker_rest(func_call, idx): - # Move this here for now to avoid needing to read this file / fix paths to be relative to dataset_dir. Fix when it's actually needed / used. - EVAL_GROUND_TRUTH_PATH = "/mnt/wsfuse/fair_llm_v2/datasets/eval/bfcl/rest-eval-response_v5.jsonl" # Ground truth file for v5 for rest execution - with open(EVAL_GROUND_TRUTH_PATH, "r") as f: - EVAL_GROUND_TRUTH = f.readlines() - if "https://geocode.maps.co" in func_call: - time.sleep(2) - if "requests_get" in func_call: - func_call = func_call.replace("requests_get", "requests.get") - try: - response = eval(func_call) - except Exception as e: - return { - "valid": False, - "error": [f"Execution failed. {str(e)}"], - "error_type": "executable_checker_rest:execution_error", - } - - try: - if response.status_code == 200: - eval_GT_json = json.loads(EVAL_GROUND_TRUTH[idx]) - try: - if isinstance(eval_GT_json, dict): - if isinstance(response.json(), dict): - if set(eval_GT_json.keys()) == set(response.json().keys()): - return {"valid": True, "error": [], "error_type": ""} - return { - "valid": False, - "error": ["Key inconsistency"], - "error_type": "executable_checker_rest:wrong_key", - } - return { - "valid": False, - "error": [f"Expected dictionary, but got {type(response.json())}"], - "error_type": "executable_checker_rest:wrong_type", - } - - elif isinstance(eval_GT_json, list): - if isinstance(response.json(), list): - if len(eval_GT_json) != len(response.json()): - return { - "valid": False, - "error": [f"Response list length inconsistency."], - "error_type": "value_error:exec_result_rest_count", - } - - else: - for i in range(len(eval_GT_json)): - if set(eval_GT_json[i].keys()) != set(response.json()[i].keys()): - return { - "valid": False, - "error": [f"Key inconsistency"], - "error_type": "executable_checker_rest:wrong_key", - } - - return {"valid": True, "error": []} - else: - return { - "valid": False, - "error": [f"Expected list, but got {type(response.json())}"], - "error_type": "executable_checker_rest:wrong_type", - } - return { - "valid": False, - "error": [f"Expected dict or list, but got {type(response.json())}"], - "error_type": "executable_checker_rest:wrong_type", - } - except Exception as e: - return { - "valid": False, - "error": [ - f"Error in execution and type checking. Status code: {response.status_code}. Error: {str(e)}" - ], - "error_type": "executable_checker_rest:response_format_error", - } - else: - return { - "valid": False, - "error": [f"Execution result status code is not 200, got {response.status_code}"], - "error_type": "executable_checker_rest:wrong_status_code", - } - except Exception as e: - return { - "valid": False, - "error": [f"Cannot get status code of the response. Error: {str(e)}"], - "error_type": "executable_checker_rest:cannot_get_status_code", - } - - -def ast_checker(func_description, model_output, possible_answer, language, test_category, model_name): - if "parallel" in test_category: - return parallel_function_checker_no_order(func_description, model_output, possible_answer, language, model_name) - - elif "multiple" in test_category: - return multiple_function_checker(func_description, model_output, possible_answer, language, model_name) - - else: - if len(model_output) != 1: - return { - "valid": False, - "error": ["Wrong number of functions."], - "error_type": "simple_function_checker:wrong_count", - } - - return simple_function_checker( - func_description[0], - model_output[0], - possible_answer[0], - language, - model_name, - ) - - -def exec_checker(decoded_result: list, func_description: dict, test_category: str): - if "multiple" in test_category or "parallel" in test_category: - return executable_checker_parallel_no_order( - decoded_result, - func_description["execution_result"], - func_description["execution_result_type"], - ) - - else: - if len(decoded_result) != 1: - return { - "valid": False, - "error": ["Wrong number of functions."], - "error_type": "simple_exec_checker:wrong_count", - } - return executable_checker_simple( - decoded_result[0], - func_description["execution_result"][0], - func_description["execution_result_type"][0], - False, - ) - - -def is_empty_output(decoded_output): - # This function is a patch to the ast decoder for relevance detection - # Sometimes the ast decoder will parse successfully, but the input doens't really have a function call - # [], [{}], and anything that is not in function calling format is considered empty (and thus should be marked as correct) - if not is_function_calling_format_output(decoded_output): - return True - if len(decoded_output) == 0: - return True - if len(decoded_output) == 1 and len(decoded_output[0]) == 0: - return True - - -def is_function_calling_format_output(decoded_output): - # Ensure the output is a list of dictionaries - if type(decoded_output) == list: - for item in decoded_output: - if type(item) != dict: - return False - return True - return False diff --git a/llama_stack/providers/inline/scoring/basic/utils/bfcl/tree_sitter.py b/llama_stack/providers/inline/scoring/basic/utils/bfcl/tree_sitter.py deleted file mode 100644 index ed97ee360..000000000 --- a/llama_stack/providers/inline/scoring/basic/utils/bfcl/tree_sitter.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -""" -Tree-sitter changes its API with unfortunate frequency. Modules that need it should -import it from here so that we can centrally manage things as necessary. -""" - -# These currently work with tree-sitter 0.23.0 -# NOTE: Don't import tree-sitter or any of the language modules in the main module -# because not all environments have them. Import lazily inside functions where needed. - -import importlib -import typing - -if typing.TYPE_CHECKING: - import tree_sitter - - -def get_language(language: str) -> "tree_sitter.Language": - import tree_sitter - - language_module_name = f"tree_sitter_{language}" - try: - language_module = importlib.import_module(language_module_name) - except ModuleNotFoundError as exc: - raise ValueError( - f"Language {language} is not found. Please install the tree-sitter-{language} package." - ) from exc - return tree_sitter.Language(language_module.language()) - - -def get_parser(language: str, **kwargs) -> "tree_sitter.Parser": - import tree_sitter - - lang = get_language(language) - return tree_sitter.Parser(lang, **kwargs)