diff --git a/llama_stack/providers/remote/files/s3/__init__.py b/llama_stack/providers/remote/files/s3/__init__.py index 3f5dfc88a..7027f1db3 100644 --- a/llama_stack/providers/remote/files/s3/__init__.py +++ b/llama_stack/providers/remote/files/s3/__init__.py @@ -6,15 +6,14 @@ from typing import Any -from llama_stack.core.datatypes import Api +from llama_stack.core.datatypes import AccessRule, Api from .config import S3FilesImplConfig -async def get_adapter_impl(config: S3FilesImplConfig, deps: dict[Api, Any]): +async def get_adapter_impl(config: S3FilesImplConfig, deps: dict[Api, Any], policy: list[AccessRule] | None = None): from .files import S3FilesImpl - # TODO: authorization policies and user separation - impl = S3FilesImpl(config) + impl = S3FilesImpl(config, policy or []) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/files/s3/files.py b/llama_stack/providers/remote/files/s3/files.py index 52e0cbbf4..0451f74ea 100644 --- a/llama_stack/providers/remote/files/s3/files.py +++ b/llama_stack/providers/remote/files/s3/files.py @@ -21,8 +21,10 @@ from llama_stack.apis.files import ( OpenAIFileObject, OpenAIFilePurpose, ) +from llama_stack.core.datatypes import AccessRule from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType -from llama_stack.providers.utils.sqlstore.sqlstore import SqlStore, sqlstore_impl +from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore +from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl from .config import S3FilesImplConfig @@ -89,16 +91,17 @@ class S3FilesImpl(Files): # TODO: implement expiration, for now a silly offset _SILLY_EXPIRATION_OFFSET = 100 * 365 * 24 * 60 * 60 - def __init__(self, config: S3FilesImplConfig) -> None: + def __init__(self, config: S3FilesImplConfig, policy: list[AccessRule]) -> None: self._config = config + self.policy = policy self._client: boto3.client | None = None - self._sql_store: SqlStore | None = None + self._sql_store: AuthorizedSqlStore | None = None async def initialize(self) -> None: self._client = _create_s3_client(self._config) await _create_bucket_if_not_exists(self._client, self._config) - self._sql_store = sqlstore_impl(self._config.metadata_store) + self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store)) await self._sql_store.create_table( "openai_files", { @@ -121,7 +124,7 @@ class S3FilesImpl(Files): return self._client @property - def sql_store(self) -> SqlStore: + def sql_store(self) -> AuthorizedSqlStore: assert self._sql_store is not None, "Provider not initialized" return self._sql_store @@ -189,6 +192,7 @@ class S3FilesImpl(Files): paginated_result = await self.sql_store.fetch_all( table="openai_files", + policy=self.policy, where=where_conditions if where_conditions else None, order_by=[("created_at", order.value)], cursor=("id", after) if after else None, @@ -216,7 +220,7 @@ class S3FilesImpl(Files): ) async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject: - row = await self.sql_store.fetch_one("openai_files", where={"id": file_id}) + row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id}) if not row: raise ResourceNotFoundError(file_id, "File", "files.list()") @@ -230,7 +234,7 @@ class S3FilesImpl(Files): ) async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse: - row = await self.sql_store.fetch_one("openai_files", where={"id": file_id}) + row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id}) if not row: raise ResourceNotFoundError(file_id, "File", "files.list()") @@ -248,7 +252,7 @@ class S3FilesImpl(Files): return OpenAIFileDeleteResponse(id=file_id, deleted=True) async def openai_retrieve_file_content(self, file_id: str) -> Response: - row = await self.sql_store.fetch_one("openai_files", where={"id": file_id}) + row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id}) if not row: raise ResourceNotFoundError(file_id, "File", "files.list()") diff --git a/tests/unit/providers/files/conftest.py b/tests/unit/providers/files/conftest.py new file mode 100644 index 000000000..46282e3dc --- /dev/null +++ b/tests/unit/providers/files/conftest.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import boto3 +import pytest +from moto import mock_aws + +from llama_stack.providers.remote.files.s3 import S3FilesImplConfig, get_adapter_impl +from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig + + +class MockUploadFile: + def __init__(self, content: bytes, filename: str, content_type: str = "text/plain"): + self.content = content + self.filename = filename + self.content_type = content_type + + async def read(self): + return self.content + + +@pytest.fixture +def sample_text_file(): + content = b"Hello, this is a test file for the S3 Files API!" + return MockUploadFile(content, "sample_text_file-0.txt") + + +@pytest.fixture +def sample_text_file2(): + content = b"Hello, this is a second test file for the S3 Files API!" + return MockUploadFile(content, "sample_text_file-1.txt") + + +@pytest.fixture +def s3_config(tmp_path): + db_path = tmp_path / "s3_files_metadata.db" + + return S3FilesImplConfig( + bucket_name=f"test-bucket-{tmp_path.name}", + region="not-a-region", + auto_create_bucket=True, + metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix()), + ) + + +@pytest.fixture +def s3_client(): + # we use `with mock_aws()` because @mock_aws decorator does not support + # being a generator + with mock_aws(): + # must yield or the mock will be reset before it is used + yield boto3.client("s3") + + +@pytest.fixture +async def s3_provider(s3_config, s3_client): # s3_client provides the moto mock, don't remove it + provider = await get_adapter_impl(s3_config, {}) + yield provider + await provider.shutdown() diff --git a/tests/unit/providers/files/test_s3_files.py b/tests/unit/providers/files/test_s3_files.py index daa250f10..3bd4836df 100644 --- a/tests/unit/providers/files/test_s3_files.py +++ b/tests/unit/providers/files/test_s3_files.py @@ -6,63 +6,11 @@ from unittest.mock import patch -import boto3 import pytest from botocore.exceptions import ClientError -from moto import mock_aws from llama_stack.apis.common.errors import ResourceNotFoundError from llama_stack.apis.files import OpenAIFilePurpose -from llama_stack.providers.remote.files.s3 import ( - S3FilesImplConfig, - get_adapter_impl, -) -from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig - - -class MockUploadFile: - def __init__(self, content: bytes, filename: str, content_type: str = "text/plain"): - self.content = content - self.filename = filename - self.content_type = content_type - - async def read(self): - return self.content - - -@pytest.fixture -def s3_config(tmp_path): - db_path = tmp_path / "s3_files_metadata.db" - - return S3FilesImplConfig( - bucket_name="test-bucket", - region="not-a-region", - auto_create_bucket=True, - metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix()), - ) - - -@pytest.fixture -def s3_client(): - """Create a mocked S3 client for testing.""" - # we use `with mock_aws()` because @mock_aws decorator does not support being a generator - with mock_aws(): - # must yield or the mock will be reset before it is used - yield boto3.client("s3") - - -@pytest.fixture -async def s3_provider(s3_config, s3_client): - """Create an S3 files provider with mocked S3 for testing.""" - provider = await get_adapter_impl(s3_config, {}) - yield provider - await provider.shutdown() - - -@pytest.fixture -def sample_text_file(): - content = b"Hello, this is a test file for the S3 Files API!" - return MockUploadFile(content, "sample_text_file.txt") class TestS3FilesImpl: @@ -143,7 +91,7 @@ class TestS3FilesImpl: s3_client.head_object(Bucket=s3_config.bucket_name, Key=uploaded.id) assert exc_info.value.response["Error"]["Code"] == "404" - async def test_list_files(self, s3_provider, sample_text_file): + async def test_list_files(self, s3_provider, sample_text_file, sample_text_file2): """Test listing files after uploading some.""" sample_text_file.filename = "test_list_files_with_content_file1" file1 = await s3_provider.openai_upload_file( @@ -151,9 +99,9 @@ class TestS3FilesImpl: purpose=OpenAIFilePurpose.ASSISTANTS, ) - file2_content = MockUploadFile(b"Second file content", "test_list_files_with_content_file2") + sample_text_file2.filename = "test_list_files_with_content_file2" file2 = await s3_provider.openai_upload_file( - file=file2_content, + file=sample_text_file2, purpose=OpenAIFilePurpose.BATCH, ) @@ -164,7 +112,7 @@ class TestS3FilesImpl: assert file1.id in file_ids assert file2.id in file_ids - async def test_list_files_with_purpose_filter(self, s3_provider, sample_text_file): + async def test_list_files_with_purpose_filter(self, s3_provider, sample_text_file, sample_text_file2): """Test listing files with purpose filter.""" sample_text_file.filename = "test_list_files_with_purpose_filter_file1" file1 = await s3_provider.openai_upload_file( @@ -172,9 +120,9 @@ class TestS3FilesImpl: purpose=OpenAIFilePurpose.ASSISTANTS, ) - file2_content = MockUploadFile(b"Batch file content", "test_list_files_with_purpose_filter_file2") + sample_text_file2.filename = "test_list_files_with_purpose_filter_file2" await s3_provider.openai_upload_file( - file=file2_content, + file=sample_text_file2, purpose=OpenAIFilePurpose.BATCH, ) diff --git a/tests/unit/providers/files/test_s3_files_auth.py b/tests/unit/providers/files/test_s3_files_auth.py new file mode 100644 index 000000000..6097f2808 --- /dev/null +++ b/tests/unit/providers/files/test_s3_files_auth.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import patch + +import pytest + +from llama_stack.apis.common.errors import ResourceNotFoundError +from llama_stack.apis.files import OpenAIFilePurpose +from llama_stack.core.datatypes import User +from llama_stack.providers.remote.files.s3.files import S3FilesImpl + + +async def test_listing_hides_other_users_file(s3_provider, sample_text_file): + """Listing should not show files uploaded by other users.""" + user_a = User("user-a", {"roles": ["team-a"]}) + user_b = User("user-b", {"roles": ["team-b"]}) + + with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: + mock_get_user.return_value = user_a + uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS) + + with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: + mock_get_user.return_value = user_b + listed = await s3_provider.openai_list_files() + assert all(f.id != uploaded.id for f in listed.data) + + +@pytest.mark.parametrize( + "op", + [S3FilesImpl.openai_retrieve_file, S3FilesImpl.openai_retrieve_file_content, S3FilesImpl.openai_delete_file], + ids=["retrieve", "content", "delete"], +) +async def test_cannot_access_other_user_file(s3_provider, sample_text_file, op): + """Operations (metadata/content/delete) on another user's file should raise ResourceNotFoundError. + + `op` is an async callable (provider, file_id) -> awaits the requested operation. + """ + user_a = User("user-a", {"roles": ["team-a"]}) + user_b = User("user-b", {"roles": ["team-b"]}) + + with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: + mock_get_user.return_value = user_a + uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS) + + with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: + mock_get_user.return_value = user_b + with pytest.raises(ResourceNotFoundError): + await op(s3_provider, uploaded.id) + + +async def test_shared_role_allows_listing(s3_provider, sample_text_file): + """Listing should show files uploaded by other users when roles are shared.""" + user_a = User("user-a", {"roles": ["shared-role"]}) + user_b = User("user-b", {"roles": ["shared-role"]}) + + with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: + mock_get_user.return_value = user_a + uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS) + + with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: + mock_get_user.return_value = user_b + listed = await s3_provider.openai_list_files() + assert any(f.id == uploaded.id for f in listed.data) + + +@pytest.mark.parametrize( + "op", + [S3FilesImpl.openai_retrieve_file, S3FilesImpl.openai_retrieve_file_content, S3FilesImpl.openai_delete_file], + ids=["retrieve", "content", "delete"], +) +async def test_shared_role_allows_access(s3_provider, sample_text_file, op): + """Operations (metadata/content/delete) on another user's file should succeed when users share a role. + + `op` is an async callable (provider, file_id) -> awaits the requested operation. + """ + user_x = User("user-x", {"roles": ["shared-role"]}) + user_y = User("user-y", {"roles": ["shared-role"]}) + + with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: + mock_get_user.return_value = user_x + uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS) + + with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: + mock_get_user.return_value = user_y + await op(s3_provider, uploaded.id)