mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
feat(s3 auth): add authorization support for s3 files provider (#3265)
# What does this PR do? adds support for authorized users to the s3 files provider ## Test Plan existing and new unit tests
This commit is contained in:
parent
ed418653ec
commit
e96e3c4da4
5 changed files with 172 additions and 70 deletions
|
@ -6,15 +6,14 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import Api
|
||||
from llama_stack.core.datatypes import AccessRule, Api
|
||||
|
||||
from .config import S3FilesImplConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: S3FilesImplConfig, deps: dict[Api, Any]):
|
||||
async def get_adapter_impl(config: S3FilesImplConfig, deps: dict[Api, Any], policy: list[AccessRule] | None = None):
|
||||
from .files import S3FilesImpl
|
||||
|
||||
# TODO: authorization policies and user separation
|
||||
impl = S3FilesImpl(config)
|
||||
impl = S3FilesImpl(config, policy or [])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -21,8 +21,10 @@ from llama_stack.apis.files import (
|
|||
OpenAIFileObject,
|
||||
OpenAIFilePurpose,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStore, sqlstore_impl
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
|
||||
|
||||
from .config import S3FilesImplConfig
|
||||
|
||||
|
@ -89,16 +91,17 @@ class S3FilesImpl(Files):
|
|||
# TODO: implement expiration, for now a silly offset
|
||||
_SILLY_EXPIRATION_OFFSET = 100 * 365 * 24 * 60 * 60
|
||||
|
||||
def __init__(self, config: S3FilesImplConfig) -> None:
|
||||
def __init__(self, config: S3FilesImplConfig, policy: list[AccessRule]) -> None:
|
||||
self._config = config
|
||||
self.policy = policy
|
||||
self._client: boto3.client | None = None
|
||||
self._sql_store: SqlStore | None = None
|
||||
self._sql_store: AuthorizedSqlStore | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self._client = _create_s3_client(self._config)
|
||||
await _create_bucket_if_not_exists(self._client, self._config)
|
||||
|
||||
self._sql_store = sqlstore_impl(self._config.metadata_store)
|
||||
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store))
|
||||
await self._sql_store.create_table(
|
||||
"openai_files",
|
||||
{
|
||||
|
@ -121,7 +124,7 @@ class S3FilesImpl(Files):
|
|||
return self._client
|
||||
|
||||
@property
|
||||
def sql_store(self) -> SqlStore:
|
||||
def sql_store(self) -> AuthorizedSqlStore:
|
||||
assert self._sql_store is not None, "Provider not initialized"
|
||||
return self._sql_store
|
||||
|
||||
|
@ -189,6 +192,7 @@ class S3FilesImpl(Files):
|
|||
|
||||
paginated_result = await self.sql_store.fetch_all(
|
||||
table="openai_files",
|
||||
policy=self.policy,
|
||||
where=where_conditions if where_conditions else None,
|
||||
order_by=[("created_at", order.value)],
|
||||
cursor=("id", after) if after else None,
|
||||
|
@ -216,7 +220,7 @@ class S3FilesImpl(Files):
|
|||
)
|
||||
|
||||
async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject:
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
|
||||
if not row:
|
||||
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||
|
||||
|
@ -230,7 +234,7 @@ class S3FilesImpl(Files):
|
|||
)
|
||||
|
||||
async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse:
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
|
||||
if not row:
|
||||
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||
|
||||
|
@ -248,7 +252,7 @@ class S3FilesImpl(Files):
|
|||
return OpenAIFileDeleteResponse(id=file_id, deleted=True)
|
||||
|
||||
async def openai_retrieve_file_content(self, file_id: str) -> Response:
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
|
||||
if not row:
|
||||
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||
|
||||
|
|
62
tests/unit/providers/files/conftest.py
Normal file
62
tests/unit/providers/files/conftest.py
Normal file
|
@ -0,0 +1,62 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import boto3
|
||||
import pytest
|
||||
from moto import mock_aws
|
||||
|
||||
from llama_stack.providers.remote.files.s3 import S3FilesImplConfig, get_adapter_impl
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
|
||||
|
||||
class MockUploadFile:
|
||||
def __init__(self, content: bytes, filename: str, content_type: str = "text/plain"):
|
||||
self.content = content
|
||||
self.filename = filename
|
||||
self.content_type = content_type
|
||||
|
||||
async def read(self):
|
||||
return self.content
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_text_file():
|
||||
content = b"Hello, this is a test file for the S3 Files API!"
|
||||
return MockUploadFile(content, "sample_text_file-0.txt")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_text_file2():
|
||||
content = b"Hello, this is a second test file for the S3 Files API!"
|
||||
return MockUploadFile(content, "sample_text_file-1.txt")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_config(tmp_path):
|
||||
db_path = tmp_path / "s3_files_metadata.db"
|
||||
|
||||
return S3FilesImplConfig(
|
||||
bucket_name=f"test-bucket-{tmp_path.name}",
|
||||
region="not-a-region",
|
||||
auto_create_bucket=True,
|
||||
metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix()),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_client():
|
||||
# we use `with mock_aws()` because @mock_aws decorator does not support
|
||||
# being a generator
|
||||
with mock_aws():
|
||||
# must yield or the mock will be reset before it is used
|
||||
yield boto3.client("s3")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def s3_provider(s3_config, s3_client): # s3_client provides the moto mock, don't remove it
|
||||
provider = await get_adapter_impl(s3_config, {})
|
||||
yield provider
|
||||
await provider.shutdown()
|
|
@ -6,63 +6,11 @@
|
|||
|
||||
from unittest.mock import patch
|
||||
|
||||
import boto3
|
||||
import pytest
|
||||
from botocore.exceptions import ClientError
|
||||
from moto import mock_aws
|
||||
|
||||
from llama_stack.apis.common.errors import ResourceNotFoundError
|
||||
from llama_stack.apis.files import OpenAIFilePurpose
|
||||
from llama_stack.providers.remote.files.s3 import (
|
||||
S3FilesImplConfig,
|
||||
get_adapter_impl,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
|
||||
|
||||
class MockUploadFile:
|
||||
def __init__(self, content: bytes, filename: str, content_type: str = "text/plain"):
|
||||
self.content = content
|
||||
self.filename = filename
|
||||
self.content_type = content_type
|
||||
|
||||
async def read(self):
|
||||
return self.content
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_config(tmp_path):
|
||||
db_path = tmp_path / "s3_files_metadata.db"
|
||||
|
||||
return S3FilesImplConfig(
|
||||
bucket_name="test-bucket",
|
||||
region="not-a-region",
|
||||
auto_create_bucket=True,
|
||||
metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix()),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_client():
|
||||
"""Create a mocked S3 client for testing."""
|
||||
# we use `with mock_aws()` because @mock_aws decorator does not support being a generator
|
||||
with mock_aws():
|
||||
# must yield or the mock will be reset before it is used
|
||||
yield boto3.client("s3")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def s3_provider(s3_config, s3_client):
|
||||
"""Create an S3 files provider with mocked S3 for testing."""
|
||||
provider = await get_adapter_impl(s3_config, {})
|
||||
yield provider
|
||||
await provider.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_text_file():
|
||||
content = b"Hello, this is a test file for the S3 Files API!"
|
||||
return MockUploadFile(content, "sample_text_file.txt")
|
||||
|
||||
|
||||
class TestS3FilesImpl:
|
||||
|
@ -143,7 +91,7 @@ class TestS3FilesImpl:
|
|||
s3_client.head_object(Bucket=s3_config.bucket_name, Key=uploaded.id)
|
||||
assert exc_info.value.response["Error"]["Code"] == "404"
|
||||
|
||||
async def test_list_files(self, s3_provider, sample_text_file):
|
||||
async def test_list_files(self, s3_provider, sample_text_file, sample_text_file2):
|
||||
"""Test listing files after uploading some."""
|
||||
sample_text_file.filename = "test_list_files_with_content_file1"
|
||||
file1 = await s3_provider.openai_upload_file(
|
||||
|
@ -151,9 +99,9 @@ class TestS3FilesImpl:
|
|||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
file2_content = MockUploadFile(b"Second file content", "test_list_files_with_content_file2")
|
||||
sample_text_file2.filename = "test_list_files_with_content_file2"
|
||||
file2 = await s3_provider.openai_upload_file(
|
||||
file=file2_content,
|
||||
file=sample_text_file2,
|
||||
purpose=OpenAIFilePurpose.BATCH,
|
||||
)
|
||||
|
||||
|
@ -164,7 +112,7 @@ class TestS3FilesImpl:
|
|||
assert file1.id in file_ids
|
||||
assert file2.id in file_ids
|
||||
|
||||
async def test_list_files_with_purpose_filter(self, s3_provider, sample_text_file):
|
||||
async def test_list_files_with_purpose_filter(self, s3_provider, sample_text_file, sample_text_file2):
|
||||
"""Test listing files with purpose filter."""
|
||||
sample_text_file.filename = "test_list_files_with_purpose_filter_file1"
|
||||
file1 = await s3_provider.openai_upload_file(
|
||||
|
@ -172,9 +120,9 @@ class TestS3FilesImpl:
|
|||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
file2_content = MockUploadFile(b"Batch file content", "test_list_files_with_purpose_filter_file2")
|
||||
sample_text_file2.filename = "test_list_files_with_purpose_filter_file2"
|
||||
await s3_provider.openai_upload_file(
|
||||
file=file2_content,
|
||||
file=sample_text_file2,
|
||||
purpose=OpenAIFilePurpose.BATCH,
|
||||
)
|
||||
|
||||
|
|
89
tests/unit/providers/files/test_s3_files_auth.py
Normal file
89
tests/unit/providers/files/test_s3_files_auth.py
Normal file
|
@ -0,0 +1,89 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.errors import ResourceNotFoundError
|
||||
from llama_stack.apis.files import OpenAIFilePurpose
|
||||
from llama_stack.core.datatypes import User
|
||||
from llama_stack.providers.remote.files.s3.files import S3FilesImpl
|
||||
|
||||
|
||||
async def test_listing_hides_other_users_file(s3_provider, sample_text_file):
|
||||
"""Listing should not show files uploaded by other users."""
|
||||
user_a = User("user-a", {"roles": ["team-a"]})
|
||||
user_b = User("user-b", {"roles": ["team-b"]})
|
||||
|
||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
mock_get_user.return_value = user_a
|
||||
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
|
||||
|
||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
mock_get_user.return_value = user_b
|
||||
listed = await s3_provider.openai_list_files()
|
||||
assert all(f.id != uploaded.id for f in listed.data)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"op",
|
||||
[S3FilesImpl.openai_retrieve_file, S3FilesImpl.openai_retrieve_file_content, S3FilesImpl.openai_delete_file],
|
||||
ids=["retrieve", "content", "delete"],
|
||||
)
|
||||
async def test_cannot_access_other_user_file(s3_provider, sample_text_file, op):
|
||||
"""Operations (metadata/content/delete) on another user's file should raise ResourceNotFoundError.
|
||||
|
||||
`op` is an async callable (provider, file_id) -> awaits the requested operation.
|
||||
"""
|
||||
user_a = User("user-a", {"roles": ["team-a"]})
|
||||
user_b = User("user-b", {"roles": ["team-b"]})
|
||||
|
||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
mock_get_user.return_value = user_a
|
||||
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
|
||||
|
||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
mock_get_user.return_value = user_b
|
||||
with pytest.raises(ResourceNotFoundError):
|
||||
await op(s3_provider, uploaded.id)
|
||||
|
||||
|
||||
async def test_shared_role_allows_listing(s3_provider, sample_text_file):
|
||||
"""Listing should show files uploaded by other users when roles are shared."""
|
||||
user_a = User("user-a", {"roles": ["shared-role"]})
|
||||
user_b = User("user-b", {"roles": ["shared-role"]})
|
||||
|
||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
mock_get_user.return_value = user_a
|
||||
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
|
||||
|
||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
mock_get_user.return_value = user_b
|
||||
listed = await s3_provider.openai_list_files()
|
||||
assert any(f.id == uploaded.id for f in listed.data)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"op",
|
||||
[S3FilesImpl.openai_retrieve_file, S3FilesImpl.openai_retrieve_file_content, S3FilesImpl.openai_delete_file],
|
||||
ids=["retrieve", "content", "delete"],
|
||||
)
|
||||
async def test_shared_role_allows_access(s3_provider, sample_text_file, op):
|
||||
"""Operations (metadata/content/delete) on another user's file should succeed when users share a role.
|
||||
|
||||
`op` is an async callable (provider, file_id) -> awaits the requested operation.
|
||||
"""
|
||||
user_x = User("user-x", {"roles": ["shared-role"]})
|
||||
user_y = User("user-y", {"roles": ["shared-role"]})
|
||||
|
||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
mock_get_user.return_value = user_x
|
||||
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
|
||||
|
||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
mock_get_user.return_value = user_y
|
||||
await op(s3_provider, uploaded.id)
|
Loading…
Add table
Add a link
Reference in a new issue