feat: add s3 provider to files API

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-04-01 11:46:42 +02:00
parent e3ad17ec5e
commit 749cbcca31
No known key found for this signature in database
17 changed files with 614 additions and 132 deletions

View file

@ -8,6 +8,7 @@ from typing import Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
@ -34,17 +35,6 @@ class BucketResponse(BaseModel):
name: str
@json_schema_type
class ListBucketResponse(BaseModel):
"""
Response representing a list of file entries.
:param data: List of FileResponse entries
"""
data: list[BucketResponse]
@json_schema_type
class FileResponse(BaseModel):
"""
@ -66,17 +56,6 @@ class FileResponse(BaseModel):
created_at: int
@json_schema_type
class ListFileResponse(BaseModel):
"""
Response representing a list of file entries.
:param data: List of FileResponse entries
"""
data: list[FileResponse]
@runtime_checkable
@trace_protocol
class Files(Protocol):
@ -98,7 +77,7 @@ class Files(Protocol):
"""
...
@webmethod(route="/files/session:{upload_id}", method="POST", raw_bytes_request_body=True)
@webmethod(route="/files/session/{upload_id}", method="POST", raw_bytes_request_body=True)
async def upload_content_to_session(
self,
upload_id: str,
@ -111,7 +90,7 @@ class Files(Protocol):
"""
...
@webmethod(route="/files/session:{upload_id}", method="GET")
@webmethod(route="/files/session/{upload_id}", method="GET")
async def get_upload_session_info(
self,
upload_id: str,
@ -126,10 +105,15 @@ class Files(Protocol):
@webmethod(route="/files", method="GET")
async def list_all_buckets(
self,
bucket: str,
) -> ListBucketResponse:
page: int | None = None,
size: int | None = None,
) -> PaginatedResponse:
"""
List all buckets.
:param page: The page number (1-based). If None, starts from first page.
:param size: Number of items per page. If None or -1, returns all items.
:return: PaginatedResponse with the list of buckets
"""
...
@ -137,11 +121,16 @@ class Files(Protocol):
async def list_files_in_bucket(
self,
bucket: str,
) -> ListFileResponse:
page: int | None = None,
size: int | None = None,
) -> PaginatedResponse:
"""
List all files in a bucket.
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-)
:param page: The page number (1-based). If None, starts from first page.
:param size: Number of items per page. If None or -1, returns all items.
:return: PaginatedResponse with the list of files
"""
...

View file

@ -4,8 +4,25 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.datatypes import ProviderSpec
from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
ProviderSpec,
remote_provider_spec,
)
def available_providers() -> list[ProviderSpec]:
return []
return [
remote_provider_spec(
api=Api.files,
adapter=AdapterSpec(
adapter_type="s3",
pip_packages=["aioboto3"],
module="llama_stack.providers.remote.files.object.s3",
config_class="llama_stack.providers.remote.files.object.s3.config.S3FilesImplConfig",
provider_data_validator="llama_stack.providers.remote.files.object.s3.S3ProviderDataValidator",
),
),
]

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import S3FilesImplConfig
async def get_adapter_impl(config: S3FilesImplConfig, _deps):
from .s3_files import S3FilesAdapter
impl = S3FilesAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,37 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore import KVStoreConfig
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
class S3FilesImplConfig(BaseModel):
"""Configuration for S3 file storage provider."""
aws_access_key_id: str = Field(description="AWS access key ID")
aws_secret_access_key: str = Field(description="AWS secret access key")
region_name: str | None = Field(default=None, description="AWS region name")
endpoint_url: str | None = Field(default=None, description="Optional endpoint URL for S3 compatible services")
bucket_name: str | None = Field(default=None, description="Default S3 bucket name")
verify_tls: bool = Field(default=True, description="Verify TLS certificates")
persistent_store: KVStoreConfig
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict:
return {
"aws_access_key_id": "your-access-key-id",
"aws_secret_access_key": "your-secret-access-key",
"region_name": "us-west-2",
"endpoint_url": None,
"bucket_name": "your-bucket-name",
"verify_tls": True,
"persistence_store": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="files_s3_store.db",
),
}

View file

@ -0,0 +1,76 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import logging
from datetime import datetime, timezone
from pydantic import BaseModel
from llama_stack.apis.files.files import FileUploadResponse
from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__)
class UploadSessionInfo(BaseModel):
"""Information about an upload session."""
upload_id: str
bucket: str
key: str
mime_type: str
size: int
url: str
created_at: datetime
class S3FilesPersistence:
def __init__(self, kvstore: KVStore):
self._kvstore = kvstore
self._store = None
async def _get_store(self) -> KVStore:
"""Get the kvstore instance, initializing it if needed."""
if self._store is None:
self._store = await anext(self._kvstore)
return self._store
async def store_upload_session(
self, session_info: FileUploadResponse, bucket: str, key: str, mime_type: str, size: int
):
"""Store upload session information."""
upload_info = UploadSessionInfo(
upload_id=session_info.id,
bucket=bucket,
key=key,
mime_type=mime_type,
size=size,
url=session_info.url,
created_at=datetime.now(timezone.utc),
)
store = await self._get_store()
await store.set(
key=f"upload_session:{session_info.id}",
value=upload_info.model_dump_json(),
)
async def get_upload_session(self, upload_id: str) -> UploadSessionInfo | None:
"""Get upload session information."""
store = await self._get_store()
value = await store.get(
key=f"upload_session:{upload_id}",
)
if not value:
return None
return UploadSessionInfo(**json.loads(value))
async def delete_upload_session(self, upload_id: str) -> None:
"""Delete upload session information."""
store = await self._get_store()
await store.delete(key=f"upload_session:{upload_id}")

View file

@ -0,0 +1,235 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import aioboto3
from botocore.exceptions import ClientError
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.files.files import (
BucketResponse,
FileResponse,
Files,
FileUploadResponse,
)
from llama_stack.providers.utils.pagination import paginate_records
from .config import S3ImplConfig
class S3FilesAdapter(Files):
def __init__(self, config: S3ImplConfig):
self.config = config
self.session = aioboto3.Session(
aws_access_key_id=config.aws_access_key_id,
aws_secret_access_key=config.aws_secret_access_key,
region_name=config.region_name,
)
async def initialize(self):
# TODO: health check?
pass
async def create_upload_session(
self,
bucket: str,
key: str,
mime_type: str,
size: int,
) -> FileUploadResponse:
"""Create a presigned URL for uploading a file to S3."""
try:
async with self.session.client(
"s3",
endpoint_url=self.config.endpoint_url,
) as s3:
url = await s3.generate_presigned_url(
"put_object",
Params={
"Bucket": bucket,
"Key": key,
"ContentType": mime_type,
},
ExpiresIn=3600, # URL expires in 1 hour
)
return FileUploadResponse(
id=f"{bucket}/{key}",
url=url,
offset=0,
size=size,
)
except ClientError as e:
raise Exception(f"Failed to create upload session: {str(e)}") from e
async def upload_content_to_session(
self,
upload_id: str,
) -> FileResponse | None:
"""Upload content to S3 using the upload session."""
bucket, key = upload_id.split("/", 1)
try:
async with self.session.client(
"s3",
endpoint_url=self.config.endpoint_url,
) as s3:
response = await s3.head_object(Bucket=bucket, Key=key)
url = await s3.generate_presigned_url(
"get_object",
Params={
"Bucket": bucket,
"Key": key,
},
ExpiresIn=3600,
)
return FileResponse(
bucket=bucket,
key=key,
mime_type=response.get("ContentType", "application/octet-stream"),
url=url,
bytes=response["ContentLength"],
created_at=int(response["LastModified"].timestamp()),
)
except ClientError:
return None
async def get_upload_session_info(
self,
upload_id: str,
) -> FileUploadResponse:
"""Get information about an upload session."""
bucket, key = upload_id.split("/", 1)
try:
async with self.session.client(
"s3",
endpoint_url=self.config.endpoint_url,
) as s3:
response = await s3.head_object(Bucket=bucket, Key=key)
url = await s3.generate_presigned_url(
"put_object",
Params={
"Bucket": bucket,
"Key": key,
"ContentType": response.get("ContentType", "application/octet-stream"),
},
ExpiresIn=3600,
)
return FileUploadResponse(
id=upload_id,
url=url,
offset=0,
size=response["ContentLength"],
)
except ClientError as e:
raise Exception(f"Failed to get upload session info: {str(e)}") from e
async def list_all_buckets(
self,
page: int | None = None,
size: int | None = None,
) -> PaginatedResponse:
"""List all available S3 buckets."""
try:
async with self.session.client(
"s3",
endpoint_url=self.config.endpoint_url,
) as s3:
response = await s3.list_buckets()
buckets = [BucketResponse(name=bucket["Name"]) for bucket in response["Buckets"]]
# Convert BucketResponse objects to dictionaries for pagination
bucket_dicts = [bucket.model_dump() for bucket in buckets]
return paginate_records(bucket_dicts, page, size)
except ClientError as e:
raise Exception(f"Failed to list buckets: {str(e)}") from e
async def list_files_in_bucket(
self,
bucket: str,
page: int | None = None,
size: int | None = None,
) -> PaginatedResponse:
"""List all files in an S3 bucket."""
try:
async with self.session.client(
"s3",
endpoint_url=self.config.endpoint_url,
) as s3:
response = await s3.list_objects_v2(Bucket=bucket)
files: list[FileResponse] = []
for obj in response.get("Contents", []):
url = await s3.generate_presigned_url(
"get_object",
Params={
"Bucket": bucket,
"Key": obj["Key"],
},
ExpiresIn=3600,
)
files.append(
FileResponse(
bucket=bucket,
key=obj["Key"],
mime_type="application/octet-stream", # Default mime type
url=url,
bytes=obj["Size"],
created_at=int(obj["LastModified"].timestamp()),
)
)
# Convert FileResponse objects to dictionaries for pagination
file_dicts = [file.model_dump() for file in files]
return paginate_records(file_dicts, page, size)
except ClientError as e:
raise Exception(f"Failed to list files in bucket: {str(e)}") from e
async def get_file(
self,
bucket: str,
key: str,
) -> FileResponse:
"""Get information about a specific file in S3."""
try:
async with self.session.client(
"s3",
endpoint_url=self.config.endpoint_url,
) as s3:
response = await s3.head_object(Bucket=bucket, Key=key)
url = await s3.generate_presigned_url(
"get_object",
Params={
"Bucket": bucket,
"Key": key,
},
ExpiresIn=3600,
)
return FileResponse(
bucket=bucket,
key=key,
mime_type=response.get("ContentType", "application/octet-stream"),
url=url,
bytes=response["ContentLength"],
created_at=int(response["LastModified"].timestamp()),
)
except ClientError as e:
raise Exception(f"Failed to get file info: {str(e)}") from e
async def delete_file(
self,
bucket: str,
key: str,
) -> None:
"""Delete a file from S3."""
try:
async with self.session.client(
"s3",
endpoint_url=self.config.endpoint_url,
) as s3:
# Delete the file
await s3.delete_object(Bucket=bucket, Key=key)
except ClientError as e:
raise Exception(f"Failed to delete file: {str(e)}") from e

View file

@ -459,6 +459,7 @@
"uvicorn"
],
"ollama": [
"aioboto3",
"aiohttp",
"aiosqlite",
"autoevals",

View file

@ -29,4 +29,6 @@ distribution_spec:
- inline::rag-runtime
- remote::model-context-protocol
- remote::wolfram-alpha
files:
- remote::s3
image_type: conda

View file

@ -35,6 +35,7 @@ def get_distribution_template() -> DistributionTemplate:
"remote::model-context-protocol",
"remote::wolfram-alpha",
],
"files": ["remote::s3"],
}
name = "ollama"
inference_provider = Provider(
@ -48,6 +49,20 @@ def get_distribution_template() -> DistributionTemplate:
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
# Add S3 provider configuration
s3_provider = Provider(
provider_id="s3",
provider_type="remote::s3",
config={
"aws_access_key_id": "${env.AWS_ACCESS_KEY_ID:}",
"aws_secret_access_key": "${env.AWS_SECRET_ACCESS_KEY:}",
"region_name": "${env.AWS_REGION_NAME:}",
"endpoint_url": "${env.AWS_ENDPOINT_URL:}",
"bucket_name": "${env.AWS_BUCKET_NAME:}",
"verify_tls": "${env.AWS_VERIFY_TLS:true}",
},
)
inference_model = ModelInput(
model_id="${env.INFERENCE_MODEL}",
provider_id="ollama",
@ -92,6 +107,7 @@ def get_distribution_template() -> DistributionTemplate:
provider_overrides={
"inference": [inference_provider],
"vector_io": [vector_io_provider_faiss],
"files": [s3_provider],
},
default_models=[inference_model, embedding_model],
default_tool_groups=default_tool_groups,
@ -100,6 +116,7 @@ def get_distribution_template() -> DistributionTemplate:
provider_overrides={
"inference": [inference_provider],
"vector_io": [vector_io_provider_faiss],
"files": [s3_provider],
"safety": [
Provider(
provider_id="llama-guard",
@ -148,5 +165,30 @@ def get_distribution_template() -> DistributionTemplate:
"meta-llama/Llama-Guard-3-1B",
"Safety model loaded into the Ollama server",
),
# Add AWS S3 environment variables
"AWS_ACCESS_KEY_ID": (
"",
"AWS access key ID for S3 access",
),
"AWS_SECRET_ACCESS_KEY": (
"",
"AWS secret access key for S3 access",
),
"AWS_REGION_NAME": (
"",
"AWS region name for S3 access",
),
"AWS_ENDPOINT_URL": (
"",
"AWS endpoint URL for S3 access (for custom endpoints)",
),
"AWS_BUCKET_NAME": (
"",
"AWS bucket name for S3 access",
),
"AWS_VERIFY_TLS": (
"true",
"Whether to verify TLS for S3 connections",
),
},
)

View file

@ -4,6 +4,7 @@ apis:
- agents
- datasetio
- eval
- files
- inference
- safety
- scoring
@ -101,6 +102,16 @@ providers:
provider_type: remote::wolfram-alpha
config:
api_key: ${env.WOLFRAM_ALPHA_API_KEY:}
files:
- provider_id: s3
provider_type: remote::s3
config:
aws_access_key_id: ${env.AWS_ACCESS_KEY_ID:}
aws_secret_access_key: ${env.AWS_SECRET_ACCESS_KEY:}
region_name: ${env.AWS_REGION_NAME:}
endpoint_url: ${env.AWS_ENDPOINT_URL:}
bucket_name: ${env.AWS_BUCKET_NAME:}
verify_tls: ${env.AWS_VERIFY_TLS:true}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db

View file

@ -4,6 +4,7 @@ apis:
- agents
- datasetio
- eval
- files
- inference
- safety
- scoring
@ -99,6 +100,16 @@ providers:
provider_type: remote::wolfram-alpha
config:
api_key: ${env.WOLFRAM_ALPHA_API_KEY:}
files:
- provider_id: s3
provider_type: remote::s3
config:
aws_access_key_id: ${env.AWS_ACCESS_KEY_ID:}
aws_secret_access_key: ${env.AWS_SECRET_ACCESS_KEY:}
region_name: ${env.AWS_REGION_NAME:}
endpoint_url: ${env.AWS_ENDPOINT_URL:}
bucket_name: ${env.AWS_BUCKET_NAME:}
verify_tls: ${env.AWS_VERIFY_TLS:true}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db