feat(s3 auth): add authorization support for s3 files provider (#3265)

# What does this PR do?

adds support for authorized users to the s3 files provider

## Test Plan

existing and new unit tests
This commit is contained in:
Matthew Farrellee 2025-08-29 10:14:00 -04:00 committed by GitHub
parent ed418653ec
commit e96e3c4da4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 172 additions and 70 deletions

View file

@ -6,15 +6,14 @@
from typing import Any from typing import Any
from llama_stack.core.datatypes import Api from llama_stack.core.datatypes import AccessRule, Api
from .config import S3FilesImplConfig from .config import S3FilesImplConfig
async def get_adapter_impl(config: S3FilesImplConfig, deps: dict[Api, Any]): async def get_adapter_impl(config: S3FilesImplConfig, deps: dict[Api, Any], policy: list[AccessRule] | None = None):
from .files import S3FilesImpl from .files import S3FilesImpl
# TODO: authorization policies and user separation impl = S3FilesImpl(config, policy or [])
impl = S3FilesImpl(config)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -21,8 +21,10 @@ from llama_stack.apis.files import (
OpenAIFileObject, OpenAIFileObject,
OpenAIFilePurpose, OpenAIFilePurpose,
) )
from llama_stack.core.datatypes import AccessRule
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStore, sqlstore_impl from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
from .config import S3FilesImplConfig from .config import S3FilesImplConfig
@ -89,16 +91,17 @@ class S3FilesImpl(Files):
# TODO: implement expiration, for now a silly offset # TODO: implement expiration, for now a silly offset
_SILLY_EXPIRATION_OFFSET = 100 * 365 * 24 * 60 * 60 _SILLY_EXPIRATION_OFFSET = 100 * 365 * 24 * 60 * 60
def __init__(self, config: S3FilesImplConfig) -> None: def __init__(self, config: S3FilesImplConfig, policy: list[AccessRule]) -> None:
self._config = config self._config = config
self.policy = policy
self._client: boto3.client | None = None self._client: boto3.client | None = None
self._sql_store: SqlStore | None = None self._sql_store: AuthorizedSqlStore | None = None
async def initialize(self) -> None: async def initialize(self) -> None:
self._client = _create_s3_client(self._config) self._client = _create_s3_client(self._config)
await _create_bucket_if_not_exists(self._client, self._config) await _create_bucket_if_not_exists(self._client, self._config)
self._sql_store = sqlstore_impl(self._config.metadata_store) self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store))
await self._sql_store.create_table( await self._sql_store.create_table(
"openai_files", "openai_files",
{ {
@ -121,7 +124,7 @@ class S3FilesImpl(Files):
return self._client return self._client
@property @property
def sql_store(self) -> SqlStore: def sql_store(self) -> AuthorizedSqlStore:
assert self._sql_store is not None, "Provider not initialized" assert self._sql_store is not None, "Provider not initialized"
return self._sql_store return self._sql_store
@ -189,6 +192,7 @@ class S3FilesImpl(Files):
paginated_result = await self.sql_store.fetch_all( paginated_result = await self.sql_store.fetch_all(
table="openai_files", table="openai_files",
policy=self.policy,
where=where_conditions if where_conditions else None, where=where_conditions if where_conditions else None,
order_by=[("created_at", order.value)], order_by=[("created_at", order.value)],
cursor=("id", after) if after else None, cursor=("id", after) if after else None,
@ -216,7 +220,7 @@ class S3FilesImpl(Files):
) )
async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject: async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject:
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id}) row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
if not row: if not row:
raise ResourceNotFoundError(file_id, "File", "files.list()") raise ResourceNotFoundError(file_id, "File", "files.list()")
@ -230,7 +234,7 @@ class S3FilesImpl(Files):
) )
async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse: async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse:
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id}) row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
if not row: if not row:
raise ResourceNotFoundError(file_id, "File", "files.list()") raise ResourceNotFoundError(file_id, "File", "files.list()")
@ -248,7 +252,7 @@ class S3FilesImpl(Files):
return OpenAIFileDeleteResponse(id=file_id, deleted=True) return OpenAIFileDeleteResponse(id=file_id, deleted=True)
async def openai_retrieve_file_content(self, file_id: str) -> Response: async def openai_retrieve_file_content(self, file_id: str) -> Response:
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id}) row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
if not row: if not row:
raise ResourceNotFoundError(file_id, "File", "files.list()") raise ResourceNotFoundError(file_id, "File", "files.list()")

View file

@ -0,0 +1,62 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import boto3
import pytest
from moto import mock_aws
from llama_stack.providers.remote.files.s3 import S3FilesImplConfig, get_adapter_impl
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
class MockUploadFile:
def __init__(self, content: bytes, filename: str, content_type: str = "text/plain"):
self.content = content
self.filename = filename
self.content_type = content_type
async def read(self):
return self.content
@pytest.fixture
def sample_text_file():
content = b"Hello, this is a test file for the S3 Files API!"
return MockUploadFile(content, "sample_text_file-0.txt")
@pytest.fixture
def sample_text_file2():
content = b"Hello, this is a second test file for the S3 Files API!"
return MockUploadFile(content, "sample_text_file-1.txt")
@pytest.fixture
def s3_config(tmp_path):
db_path = tmp_path / "s3_files_metadata.db"
return S3FilesImplConfig(
bucket_name=f"test-bucket-{tmp_path.name}",
region="not-a-region",
auto_create_bucket=True,
metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix()),
)
@pytest.fixture
def s3_client():
# we use `with mock_aws()` because @mock_aws decorator does not support
# being a generator
with mock_aws():
# must yield or the mock will be reset before it is used
yield boto3.client("s3")
@pytest.fixture
async def s3_provider(s3_config, s3_client): # s3_client provides the moto mock, don't remove it
provider = await get_adapter_impl(s3_config, {})
yield provider
await provider.shutdown()

View file

@ -6,63 +6,11 @@
from unittest.mock import patch from unittest.mock import patch
import boto3
import pytest import pytest
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
from moto import mock_aws
from llama_stack.apis.common.errors import ResourceNotFoundError from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.files import OpenAIFilePurpose from llama_stack.apis.files import OpenAIFilePurpose
from llama_stack.providers.remote.files.s3 import (
S3FilesImplConfig,
get_adapter_impl,
)
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
class MockUploadFile:
def __init__(self, content: bytes, filename: str, content_type: str = "text/plain"):
self.content = content
self.filename = filename
self.content_type = content_type
async def read(self):
return self.content
@pytest.fixture
def s3_config(tmp_path):
db_path = tmp_path / "s3_files_metadata.db"
return S3FilesImplConfig(
bucket_name="test-bucket",
region="not-a-region",
auto_create_bucket=True,
metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix()),
)
@pytest.fixture
def s3_client():
"""Create a mocked S3 client for testing."""
# we use `with mock_aws()` because @mock_aws decorator does not support being a generator
with mock_aws():
# must yield or the mock will be reset before it is used
yield boto3.client("s3")
@pytest.fixture
async def s3_provider(s3_config, s3_client):
"""Create an S3 files provider with mocked S3 for testing."""
provider = await get_adapter_impl(s3_config, {})
yield provider
await provider.shutdown()
@pytest.fixture
def sample_text_file():
content = b"Hello, this is a test file for the S3 Files API!"
return MockUploadFile(content, "sample_text_file.txt")
class TestS3FilesImpl: class TestS3FilesImpl:
@ -143,7 +91,7 @@ class TestS3FilesImpl:
s3_client.head_object(Bucket=s3_config.bucket_name, Key=uploaded.id) s3_client.head_object(Bucket=s3_config.bucket_name, Key=uploaded.id)
assert exc_info.value.response["Error"]["Code"] == "404" assert exc_info.value.response["Error"]["Code"] == "404"
async def test_list_files(self, s3_provider, sample_text_file): async def test_list_files(self, s3_provider, sample_text_file, sample_text_file2):
"""Test listing files after uploading some.""" """Test listing files after uploading some."""
sample_text_file.filename = "test_list_files_with_content_file1" sample_text_file.filename = "test_list_files_with_content_file1"
file1 = await s3_provider.openai_upload_file( file1 = await s3_provider.openai_upload_file(
@ -151,9 +99,9 @@ class TestS3FilesImpl:
purpose=OpenAIFilePurpose.ASSISTANTS, purpose=OpenAIFilePurpose.ASSISTANTS,
) )
file2_content = MockUploadFile(b"Second file content", "test_list_files_with_content_file2") sample_text_file2.filename = "test_list_files_with_content_file2"
file2 = await s3_provider.openai_upload_file( file2 = await s3_provider.openai_upload_file(
file=file2_content, file=sample_text_file2,
purpose=OpenAIFilePurpose.BATCH, purpose=OpenAIFilePurpose.BATCH,
) )
@ -164,7 +112,7 @@ class TestS3FilesImpl:
assert file1.id in file_ids assert file1.id in file_ids
assert file2.id in file_ids assert file2.id in file_ids
async def test_list_files_with_purpose_filter(self, s3_provider, sample_text_file): async def test_list_files_with_purpose_filter(self, s3_provider, sample_text_file, sample_text_file2):
"""Test listing files with purpose filter.""" """Test listing files with purpose filter."""
sample_text_file.filename = "test_list_files_with_purpose_filter_file1" sample_text_file.filename = "test_list_files_with_purpose_filter_file1"
file1 = await s3_provider.openai_upload_file( file1 = await s3_provider.openai_upload_file(
@ -172,9 +120,9 @@ class TestS3FilesImpl:
purpose=OpenAIFilePurpose.ASSISTANTS, purpose=OpenAIFilePurpose.ASSISTANTS,
) )
file2_content = MockUploadFile(b"Batch file content", "test_list_files_with_purpose_filter_file2") sample_text_file2.filename = "test_list_files_with_purpose_filter_file2"
await s3_provider.openai_upload_file( await s3_provider.openai_upload_file(
file=file2_content, file=sample_text_file2,
purpose=OpenAIFilePurpose.BATCH, purpose=OpenAIFilePurpose.BATCH,
) )

View file

@ -0,0 +1,89 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import patch
import pytest
from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.files import OpenAIFilePurpose
from llama_stack.core.datatypes import User
from llama_stack.providers.remote.files.s3.files import S3FilesImpl
async def test_listing_hides_other_users_file(s3_provider, sample_text_file):
"""Listing should not show files uploaded by other users."""
user_a = User("user-a", {"roles": ["team-a"]})
user_b = User("user-b", {"roles": ["team-b"]})
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
mock_get_user.return_value = user_a
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
mock_get_user.return_value = user_b
listed = await s3_provider.openai_list_files()
assert all(f.id != uploaded.id for f in listed.data)
@pytest.mark.parametrize(
"op",
[S3FilesImpl.openai_retrieve_file, S3FilesImpl.openai_retrieve_file_content, S3FilesImpl.openai_delete_file],
ids=["retrieve", "content", "delete"],
)
async def test_cannot_access_other_user_file(s3_provider, sample_text_file, op):
"""Operations (metadata/content/delete) on another user's file should raise ResourceNotFoundError.
`op` is an async callable (provider, file_id) -> awaits the requested operation.
"""
user_a = User("user-a", {"roles": ["team-a"]})
user_b = User("user-b", {"roles": ["team-b"]})
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
mock_get_user.return_value = user_a
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
mock_get_user.return_value = user_b
with pytest.raises(ResourceNotFoundError):
await op(s3_provider, uploaded.id)
async def test_shared_role_allows_listing(s3_provider, sample_text_file):
"""Listing should show files uploaded by other users when roles are shared."""
user_a = User("user-a", {"roles": ["shared-role"]})
user_b = User("user-b", {"roles": ["shared-role"]})
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
mock_get_user.return_value = user_a
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
mock_get_user.return_value = user_b
listed = await s3_provider.openai_list_files()
assert any(f.id == uploaded.id for f in listed.data)
@pytest.mark.parametrize(
"op",
[S3FilesImpl.openai_retrieve_file, S3FilesImpl.openai_retrieve_file_content, S3FilesImpl.openai_delete_file],
ids=["retrieve", "content", "delete"],
)
async def test_shared_role_allows_access(s3_provider, sample_text_file, op):
"""Operations (metadata/content/delete) on another user's file should succeed when users share a role.
`op` is an async callable (provider, file_id) -> awaits the requested operation.
"""
user_x = User("user-x", {"roles": ["shared-role"]})
user_y = User("user-y", {"roles": ["shared-role"]})
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
mock_get_user.return_value = user_x
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
mock_get_user.return_value = user_y
await op(s3_provider, uploaded.id)