llama-stack-mirror/src/llama_stack/providers/remote/files/s3/files.py
Derek Higgins 89f9ae359d fix(files): Enforce DELETE action permission for file deletion
Previously, file deletion only checked READ permission via the
_lookup_file_id() method. This meant any user with READ access to
a file could also delete it, making it impossible to configure
read-only file access.

This change adds an 'action' parameter to fetch_all() and fetch_one()
in AuthorizedSqlStore, defaulting to Action.READ for backward
compatibility. The openai_delete_file() method now passes
Action.DELETE, ensuring proper RBAC enforcement.

With this fix, access policies can now distinguish between
Users who can read/list files but not delete them

Closes: #4274

Signed-off-by: Derek Higgins <derekh@redhat.com>
2025-12-02 16:05:31 +00:00

322 lines
12 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Annotated, Any, cast
import boto3
from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError
from fastapi import Depends, File, Form, Response, UploadFile
if TYPE_CHECKING:
from mypy_boto3_s3.client import S3Client
from llama_stack.core.access_control.datatypes import Action
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.id_generation import generate_object_id
from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.core.storage.sqlstore.sqlstore import sqlstore_impl
from llama_stack.providers.utils.files.form_data import parse_expires_after
from llama_stack_api import (
ExpiresAfter,
Files,
ListOpenAIFileResponse,
OpenAIFileDeleteResponse,
OpenAIFileObject,
OpenAIFilePurpose,
Order,
ResourceNotFoundError,
)
from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType
from .config import S3FilesImplConfig
# TODO: provider data for S3 credentials
def _create_s3_client(config: S3FilesImplConfig) -> "S3Client":
try:
s3_config = {
"region_name": config.region,
}
# endpoint URL if specified (for MinIO, LocalStack, etc.)
if config.endpoint_url:
s3_config["endpoint_url"] = config.endpoint_url
if config.aws_access_key_id and config.aws_secret_access_key:
s3_config.update(
{
"aws_access_key_id": config.aws_access_key_id,
"aws_secret_access_key": config.aws_secret_access_key,
}
)
# Both cast and type:ignore are needed here:
# - cast tells mypy the return type for downstream usage (S3Client vs generic client)
# - type:ignore suppresses the call-overload error from boto3's complex overloaded signatures
return cast("S3Client", boto3.client("s3", **s3_config)) # type: ignore[call-overload]
except (BotoCoreError, NoCredentialsError) as e:
raise RuntimeError(f"Failed to initialize S3 client: {e}") from e
async def _create_bucket_if_not_exists(client: "S3Client", config: S3FilesImplConfig) -> None:
try:
client.head_bucket(Bucket=config.bucket_name)
except ClientError as e:
error_code = e.response["Error"]["Code"]
if error_code == "404":
if not config.auto_create_bucket:
raise RuntimeError(
f"S3 bucket '{config.bucket_name}' does not exist. "
f"Either create the bucket manually or set 'auto_create_bucket: true' in your configuration."
) from e
try:
# For us-east-1, we can't specify LocationConstraint
if config.region == "us-east-1":
client.create_bucket(Bucket=config.bucket_name)
else:
client.create_bucket(
Bucket=config.bucket_name,
CreateBucketConfiguration=cast(Any, {"LocationConstraint": config.region}),
)
except ClientError as create_error:
raise RuntimeError(
f"Failed to create S3 bucket '{config.bucket_name}': {create_error}"
) from create_error
elif error_code == "403":
raise RuntimeError(f"Access denied to S3 bucket '{config.bucket_name}'") from e
else:
raise RuntimeError(f"Failed to access S3 bucket '{config.bucket_name}': {e}") from e
def _make_file_object(
*,
id: str,
filename: str,
purpose: str,
bytes: int,
created_at: int,
expires_at: int,
**kwargs: Any, # here to ignore any additional fields, e.g. extra fields from AuthorizedSqlStore
) -> OpenAIFileObject:
"""
Construct an OpenAIFileObject and normalize expires_at.
If expires_at is greater than the max we treat it as no-expiration and
return None for expires_at.
The OpenAI spec says expires_at type is Integer, but the implementation
will return None for no expiration.
"""
obj = OpenAIFileObject(
id=id,
filename=filename,
purpose=OpenAIFilePurpose(purpose),
bytes=bytes,
created_at=created_at,
expires_at=expires_at,
)
if obj.expires_at is not None and obj.expires_at > (obj.created_at + ExpiresAfter.MAX):
obj.expires_at = None # type: ignore
return obj
class S3FilesImpl(Files):
"""S3-based implementation of the Files API."""
def __init__(self, config: S3FilesImplConfig, policy: list[AccessRule]) -> None:
self._config = config
self.policy = policy
self._client: S3Client | None = None
self._sql_store: AuthorizedSqlStore | None = None
def _now(self) -> int:
"""Return current UTC timestamp as int seconds."""
return int(datetime.now(UTC).timestamp())
async def _get_file(
self, file_id: str, return_expired: bool = False, action: Action = Action.READ
) -> dict[str, Any]:
where: dict[str, str | dict] = {"id": file_id}
if not return_expired:
where["expires_at"] = {">": self._now()}
if not (row := await self.sql_store.fetch_one("openai_files", where=where, action=action)):
raise ResourceNotFoundError(file_id, "File", "files.list()")
return row
async def _delete_file(self, file_id: str) -> None:
"""Delete a file from S3 and the database."""
try:
self.client.delete_object(
Bucket=self._config.bucket_name,
Key=file_id,
)
except ClientError as e:
if e.response["Error"]["Code"] != "NoSuchKey":
raise RuntimeError(f"Failed to delete file from S3: {e}") from e
await self.sql_store.delete("openai_files", where={"id": file_id})
async def _delete_if_expired(self, file_id: str) -> None:
"""If the file exists and is expired, delete it."""
if row := await self._get_file(file_id, return_expired=True):
if (expires_at := row.get("expires_at")) and expires_at <= self._now():
await self._delete_file(file_id)
async def initialize(self) -> None:
self._client = _create_s3_client(self._config)
await _create_bucket_if_not_exists(self._client, self._config)
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store), self.policy)
await self._sql_store.create_table(
"openai_files",
{
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
"filename": ColumnType.STRING,
"purpose": ColumnType.STRING,
"bytes": ColumnType.INTEGER,
"created_at": ColumnType.INTEGER,
"expires_at": ColumnType.INTEGER,
# TODO: add s3_etag field for integrity checking
},
)
async def shutdown(self) -> None:
pass
@property
def client(self) -> "S3Client":
assert self._client is not None, "Provider not initialized"
return self._client
@property
def sql_store(self) -> AuthorizedSqlStore:
assert self._sql_store is not None, "Provider not initialized"
return self._sql_store
async def openai_upload_file(
self,
file: Annotated[UploadFile, File()],
purpose: Annotated[OpenAIFilePurpose, Form()],
expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
) -> OpenAIFileObject:
file_id = generate_object_id("file", lambda: f"file-{uuid.uuid4().hex}")
filename = getattr(file, "filename", None) or "uploaded_file"
created_at = self._now()
# the default is no expiration.
# to implement no expiration we set an expiration beyond the max.
# we'll hide this fact from users when returning the file object.
expires_at = created_at + ExpiresAfter.MAX * 42
# the default for BATCH files is 30 days, which happens to be the expiration max.
if purpose == OpenAIFilePurpose.BATCH:
expires_at = created_at + ExpiresAfter.MAX
if expires_after is not None:
expires_at = created_at + expires_after.seconds
content = await file.read()
file_size = len(content)
entry: dict[str, Any] = {
"id": file_id,
"filename": filename,
"purpose": purpose.value,
"bytes": file_size,
"created_at": created_at,
"expires_at": expires_at,
}
await self.sql_store.insert("openai_files", entry)
try:
self.client.put_object(
Bucket=self._config.bucket_name,
Key=file_id,
Body=content,
# TODO: enable server-side encryption
)
except ClientError as e:
await self.sql_store.delete("openai_files", where={"id": file_id})
raise RuntimeError(f"Failed to upload file to S3: {e}") from e
return _make_file_object(**entry)
async def openai_list_files(
self,
after: str | None = None,
limit: int | None = 10000,
order: Order | None = Order.desc,
purpose: OpenAIFilePurpose | None = None,
) -> ListOpenAIFileResponse:
# this purely defensive. it should not happen because the router also default to Order.desc.
if not order:
order = Order.desc
where_conditions: dict[str, Any] = {"expires_at": {">": self._now()}}
if purpose:
where_conditions["purpose"] = purpose.value
paginated_result = await self.sql_store.fetch_all(
table="openai_files",
where=where_conditions,
order_by=[("created_at", order.value)],
cursor=("id", after) if after else None,
limit=limit,
)
files = [_make_file_object(**row) for row in paginated_result.data]
return ListOpenAIFileResponse(
data=files,
has_more=paginated_result.has_more,
# empty string or None? spec says str, ref impl returns str | None, we go with spec
first_id=files[0].id if files else "",
last_id=files[-1].id if files else "",
)
async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject:
await self._delete_if_expired(file_id)
row = await self._get_file(file_id)
return _make_file_object(**row)
async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse:
await self._delete_if_expired(file_id)
_ = await self._get_file(file_id, action=Action.DELETE) # raises if not found
await self._delete_file(file_id)
return OpenAIFileDeleteResponse(id=file_id, deleted=True)
async def openai_retrieve_file_content(self, file_id: str) -> Response:
await self._delete_if_expired(file_id)
row = await self._get_file(file_id)
try:
response = self.client.get_object(
Bucket=self._config.bucket_name,
Key=row["id"],
)
# TODO: can we stream this instead of loading it into memory
content = response["Body"].read()
except ClientError as e:
if e.response["Error"]["Code"] == "NoSuchKey":
await self._delete_file(file_id)
raise ResourceNotFoundError(file_id, "File", "files.list()") from e
raise RuntimeError(f"Failed to download file from S3: {e}") from e
return Response(
content=content,
media_type="application/octet-stream",
headers={"Content-Disposition": f'attachment; filename="{row["filename"]}"'},
)