diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index f82a7cdd2..19d96cce2 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -24,7 +24,7 @@ jobs: matrix: # Listing tests manually since some of them currently fail # TODO: generate matrix list from tests/integration when fixed - test-type: [agents, inference, datasets, inspect, scoring, post_training, providers] + test-type: [agents, inference, datasets, inspect, scoring, post_training, providers, files] client-type: [library, http] fail-fast: false # we want to run all tests regardless of failure diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 4020dc4cd..7fa2390db 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -568,11 +568,11 @@ "get": { "responses": { "200": { - "description": "OK", + "description": "PaginatedResponse with the list of buckets", "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/ListBucketResponse" + "$ref": "#/components/schemas/PaginatedResponse" } } } @@ -596,11 +596,21 @@ "description": "List all buckets.", "parameters": [ { - "name": "bucket", + "name": "page", "in": "query", - "required": true, + "description": "The page number (1-based). If None, starts from first page.", + "required": false, "schema": { - "type": "string" + "type": "integer" + } + }, + { + "name": "size", + "in": "query", + "description": "Number of items per page. If None or -1, returns all items.", + "required": false, + "schema": { + "type": "integer" } } ] @@ -1850,7 +1860,7 @@ "parameters": [] } }, - "/v1/files/session:{upload_id}": { + "/v1/files/session/{upload_id}": { "get": { "responses": { "200": { @@ -2631,11 +2641,11 @@ "get": { "responses": { "200": { - "description": "OK", + "description": "PaginatedResponse with the list of files", "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/ListFileResponse" + "$ref": "#/components/schemas/PaginatedResponse" } } } @@ -2666,6 +2676,24 @@ "schema": { "type": "string" } + }, + { + "name": "page", + "in": "query", + "description": "The page number (1-based). If None, starts from first page.", + "required": false, + "schema": { + "type": "integer" + } + }, + { + "name": "size", + "in": "query", + "description": "Number of items per page. If None or -1, returns all items.", + "required": false, + "schema": { + "type": "integer" + } } ] } @@ -9085,37 +9113,6 @@ ], "title": "Job" }, - "BucketResponse": { - "type": "object", - "properties": { - "name": { - "type": "string" - } - }, - "additionalProperties": false, - "required": [ - "name" - ], - "title": "BucketResponse" - }, - "ListBucketResponse": { - "type": "object", - "properties": { - "data": { - "type": "array", - "items": { - "$ref": "#/components/schemas/BucketResponse" - }, - "description": "List of FileResponse entries" - } - }, - "additionalProperties": false, - "required": [ - "data" - ], - "title": "ListBucketResponse", - "description": "Response representing a list of file entries." - }, "ListBenchmarksResponse": { "type": "object", "properties": { @@ -9148,24 +9145,6 @@ ], "title": "ListDatasetsResponse" }, - "ListFileResponse": { - "type": "object", - "properties": { - "data": { - "type": "array", - "items": { - "$ref": "#/components/schemas/FileResponse" - }, - "description": "List of FileResponse entries" - } - }, - "additionalProperties": false, - "required": [ - "data" - ], - "title": "ListFileResponse", - "description": "Response representing a list of file entries." - }, "ListModelsResponse": { "type": "object", "properties": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 62e3ca85c..aa630e001 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -379,11 +379,12 @@ paths: get: responses: '200': - description: OK + description: >- + PaginatedResponse with the list of buckets content: application/json: schema: - $ref: '#/components/schemas/ListBucketResponse' + $ref: '#/components/schemas/PaginatedResponse' '400': $ref: '#/components/responses/BadRequest400' '429': @@ -398,11 +399,20 @@ paths: - Files description: List all buckets. parameters: - - name: bucket + - name: page in: query - required: true + description: >- + The page number (1-based). If None, starts from first page. + required: false schema: - type: string + type: integer + - name: size + in: query + description: >- + Number of items per page. If None or -1, returns all items. + required: false + schema: + type: integer post: responses: '200': @@ -1261,7 +1271,7 @@ paths: - PostTraining (Coming Soon) description: '' parameters: [] - /v1/files/session:{upload_id}: + /v1/files/session/{upload_id}: get: responses: '200': @@ -1816,11 +1826,11 @@ paths: get: responses: '200': - description: OK + description: PaginatedResponse with the list of files content: application/json: schema: - $ref: '#/components/schemas/ListFileResponse' + $ref: '#/components/schemas/PaginatedResponse' '400': $ref: '#/components/responses/BadRequest400' '429': @@ -1841,6 +1851,20 @@ paths: required: true schema: type: string + - name: page + in: query + description: >- + The page number (1-based). If None, starts from first page. + required: false + schema: + type: integer + - name: size + in: query + description: >- + Number of items per page. If None or -1, returns all items. + required: false + schema: + type: integer /v1/models: get: responses: @@ -6277,29 +6301,6 @@ components: - job_id - status title: Job - BucketResponse: - type: object - properties: - name: - type: string - additionalProperties: false - required: - - name - title: BucketResponse - ListBucketResponse: - type: object - properties: - data: - type: array - items: - $ref: '#/components/schemas/BucketResponse' - description: List of FileResponse entries - additionalProperties: false - required: - - data - title: ListBucketResponse - description: >- - Response representing a list of file entries. ListBenchmarksResponse: type: object properties: @@ -6322,20 +6323,6 @@ components: required: - data title: ListDatasetsResponse - ListFileResponse: - type: object - properties: - data: - type: array - items: - $ref: '#/components/schemas/FileResponse' - description: List of FileResponse entries - additionalProperties: false - required: - - data - title: ListFileResponse - description: >- - Response representing a list of file entries. ListModelsResponse: type: object properties: diff --git a/docs/source/distributions/self_hosted_distro/ollama.md b/docs/source/distributions/self_hosted_distro/ollama.md index 5d8935fe2..5f43ae38c 100644 --- a/docs/source/distributions/self_hosted_distro/ollama.md +++ b/docs/source/distributions/self_hosted_distro/ollama.md @@ -18,6 +18,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov | agents | `inline::meta-reference` | | datasetio | `remote::huggingface`, `inline::localfs` | | eval | `inline::meta-reference` | +| files | `remote::s3` | | inference | `remote::ollama` | | safety | `inline::llama-guard` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | @@ -36,6 +37,12 @@ The following environment variables can be configured: - `OLLAMA_URL`: URL of the Ollama server (default: `http://127.0.0.1:11434`) - `INFERENCE_MODEL`: Inference model loaded into the Ollama server (default: `meta-llama/Llama-3.2-3B-Instruct`) - `SAFETY_MODEL`: Safety model loaded into the Ollama server (default: `meta-llama/Llama-Guard-3-1B`) +- `AWS_ACCESS_KEY_ID`: AWS access key ID for S3 access (default: ``) +- `AWS_SECRET_ACCESS_KEY`: AWS secret access key for S3 access (default: ``) +- `AWS_REGION_NAME`: AWS region name for S3 access (default: ``) +- `AWS_ENDPOINT_URL`: AWS endpoint URL for S3 access (for custom endpoints) (default: ``) +- `AWS_BUCKET_NAME`: AWS bucket name for S3 access (default: ``) +- `AWS_VERIFY_TLS`: Whether to verify TLS for S3 connections (default: `true`) ## Setting up Ollama server diff --git a/llama_stack/apis/files/files.py b/llama_stack/apis/files/files.py index 4a9b49978..e0d89b1ed 100644 --- a/llama_stack/apis/files/files.py +++ b/llama_stack/apis/files/files.py @@ -8,6 +8,7 @@ from typing import Protocol, runtime_checkable from pydantic import BaseModel +from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, webmethod @@ -34,17 +35,6 @@ class BucketResponse(BaseModel): name: str -@json_schema_type -class ListBucketResponse(BaseModel): - """ - Response representing a list of file entries. - - :param data: List of FileResponse entries - """ - - data: list[BucketResponse] - - @json_schema_type class FileResponse(BaseModel): """ @@ -66,17 +56,6 @@ class FileResponse(BaseModel): created_at: int -@json_schema_type -class ListFileResponse(BaseModel): - """ - Response representing a list of file entries. - - :param data: List of FileResponse entries - """ - - data: list[FileResponse] - - @runtime_checkable @trace_protocol class Files(Protocol): @@ -98,7 +77,7 @@ class Files(Protocol): """ ... - @webmethod(route="/files/session:{upload_id}", method="POST", raw_bytes_request_body=True) + @webmethod(route="/files/session/{upload_id}", method="POST", raw_bytes_request_body=True) async def upload_content_to_session( self, upload_id: str, @@ -111,7 +90,7 @@ class Files(Protocol): """ ... - @webmethod(route="/files/session:{upload_id}", method="GET") + @webmethod(route="/files/session/{upload_id}", method="GET") async def get_upload_session_info( self, upload_id: str, @@ -126,10 +105,15 @@ class Files(Protocol): @webmethod(route="/files", method="GET") async def list_all_buckets( self, - bucket: str, - ) -> ListBucketResponse: + page: int | None = None, + size: int | None = None, + ) -> PaginatedResponse: """ List all buckets. + + :param page: The page number (1-based). If None, starts from first page. + :param size: Number of items per page. If None or -1, returns all items. + :return: PaginatedResponse with the list of buckets """ ... @@ -137,11 +121,16 @@ class Files(Protocol): async def list_files_in_bucket( self, bucket: str, - ) -> ListFileResponse: + page: int | None = None, + size: int | None = None, + ) -> PaginatedResponse: """ List all files in a bucket. :param bucket: Bucket name (valid chars: a-zA-Z0-9_-) + :param page: The page number (1-based). If None, starts from first page. + :param size: Number of items per page. If None or -1, returns all items. + :return: PaginatedResponse with the list of files """ ... diff --git a/llama_stack/providers/registry/files.py b/llama_stack/providers/registry/files.py index fb23436bb..46d62d820 100644 --- a/llama_stack/providers/registry/files.py +++ b/llama_stack/providers/registry/files.py @@ -4,8 +4,25 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.providers.datatypes import ProviderSpec + +from llama_stack.providers.datatypes import ( + AdapterSpec, + Api, + ProviderSpec, + remote_provider_spec, +) def available_providers() -> list[ProviderSpec]: - return [] + return [ + remote_provider_spec( + api=Api.files, + adapter=AdapterSpec( + adapter_type="s3", + pip_packages=["aioboto3"], + module="llama_stack.providers.remote.files.object.s3", + config_class="llama_stack.providers.remote.files.object.s3.config.S3FilesImplConfig", + provider_data_validator="llama_stack.providers.remote.files.object.s3.S3ProviderDataValidator", + ), + ), + ] diff --git a/llama_stack/providers/remote/files/object/s3/__init__.py b/llama_stack/providers/remote/files/object/s3/__init__.py new file mode 100644 index 000000000..13bd16230 --- /dev/null +++ b/llama_stack/providers/remote/files/object/s3/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import S3FilesImplConfig + + +async def get_adapter_impl(config: S3FilesImplConfig, _deps): + from .s3_files import S3FilesAdapter + + impl = S3FilesAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/files/object/s3/config.py b/llama_stack/providers/remote/files/object/s3/config.py new file mode 100644 index 000000000..992100018 --- /dev/null +++ b/llama_stack/providers/remote/files/object/s3/config.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel, Field + +from llama_stack.providers.utils.kvstore import KVStoreConfig +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig + + +class S3FilesImplConfig(BaseModel): + """Configuration for S3 file storage provider.""" + + aws_access_key_id: str = Field(description="AWS access key ID") + aws_secret_access_key: str = Field(description="AWS secret access key") + region_name: str | None = Field(default=None, description="AWS region name") + endpoint_url: str | None = Field(default=None, description="Optional endpoint URL for S3 compatible services") + bucket_name: str | None = Field(default=None, description="Default S3 bucket name") + verify_tls: bool = Field(default=True, description="Verify TLS certificates") + persistent_store: KVStoreConfig + + @classmethod + def sample_run_config(cls, __distro_dir__: str) -> dict: + return { + "aws_access_key_id": "your-access-key-id", + "aws_secret_access_key": "your-secret-access-key", + "region_name": "us-west-2", + "endpoint_url": None, + "bucket_name": "your-bucket-name", + "verify_tls": True, + "persistence_store": SqliteKVStoreConfig.sample_run_config( + __distro_dir__=__distro_dir__, + db_name="files_s3_store.db", + ), + } diff --git a/llama_stack/providers/remote/files/object/s3/persistence.py b/llama_stack/providers/remote/files/object/s3/persistence.py new file mode 100644 index 000000000..7f27eece3 --- /dev/null +++ b/llama_stack/providers/remote/files/object/s3/persistence.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import logging +from datetime import datetime, timezone + +from pydantic import BaseModel + +from llama_stack.apis.files.files import FileUploadResponse +from llama_stack.providers.utils.kvstore import KVStore + +log = logging.getLogger(__name__) + + +class UploadSessionInfo(BaseModel): + """Information about an upload session.""" + + upload_id: str + bucket: str + key: str + mime_type: str + size: int + url: str + created_at: datetime + + +class S3FilesPersistence: + def __init__(self, kvstore: KVStore): + self._kvstore = kvstore + self._store = None + + async def _get_store(self) -> KVStore: + """Get the kvstore instance, initializing it if needed.""" + if self._store is None: + self._store = await anext(self._kvstore) + return self._store + + async def store_upload_session( + self, session_info: FileUploadResponse, bucket: str, key: str, mime_type: str, size: int + ): + """Store upload session information.""" + upload_info = UploadSessionInfo( + upload_id=session_info.id, + bucket=bucket, + key=key, + mime_type=mime_type, + size=size, + url=session_info.url, + created_at=datetime.now(timezone.utc), + ) + + store = await self._get_store() + await store.set( + key=f"upload_session:{session_info.id}", + value=upload_info.model_dump_json(), + ) + + async def get_upload_session(self, upload_id: str) -> UploadSessionInfo | None: + """Get upload session information.""" + store = await self._get_store() + value = await store.get( + key=f"upload_session:{upload_id}", + ) + if not value: + return None + + return UploadSessionInfo(**json.loads(value)) + + async def delete_upload_session(self, upload_id: str) -> None: + """Delete upload session information.""" + store = await self._get_store() + await store.delete(key=f"upload_session:{upload_id}") diff --git a/llama_stack/providers/remote/files/object/s3/s3_files.py b/llama_stack/providers/remote/files/object/s3/s3_files.py new file mode 100644 index 000000000..45dff31ed --- /dev/null +++ b/llama_stack/providers/remote/files/object/s3/s3_files.py @@ -0,0 +1,235 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import aioboto3 +from botocore.exceptions import ClientError + +from llama_stack.apis.common.responses import PaginatedResponse +from llama_stack.apis.files.files import ( + BucketResponse, + FileResponse, + Files, + FileUploadResponse, +) +from llama_stack.providers.utils.pagination import paginate_records + +from .config import S3ImplConfig + + +class S3FilesAdapter(Files): + def __init__(self, config: S3ImplConfig): + self.config = config + self.session = aioboto3.Session( + aws_access_key_id=config.aws_access_key_id, + aws_secret_access_key=config.aws_secret_access_key, + region_name=config.region_name, + ) + + async def initialize(self): + # TODO: health check? + pass + + async def create_upload_session( + self, + bucket: str, + key: str, + mime_type: str, + size: int, + ) -> FileUploadResponse: + """Create a presigned URL for uploading a file to S3.""" + try: + async with self.session.client( + "s3", + endpoint_url=self.config.endpoint_url, + ) as s3: + url = await s3.generate_presigned_url( + "put_object", + Params={ + "Bucket": bucket, + "Key": key, + "ContentType": mime_type, + }, + ExpiresIn=3600, # URL expires in 1 hour + ) + return FileUploadResponse( + id=f"{bucket}/{key}", + url=url, + offset=0, + size=size, + ) + except ClientError as e: + raise Exception(f"Failed to create upload session: {str(e)}") from e + + async def upload_content_to_session( + self, + upload_id: str, + ) -> FileResponse | None: + """Upload content to S3 using the upload session.""" + bucket, key = upload_id.split("/", 1) + try: + async with self.session.client( + "s3", + endpoint_url=self.config.endpoint_url, + ) as s3: + response = await s3.head_object(Bucket=bucket, Key=key) + url = await s3.generate_presigned_url( + "get_object", + Params={ + "Bucket": bucket, + "Key": key, + }, + ExpiresIn=3600, + ) + return FileResponse( + bucket=bucket, + key=key, + mime_type=response.get("ContentType", "application/octet-stream"), + url=url, + bytes=response["ContentLength"], + created_at=int(response["LastModified"].timestamp()), + ) + except ClientError: + return None + + async def get_upload_session_info( + self, + upload_id: str, + ) -> FileUploadResponse: + """Get information about an upload session.""" + bucket, key = upload_id.split("/", 1) + try: + async with self.session.client( + "s3", + endpoint_url=self.config.endpoint_url, + ) as s3: + response = await s3.head_object(Bucket=bucket, Key=key) + url = await s3.generate_presigned_url( + "put_object", + Params={ + "Bucket": bucket, + "Key": key, + "ContentType": response.get("ContentType", "application/octet-stream"), + }, + ExpiresIn=3600, + ) + return FileUploadResponse( + id=upload_id, + url=url, + offset=0, + size=response["ContentLength"], + ) + except ClientError as e: + raise Exception(f"Failed to get upload session info: {str(e)}") from e + + async def list_all_buckets( + self, + page: int | None = None, + size: int | None = None, + ) -> PaginatedResponse: + """List all available S3 buckets.""" + + try: + async with self.session.client( + "s3", + endpoint_url=self.config.endpoint_url, + ) as s3: + response = await s3.list_buckets() + buckets = [BucketResponse(name=bucket["Name"]) for bucket in response["Buckets"]] + # Convert BucketResponse objects to dictionaries for pagination + bucket_dicts = [bucket.model_dump() for bucket in buckets] + return paginate_records(bucket_dicts, page, size) + except ClientError as e: + raise Exception(f"Failed to list buckets: {str(e)}") from e + + async def list_files_in_bucket( + self, + bucket: str, + page: int | None = None, + size: int | None = None, + ) -> PaginatedResponse: + """List all files in an S3 bucket.""" + try: + async with self.session.client( + "s3", + endpoint_url=self.config.endpoint_url, + ) as s3: + response = await s3.list_objects_v2(Bucket=bucket) + files: list[FileResponse] = [] + + for obj in response.get("Contents", []): + url = await s3.generate_presigned_url( + "get_object", + Params={ + "Bucket": bucket, + "Key": obj["Key"], + }, + ExpiresIn=3600, + ) + + files.append( + FileResponse( + bucket=bucket, + key=obj["Key"], + mime_type="application/octet-stream", # Default mime type + url=url, + bytes=obj["Size"], + created_at=int(obj["LastModified"].timestamp()), + ) + ) + + # Convert FileResponse objects to dictionaries for pagination + file_dicts = [file.model_dump() for file in files] + return paginate_records(file_dicts, page, size) + except ClientError as e: + raise Exception(f"Failed to list files in bucket: {str(e)}") from e + + async def get_file( + self, + bucket: str, + key: str, + ) -> FileResponse: + """Get information about a specific file in S3.""" + try: + async with self.session.client( + "s3", + endpoint_url=self.config.endpoint_url, + ) as s3: + response = await s3.head_object(Bucket=bucket, Key=key) + url = await s3.generate_presigned_url( + "get_object", + Params={ + "Bucket": bucket, + "Key": key, + }, + ExpiresIn=3600, + ) + + return FileResponse( + bucket=bucket, + key=key, + mime_type=response.get("ContentType", "application/octet-stream"), + url=url, + bytes=response["ContentLength"], + created_at=int(response["LastModified"].timestamp()), + ) + except ClientError as e: + raise Exception(f"Failed to get file info: {str(e)}") from e + + async def delete_file( + self, + bucket: str, + key: str, + ) -> None: + """Delete a file from S3.""" + try: + async with self.session.client( + "s3", + endpoint_url=self.config.endpoint_url, + ) as s3: + # Delete the file + await s3.delete_object(Bucket=bucket, Key=key) + except ClientError as e: + raise Exception(f"Failed to delete file: {str(e)}") from e diff --git a/llama_stack/templates/dependencies.json b/llama_stack/templates/dependencies.json index 35cbc8878..c97f56604 100644 --- a/llama_stack/templates/dependencies.json +++ b/llama_stack/templates/dependencies.json @@ -459,6 +459,7 @@ "uvicorn" ], "ollama": [ + "aioboto3", "aiohttp", "aiosqlite", "autoevals", diff --git a/llama_stack/templates/ollama/build.yaml b/llama_stack/templates/ollama/build.yaml index 88e61bf8a..26142f5db 100644 --- a/llama_stack/templates/ollama/build.yaml +++ b/llama_stack/templates/ollama/build.yaml @@ -29,4 +29,6 @@ distribution_spec: - inline::rag-runtime - remote::model-context-protocol - remote::wolfram-alpha + files: + - remote::s3 image_type: conda diff --git a/llama_stack/templates/ollama/ollama.py b/llama_stack/templates/ollama/ollama.py index d72d299ec..826df4e3e 100644 --- a/llama_stack/templates/ollama/ollama.py +++ b/llama_stack/templates/ollama/ollama.py @@ -35,6 +35,7 @@ def get_distribution_template() -> DistributionTemplate: "remote::model-context-protocol", "remote::wolfram-alpha", ], + "files": ["remote::s3"], } name = "ollama" inference_provider = Provider( @@ -48,6 +49,20 @@ def get_distribution_template() -> DistributionTemplate: config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"), ) + # Add S3 provider configuration + s3_provider = Provider( + provider_id="s3", + provider_type="remote::s3", + config={ + "aws_access_key_id": "${env.AWS_ACCESS_KEY_ID:}", + "aws_secret_access_key": "${env.AWS_SECRET_ACCESS_KEY:}", + "region_name": "${env.AWS_REGION_NAME:}", + "endpoint_url": "${env.AWS_ENDPOINT_URL:}", + "bucket_name": "${env.AWS_BUCKET_NAME:}", + "verify_tls": "${env.AWS_VERIFY_TLS:true}", + }, + ) + inference_model = ModelInput( model_id="${env.INFERENCE_MODEL}", provider_id="ollama", @@ -92,6 +107,7 @@ def get_distribution_template() -> DistributionTemplate: provider_overrides={ "inference": [inference_provider], "vector_io": [vector_io_provider_faiss], + "files": [s3_provider], }, default_models=[inference_model, embedding_model], default_tool_groups=default_tool_groups, @@ -100,6 +116,7 @@ def get_distribution_template() -> DistributionTemplate: provider_overrides={ "inference": [inference_provider], "vector_io": [vector_io_provider_faiss], + "files": [s3_provider], "safety": [ Provider( provider_id="llama-guard", @@ -148,5 +165,30 @@ def get_distribution_template() -> DistributionTemplate: "meta-llama/Llama-Guard-3-1B", "Safety model loaded into the Ollama server", ), + # Add AWS S3 environment variables + "AWS_ACCESS_KEY_ID": ( + "", + "AWS access key ID for S3 access", + ), + "AWS_SECRET_ACCESS_KEY": ( + "", + "AWS secret access key for S3 access", + ), + "AWS_REGION_NAME": ( + "", + "AWS region name for S3 access", + ), + "AWS_ENDPOINT_URL": ( + "", + "AWS endpoint URL for S3 access (for custom endpoints)", + ), + "AWS_BUCKET_NAME": ( + "", + "AWS bucket name for S3 access", + ), + "AWS_VERIFY_TLS": ( + "true", + "Whether to verify TLS for S3 connections", + ), }, ) diff --git a/llama_stack/templates/ollama/run-with-safety.yaml b/llama_stack/templates/ollama/run-with-safety.yaml index 9f3f2a505..15308197f 100644 --- a/llama_stack/templates/ollama/run-with-safety.yaml +++ b/llama_stack/templates/ollama/run-with-safety.yaml @@ -4,6 +4,7 @@ apis: - agents - datasetio - eval +- files - inference - safety - scoring @@ -101,6 +102,16 @@ providers: provider_type: remote::wolfram-alpha config: api_key: ${env.WOLFRAM_ALPHA_API_KEY:} + files: + - provider_id: s3 + provider_type: remote::s3 + config: + aws_access_key_id: ${env.AWS_ACCESS_KEY_ID:} + aws_secret_access_key: ${env.AWS_SECRET_ACCESS_KEY:} + region_name: ${env.AWS_REGION_NAME:} + endpoint_url: ${env.AWS_ENDPOINT_URL:} + bucket_name: ${env.AWS_BUCKET_NAME:} + verify_tls: ${env.AWS_VERIFY_TLS:true} metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db diff --git a/llama_stack/templates/ollama/run.yaml b/llama_stack/templates/ollama/run.yaml index 66b0d77d7..233799078 100644 --- a/llama_stack/templates/ollama/run.yaml +++ b/llama_stack/templates/ollama/run.yaml @@ -4,6 +4,7 @@ apis: - agents - datasetio - eval +- files - inference - safety - scoring @@ -99,6 +100,16 @@ providers: provider_type: remote::wolfram-alpha config: api_key: ${env.WOLFRAM_ALPHA_API_KEY:} + files: + - provider_id: s3 + provider_type: remote::s3 + config: + aws_access_key_id: ${env.AWS_ACCESS_KEY_ID:} + aws_secret_access_key: ${env.AWS_SECRET_ACCESS_KEY:} + region_name: ${env.AWS_REGION_NAME:} + endpoint_url: ${env.AWS_ENDPOINT_URL:} + bucket_name: ${env.AWS_BUCKET_NAME:} + verify_tls: ${env.AWS_VERIFY_TLS:true} metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db diff --git a/tests/integration/files/conftest.py b/tests/integration/files/conftest.py new file mode 100644 index 000000000..132c333f3 --- /dev/null +++ b/tests/integration/files/conftest.py @@ -0,0 +1,38 @@ +from typing import AsyncGenerator + +import pytest + +from llama_stack.providers.remote.files.object.s3.config import S3FilesImplConfig +from llama_stack.providers.remote.files.object.s3.s3_files import S3FilesAdapter +from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig + + +@pytest.fixture +def s3_config(): + """Create S3 configuration for MinIO.""" + return S3FilesImplConfig( + aws_access_key_id="ROOTNAME", + aws_secret_access_key="CHANGEME123", + region_name="us-east-1", + endpoint_url="http://localhost:9000", + ) + + +@pytest.fixture +async def kvstore() -> AsyncGenerator[KVStore, None]: + """Create a SQLite KV store for testing.""" + config = SqliteKVStoreConfig( + path=":memory:" # Use in-memory SQLite for tests + ) + store = await kvstore_impl(config) + await store.initialize() + yield store + + +@pytest.fixture +async def s3_files(s3_config, kvstore) -> AsyncGenerator[S3FilesAdapter, None]: + """Create S3FilesAdapter instance for testing.""" + adapter = S3FilesAdapter(s3_config, kvstore) + await adapter.initialize() + yield adapter diff --git a/tests/unit/providers/files/test_remote_files_s3.py b/tests/unit/providers/files/test_remote_files_s3.py new file mode 100644 index 000000000..ef80d7269 --- /dev/null +++ b/tests/unit/providers/files/test_remote_files_s3.py @@ -0,0 +1,35 @@ +import pytest + +from llama_stack.providers.remote.files.object.s3.config import S3FilesImplConfig +from llama_stack.providers.remote.files.object.s3.s3_files import S3FilesAdapter + + +@pytest.fixture +def s3_config(): + return S3FilesImplConfig( + aws_access_key_id="test-key", + aws_secret_access_key="test-secret", + region_name="us-east-1", + endpoint_url="http://localhost:9000", + ) + + +@pytest.fixture +async def s3_files(s3_config): + adapter = S3FilesAdapter(s3_config) + await adapter.initialize() + return adapter + + +@pytest.mark.asyncio +async def test_create_upload_session(s3_files): + bucket = "test-bucket" + key = "test-file.txt" + mime_type = "text/plain" + size = 1024 + + response = await s3_files.create_upload_session(bucket, key, mime_type, size) + assert response.id == f"{bucket}/{key}" + assert response.size == size + assert response.offset == 0 + assert response.url is not None