mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
feat: add s3 provider to files API
Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
e3ad17ec5e
commit
749cbcca31
17 changed files with 614 additions and 132 deletions
2
.github/workflows/integration-tests.yml
vendored
2
.github/workflows/integration-tests.yml
vendored
|
@ -24,7 +24,7 @@ jobs:
|
|||
matrix:
|
||||
# Listing tests manually since some of them currently fail
|
||||
# TODO: generate matrix list from tests/integration when fixed
|
||||
test-type: [agents, inference, datasets, inspect, scoring, post_training, providers]
|
||||
test-type: [agents, inference, datasets, inspect, scoring, post_training, providers, files]
|
||||
client-type: [library, http]
|
||||
fail-fast: false # we want to run all tests regardless of failure
|
||||
|
||||
|
|
93
docs/_static/llama-stack-spec.html
vendored
93
docs/_static/llama-stack-spec.html
vendored
|
@ -568,11 +568,11 @@
|
|||
"get": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"description": "PaginatedResponse with the list of buckets",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/ListBucketResponse"
|
||||
"$ref": "#/components/schemas/PaginatedResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -596,11 +596,21 @@
|
|||
"description": "List all buckets.",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "bucket",
|
||||
"name": "page",
|
||||
"in": "query",
|
||||
"required": true,
|
||||
"description": "The page number (1-based). If None, starts from first page.",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "string"
|
||||
"type": "integer"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "size",
|
||||
"in": "query",
|
||||
"description": "Number of items per page. If None or -1, returns all items.",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
@ -1850,7 +1860,7 @@
|
|||
"parameters": []
|
||||
}
|
||||
},
|
||||
"/v1/files/session:{upload_id}": {
|
||||
"/v1/files/session/{upload_id}": {
|
||||
"get": {
|
||||
"responses": {
|
||||
"200": {
|
||||
|
@ -2631,11 +2641,11 @@
|
|||
"get": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"description": "PaginatedResponse with the list of files",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/ListFileResponse"
|
||||
"$ref": "#/components/schemas/PaginatedResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2666,6 +2676,24 @@
|
|||
"schema": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "page",
|
||||
"in": "query",
|
||||
"description": "The page number (1-based). If None, starts from first page.",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "integer"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "size",
|
||||
"in": "query",
|
||||
"description": "Number of items per page. If None or -1, returns all items.",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
@ -9085,37 +9113,6 @@
|
|||
],
|
||||
"title": "Job"
|
||||
},
|
||||
"BucketResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"name"
|
||||
],
|
||||
"title": "BucketResponse"
|
||||
},
|
||||
"ListBucketResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"data": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/BucketResponse"
|
||||
},
|
||||
"description": "List of FileResponse entries"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"data"
|
||||
],
|
||||
"title": "ListBucketResponse",
|
||||
"description": "Response representing a list of file entries."
|
||||
},
|
||||
"ListBenchmarksResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -9148,24 +9145,6 @@
|
|||
],
|
||||
"title": "ListDatasetsResponse"
|
||||
},
|
||||
"ListFileResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"data": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/FileResponse"
|
||||
},
|
||||
"description": "List of FileResponse entries"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"data"
|
||||
],
|
||||
"title": "ListFileResponse",
|
||||
"description": "Response representing a list of file entries."
|
||||
},
|
||||
"ListModelsResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
|
77
docs/_static/llama-stack-spec.yaml
vendored
77
docs/_static/llama-stack-spec.yaml
vendored
|
@ -379,11 +379,12 @@ paths:
|
|||
get:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
description: >-
|
||||
PaginatedResponse with the list of buckets
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ListBucketResponse'
|
||||
$ref: '#/components/schemas/PaginatedResponse'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
|
@ -398,11 +399,20 @@ paths:
|
|||
- Files
|
||||
description: List all buckets.
|
||||
parameters:
|
||||
- name: bucket
|
||||
- name: page
|
||||
in: query
|
||||
required: true
|
||||
description: >-
|
||||
The page number (1-based). If None, starts from first page.
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
type: integer
|
||||
- name: size
|
||||
in: query
|
||||
description: >-
|
||||
Number of items per page. If None or -1, returns all items.
|
||||
required: false
|
||||
schema:
|
||||
type: integer
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
|
@ -1261,7 +1271,7 @@ paths:
|
|||
- PostTraining (Coming Soon)
|
||||
description: ''
|
||||
parameters: []
|
||||
/v1/files/session:{upload_id}:
|
||||
/v1/files/session/{upload_id}:
|
||||
get:
|
||||
responses:
|
||||
'200':
|
||||
|
@ -1816,11 +1826,11 @@ paths:
|
|||
get:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
description: PaginatedResponse with the list of files
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ListFileResponse'
|
||||
$ref: '#/components/schemas/PaginatedResponse'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
|
@ -1841,6 +1851,20 @@ paths:
|
|||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: page
|
||||
in: query
|
||||
description: >-
|
||||
The page number (1-based). If None, starts from first page.
|
||||
required: false
|
||||
schema:
|
||||
type: integer
|
||||
- name: size
|
||||
in: query
|
||||
description: >-
|
||||
Number of items per page. If None or -1, returns all items.
|
||||
required: false
|
||||
schema:
|
||||
type: integer
|
||||
/v1/models:
|
||||
get:
|
||||
responses:
|
||||
|
@ -6277,29 +6301,6 @@ components:
|
|||
- job_id
|
||||
- status
|
||||
title: Job
|
||||
BucketResponse:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
required:
|
||||
- name
|
||||
title: BucketResponse
|
||||
ListBucketResponse:
|
||||
type: object
|
||||
properties:
|
||||
data:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/BucketResponse'
|
||||
description: List of FileResponse entries
|
||||
additionalProperties: false
|
||||
required:
|
||||
- data
|
||||
title: ListBucketResponse
|
||||
description: >-
|
||||
Response representing a list of file entries.
|
||||
ListBenchmarksResponse:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -6322,20 +6323,6 @@ components:
|
|||
required:
|
||||
- data
|
||||
title: ListDatasetsResponse
|
||||
ListFileResponse:
|
||||
type: object
|
||||
properties:
|
||||
data:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/FileResponse'
|
||||
description: List of FileResponse entries
|
||||
additionalProperties: false
|
||||
required:
|
||||
- data
|
||||
title: ListFileResponse
|
||||
description: >-
|
||||
Response representing a list of file entries.
|
||||
ListModelsResponse:
|
||||
type: object
|
||||
properties:
|
||||
|
|
|
@ -18,6 +18,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
|
|||
| agents | `inline::meta-reference` |
|
||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||
| eval | `inline::meta-reference` |
|
||||
| files | `remote::s3` |
|
||||
| inference | `remote::ollama` |
|
||||
| safety | `inline::llama-guard` |
|
||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||
|
@ -36,6 +37,12 @@ The following environment variables can be configured:
|
|||
- `OLLAMA_URL`: URL of the Ollama server (default: `http://127.0.0.1:11434`)
|
||||
- `INFERENCE_MODEL`: Inference model loaded into the Ollama server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
||||
- `SAFETY_MODEL`: Safety model loaded into the Ollama server (default: `meta-llama/Llama-Guard-3-1B`)
|
||||
- `AWS_ACCESS_KEY_ID`: AWS access key ID for S3 access (default: ``)
|
||||
- `AWS_SECRET_ACCESS_KEY`: AWS secret access key for S3 access (default: ``)
|
||||
- `AWS_REGION_NAME`: AWS region name for S3 access (default: ``)
|
||||
- `AWS_ENDPOINT_URL`: AWS endpoint URL for S3 access (for custom endpoints) (default: ``)
|
||||
- `AWS_BUCKET_NAME`: AWS bucket name for S3 access (default: ``)
|
||||
- `AWS_VERIFY_TLS`: Whether to verify TLS for S3 connections (default: `true`)
|
||||
|
||||
|
||||
## Setting up Ollama server
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import Protocol, runtime_checkable
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
@ -34,17 +35,6 @@ class BucketResponse(BaseModel):
|
|||
name: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListBucketResponse(BaseModel):
|
||||
"""
|
||||
Response representing a list of file entries.
|
||||
|
||||
:param data: List of FileResponse entries
|
||||
"""
|
||||
|
||||
data: list[BucketResponse]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class FileResponse(BaseModel):
|
||||
"""
|
||||
|
@ -66,17 +56,6 @@ class FileResponse(BaseModel):
|
|||
created_at: int
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListFileResponse(BaseModel):
|
||||
"""
|
||||
Response representing a list of file entries.
|
||||
|
||||
:param data: List of FileResponse entries
|
||||
"""
|
||||
|
||||
data: list[FileResponse]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Files(Protocol):
|
||||
|
@ -98,7 +77,7 @@ class Files(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/files/session:{upload_id}", method="POST", raw_bytes_request_body=True)
|
||||
@webmethod(route="/files/session/{upload_id}", method="POST", raw_bytes_request_body=True)
|
||||
async def upload_content_to_session(
|
||||
self,
|
||||
upload_id: str,
|
||||
|
@ -111,7 +90,7 @@ class Files(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/files/session:{upload_id}", method="GET")
|
||||
@webmethod(route="/files/session/{upload_id}", method="GET")
|
||||
async def get_upload_session_info(
|
||||
self,
|
||||
upload_id: str,
|
||||
|
@ -126,10 +105,15 @@ class Files(Protocol):
|
|||
@webmethod(route="/files", method="GET")
|
||||
async def list_all_buckets(
|
||||
self,
|
||||
bucket: str,
|
||||
) -> ListBucketResponse:
|
||||
page: int | None = None,
|
||||
size: int | None = None,
|
||||
) -> PaginatedResponse:
|
||||
"""
|
||||
List all buckets.
|
||||
|
||||
:param page: The page number (1-based). If None, starts from first page.
|
||||
:param size: Number of items per page. If None or -1, returns all items.
|
||||
:return: PaginatedResponse with the list of buckets
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -137,11 +121,16 @@ class Files(Protocol):
|
|||
async def list_files_in_bucket(
|
||||
self,
|
||||
bucket: str,
|
||||
) -> ListFileResponse:
|
||||
page: int | None = None,
|
||||
size: int | None = None,
|
||||
) -> PaginatedResponse:
|
||||
"""
|
||||
List all files in a bucket.
|
||||
|
||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-)
|
||||
:param page: The page number (1-based). If None, starts from first page.
|
||||
:param size: Number of items per page. If None or -1, returns all items.
|
||||
:return: PaginatedResponse with the list of files
|
||||
"""
|
||||
...
|
||||
|
||||
|
|
|
@ -4,8 +4,25 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
)
|
||||
|
||||
|
||||
def available_providers() -> list[ProviderSpec]:
|
||||
return []
|
||||
return [
|
||||
remote_provider_spec(
|
||||
api=Api.files,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="s3",
|
||||
pip_packages=["aioboto3"],
|
||||
module="llama_stack.providers.remote.files.object.s3",
|
||||
config_class="llama_stack.providers.remote.files.object.s3.config.S3FilesImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.files.object.s3.S3ProviderDataValidator",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
15
llama_stack/providers/remote/files/object/s3/__init__.py
Normal file
15
llama_stack/providers/remote/files/object/s3/__init__.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import S3FilesImplConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: S3FilesImplConfig, _deps):
|
||||
from .s3_files import S3FilesAdapter
|
||||
|
||||
impl = S3FilesAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
37
llama_stack/providers/remote/files/object/s3/config.py
Normal file
37
llama_stack/providers/remote/files/object/s3/config.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore import KVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
class S3FilesImplConfig(BaseModel):
|
||||
"""Configuration for S3 file storage provider."""
|
||||
|
||||
aws_access_key_id: str = Field(description="AWS access key ID")
|
||||
aws_secret_access_key: str = Field(description="AWS secret access key")
|
||||
region_name: str | None = Field(default=None, description="AWS region name")
|
||||
endpoint_url: str | None = Field(default=None, description="Optional endpoint URL for S3 compatible services")
|
||||
bucket_name: str | None = Field(default=None, description="Default S3 bucket name")
|
||||
verify_tls: bool = Field(default=True, description="Verify TLS certificates")
|
||||
persistent_store: KVStoreConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict:
|
||||
return {
|
||||
"aws_access_key_id": "your-access-key-id",
|
||||
"aws_secret_access_key": "your-secret-access-key",
|
||||
"region_name": "us-west-2",
|
||||
"endpoint_url": None,
|
||||
"bucket_name": "your-bucket-name",
|
||||
"verify_tls": True,
|
||||
"persistence_store": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="files_s3_store.db",
|
||||
),
|
||||
}
|
76
llama_stack/providers/remote/files/object/s3/persistence.py
Normal file
76
llama_stack/providers/remote/files/object/s3/persistence.py
Normal file
|
@ -0,0 +1,76 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.files.files import FileUploadResponse
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UploadSessionInfo(BaseModel):
|
||||
"""Information about an upload session."""
|
||||
|
||||
upload_id: str
|
||||
bucket: str
|
||||
key: str
|
||||
mime_type: str
|
||||
size: int
|
||||
url: str
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class S3FilesPersistence:
|
||||
def __init__(self, kvstore: KVStore):
|
||||
self._kvstore = kvstore
|
||||
self._store = None
|
||||
|
||||
async def _get_store(self) -> KVStore:
|
||||
"""Get the kvstore instance, initializing it if needed."""
|
||||
if self._store is None:
|
||||
self._store = await anext(self._kvstore)
|
||||
return self._store
|
||||
|
||||
async def store_upload_session(
|
||||
self, session_info: FileUploadResponse, bucket: str, key: str, mime_type: str, size: int
|
||||
):
|
||||
"""Store upload session information."""
|
||||
upload_info = UploadSessionInfo(
|
||||
upload_id=session_info.id,
|
||||
bucket=bucket,
|
||||
key=key,
|
||||
mime_type=mime_type,
|
||||
size=size,
|
||||
url=session_info.url,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
store = await self._get_store()
|
||||
await store.set(
|
||||
key=f"upload_session:{session_info.id}",
|
||||
value=upload_info.model_dump_json(),
|
||||
)
|
||||
|
||||
async def get_upload_session(self, upload_id: str) -> UploadSessionInfo | None:
|
||||
"""Get upload session information."""
|
||||
store = await self._get_store()
|
||||
value = await store.get(
|
||||
key=f"upload_session:{upload_id}",
|
||||
)
|
||||
if not value:
|
||||
return None
|
||||
|
||||
return UploadSessionInfo(**json.loads(value))
|
||||
|
||||
async def delete_upload_session(self, upload_id: str) -> None:
|
||||
"""Delete upload session information."""
|
||||
store = await self._get_store()
|
||||
await store.delete(key=f"upload_session:{upload_id}")
|
235
llama_stack/providers/remote/files/object/s3/s3_files.py
Normal file
235
llama_stack/providers/remote/files/object/s3/s3_files.py
Normal file
|
@ -0,0 +1,235 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import aioboto3
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.files.files import (
|
||||
BucketResponse,
|
||||
FileResponse,
|
||||
Files,
|
||||
FileUploadResponse,
|
||||
)
|
||||
from llama_stack.providers.utils.pagination import paginate_records
|
||||
|
||||
from .config import S3ImplConfig
|
||||
|
||||
|
||||
class S3FilesAdapter(Files):
|
||||
def __init__(self, config: S3ImplConfig):
|
||||
self.config = config
|
||||
self.session = aioboto3.Session(
|
||||
aws_access_key_id=config.aws_access_key_id,
|
||||
aws_secret_access_key=config.aws_secret_access_key,
|
||||
region_name=config.region_name,
|
||||
)
|
||||
|
||||
async def initialize(self):
|
||||
# TODO: health check?
|
||||
pass
|
||||
|
||||
async def create_upload_session(
|
||||
self,
|
||||
bucket: str,
|
||||
key: str,
|
||||
mime_type: str,
|
||||
size: int,
|
||||
) -> FileUploadResponse:
|
||||
"""Create a presigned URL for uploading a file to S3."""
|
||||
try:
|
||||
async with self.session.client(
|
||||
"s3",
|
||||
endpoint_url=self.config.endpoint_url,
|
||||
) as s3:
|
||||
url = await s3.generate_presigned_url(
|
||||
"put_object",
|
||||
Params={
|
||||
"Bucket": bucket,
|
||||
"Key": key,
|
||||
"ContentType": mime_type,
|
||||
},
|
||||
ExpiresIn=3600, # URL expires in 1 hour
|
||||
)
|
||||
return FileUploadResponse(
|
||||
id=f"{bucket}/{key}",
|
||||
url=url,
|
||||
offset=0,
|
||||
size=size,
|
||||
)
|
||||
except ClientError as e:
|
||||
raise Exception(f"Failed to create upload session: {str(e)}") from e
|
||||
|
||||
async def upload_content_to_session(
|
||||
self,
|
||||
upload_id: str,
|
||||
) -> FileResponse | None:
|
||||
"""Upload content to S3 using the upload session."""
|
||||
bucket, key = upload_id.split("/", 1)
|
||||
try:
|
||||
async with self.session.client(
|
||||
"s3",
|
||||
endpoint_url=self.config.endpoint_url,
|
||||
) as s3:
|
||||
response = await s3.head_object(Bucket=bucket, Key=key)
|
||||
url = await s3.generate_presigned_url(
|
||||
"get_object",
|
||||
Params={
|
||||
"Bucket": bucket,
|
||||
"Key": key,
|
||||
},
|
||||
ExpiresIn=3600,
|
||||
)
|
||||
return FileResponse(
|
||||
bucket=bucket,
|
||||
key=key,
|
||||
mime_type=response.get("ContentType", "application/octet-stream"),
|
||||
url=url,
|
||||
bytes=response["ContentLength"],
|
||||
created_at=int(response["LastModified"].timestamp()),
|
||||
)
|
||||
except ClientError:
|
||||
return None
|
||||
|
||||
async def get_upload_session_info(
|
||||
self,
|
||||
upload_id: str,
|
||||
) -> FileUploadResponse:
|
||||
"""Get information about an upload session."""
|
||||
bucket, key = upload_id.split("/", 1)
|
||||
try:
|
||||
async with self.session.client(
|
||||
"s3",
|
||||
endpoint_url=self.config.endpoint_url,
|
||||
) as s3:
|
||||
response = await s3.head_object(Bucket=bucket, Key=key)
|
||||
url = await s3.generate_presigned_url(
|
||||
"put_object",
|
||||
Params={
|
||||
"Bucket": bucket,
|
||||
"Key": key,
|
||||
"ContentType": response.get("ContentType", "application/octet-stream"),
|
||||
},
|
||||
ExpiresIn=3600,
|
||||
)
|
||||
return FileUploadResponse(
|
||||
id=upload_id,
|
||||
url=url,
|
||||
offset=0,
|
||||
size=response["ContentLength"],
|
||||
)
|
||||
except ClientError as e:
|
||||
raise Exception(f"Failed to get upload session info: {str(e)}") from e
|
||||
|
||||
async def list_all_buckets(
|
||||
self,
|
||||
page: int | None = None,
|
||||
size: int | None = None,
|
||||
) -> PaginatedResponse:
|
||||
"""List all available S3 buckets."""
|
||||
|
||||
try:
|
||||
async with self.session.client(
|
||||
"s3",
|
||||
endpoint_url=self.config.endpoint_url,
|
||||
) as s3:
|
||||
response = await s3.list_buckets()
|
||||
buckets = [BucketResponse(name=bucket["Name"]) for bucket in response["Buckets"]]
|
||||
# Convert BucketResponse objects to dictionaries for pagination
|
||||
bucket_dicts = [bucket.model_dump() for bucket in buckets]
|
||||
return paginate_records(bucket_dicts, page, size)
|
||||
except ClientError as e:
|
||||
raise Exception(f"Failed to list buckets: {str(e)}") from e
|
||||
|
||||
async def list_files_in_bucket(
|
||||
self,
|
||||
bucket: str,
|
||||
page: int | None = None,
|
||||
size: int | None = None,
|
||||
) -> PaginatedResponse:
|
||||
"""List all files in an S3 bucket."""
|
||||
try:
|
||||
async with self.session.client(
|
||||
"s3",
|
||||
endpoint_url=self.config.endpoint_url,
|
||||
) as s3:
|
||||
response = await s3.list_objects_v2(Bucket=bucket)
|
||||
files: list[FileResponse] = []
|
||||
|
||||
for obj in response.get("Contents", []):
|
||||
url = await s3.generate_presigned_url(
|
||||
"get_object",
|
||||
Params={
|
||||
"Bucket": bucket,
|
||||
"Key": obj["Key"],
|
||||
},
|
||||
ExpiresIn=3600,
|
||||
)
|
||||
|
||||
files.append(
|
||||
FileResponse(
|
||||
bucket=bucket,
|
||||
key=obj["Key"],
|
||||
mime_type="application/octet-stream", # Default mime type
|
||||
url=url,
|
||||
bytes=obj["Size"],
|
||||
created_at=int(obj["LastModified"].timestamp()),
|
||||
)
|
||||
)
|
||||
|
||||
# Convert FileResponse objects to dictionaries for pagination
|
||||
file_dicts = [file.model_dump() for file in files]
|
||||
return paginate_records(file_dicts, page, size)
|
||||
except ClientError as e:
|
||||
raise Exception(f"Failed to list files in bucket: {str(e)}") from e
|
||||
|
||||
async def get_file(
|
||||
self,
|
||||
bucket: str,
|
||||
key: str,
|
||||
) -> FileResponse:
|
||||
"""Get information about a specific file in S3."""
|
||||
try:
|
||||
async with self.session.client(
|
||||
"s3",
|
||||
endpoint_url=self.config.endpoint_url,
|
||||
) as s3:
|
||||
response = await s3.head_object(Bucket=bucket, Key=key)
|
||||
url = await s3.generate_presigned_url(
|
||||
"get_object",
|
||||
Params={
|
||||
"Bucket": bucket,
|
||||
"Key": key,
|
||||
},
|
||||
ExpiresIn=3600,
|
||||
)
|
||||
|
||||
return FileResponse(
|
||||
bucket=bucket,
|
||||
key=key,
|
||||
mime_type=response.get("ContentType", "application/octet-stream"),
|
||||
url=url,
|
||||
bytes=response["ContentLength"],
|
||||
created_at=int(response["LastModified"].timestamp()),
|
||||
)
|
||||
except ClientError as e:
|
||||
raise Exception(f"Failed to get file info: {str(e)}") from e
|
||||
|
||||
async def delete_file(
|
||||
self,
|
||||
bucket: str,
|
||||
key: str,
|
||||
) -> None:
|
||||
"""Delete a file from S3."""
|
||||
try:
|
||||
async with self.session.client(
|
||||
"s3",
|
||||
endpoint_url=self.config.endpoint_url,
|
||||
) as s3:
|
||||
# Delete the file
|
||||
await s3.delete_object(Bucket=bucket, Key=key)
|
||||
except ClientError as e:
|
||||
raise Exception(f"Failed to delete file: {str(e)}") from e
|
|
@ -459,6 +459,7 @@
|
|||
"uvicorn"
|
||||
],
|
||||
"ollama": [
|
||||
"aioboto3",
|
||||
"aiohttp",
|
||||
"aiosqlite",
|
||||
"autoevals",
|
||||
|
|
|
@ -29,4 +29,6 @@ distribution_spec:
|
|||
- inline::rag-runtime
|
||||
- remote::model-context-protocol
|
||||
- remote::wolfram-alpha
|
||||
files:
|
||||
- remote::s3
|
||||
image_type: conda
|
||||
|
|
|
@ -35,6 +35,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"remote::model-context-protocol",
|
||||
"remote::wolfram-alpha",
|
||||
],
|
||||
"files": ["remote::s3"],
|
||||
}
|
||||
name = "ollama"
|
||||
inference_provider = Provider(
|
||||
|
@ -48,6 +49,20 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
)
|
||||
|
||||
# Add S3 provider configuration
|
||||
s3_provider = Provider(
|
||||
provider_id="s3",
|
||||
provider_type="remote::s3",
|
||||
config={
|
||||
"aws_access_key_id": "${env.AWS_ACCESS_KEY_ID:}",
|
||||
"aws_secret_access_key": "${env.AWS_SECRET_ACCESS_KEY:}",
|
||||
"region_name": "${env.AWS_REGION_NAME:}",
|
||||
"endpoint_url": "${env.AWS_ENDPOINT_URL:}",
|
||||
"bucket_name": "${env.AWS_BUCKET_NAME:}",
|
||||
"verify_tls": "${env.AWS_VERIFY_TLS:true}",
|
||||
},
|
||||
)
|
||||
|
||||
inference_model = ModelInput(
|
||||
model_id="${env.INFERENCE_MODEL}",
|
||||
provider_id="ollama",
|
||||
|
@ -92,6 +107,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
provider_overrides={
|
||||
"inference": [inference_provider],
|
||||
"vector_io": [vector_io_provider_faiss],
|
||||
"files": [s3_provider],
|
||||
},
|
||||
default_models=[inference_model, embedding_model],
|
||||
default_tool_groups=default_tool_groups,
|
||||
|
@ -100,6 +116,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
provider_overrides={
|
||||
"inference": [inference_provider],
|
||||
"vector_io": [vector_io_provider_faiss],
|
||||
"files": [s3_provider],
|
||||
"safety": [
|
||||
Provider(
|
||||
provider_id="llama-guard",
|
||||
|
@ -148,5 +165,30 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"meta-llama/Llama-Guard-3-1B",
|
||||
"Safety model loaded into the Ollama server",
|
||||
),
|
||||
# Add AWS S3 environment variables
|
||||
"AWS_ACCESS_KEY_ID": (
|
||||
"",
|
||||
"AWS access key ID for S3 access",
|
||||
),
|
||||
"AWS_SECRET_ACCESS_KEY": (
|
||||
"",
|
||||
"AWS secret access key for S3 access",
|
||||
),
|
||||
"AWS_REGION_NAME": (
|
||||
"",
|
||||
"AWS region name for S3 access",
|
||||
),
|
||||
"AWS_ENDPOINT_URL": (
|
||||
"",
|
||||
"AWS endpoint URL for S3 access (for custom endpoints)",
|
||||
),
|
||||
"AWS_BUCKET_NAME": (
|
||||
"",
|
||||
"AWS bucket name for S3 access",
|
||||
),
|
||||
"AWS_VERIFY_TLS": (
|
||||
"true",
|
||||
"Whether to verify TLS for S3 connections",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
|
|
@ -4,6 +4,7 @@ apis:
|
|||
- agents
|
||||
- datasetio
|
||||
- eval
|
||||
- files
|
||||
- inference
|
||||
- safety
|
||||
- scoring
|
||||
|
@ -101,6 +102,16 @@ providers:
|
|||
provider_type: remote::wolfram-alpha
|
||||
config:
|
||||
api_key: ${env.WOLFRAM_ALPHA_API_KEY:}
|
||||
files:
|
||||
- provider_id: s3
|
||||
provider_type: remote::s3
|
||||
config:
|
||||
aws_access_key_id: ${env.AWS_ACCESS_KEY_ID:}
|
||||
aws_secret_access_key: ${env.AWS_SECRET_ACCESS_KEY:}
|
||||
region_name: ${env.AWS_REGION_NAME:}
|
||||
endpoint_url: ${env.AWS_ENDPOINT_URL:}
|
||||
bucket_name: ${env.AWS_BUCKET_NAME:}
|
||||
verify_tls: ${env.AWS_VERIFY_TLS:true}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db
|
||||
|
|
|
@ -4,6 +4,7 @@ apis:
|
|||
- agents
|
||||
- datasetio
|
||||
- eval
|
||||
- files
|
||||
- inference
|
||||
- safety
|
||||
- scoring
|
||||
|
@ -99,6 +100,16 @@ providers:
|
|||
provider_type: remote::wolfram-alpha
|
||||
config:
|
||||
api_key: ${env.WOLFRAM_ALPHA_API_KEY:}
|
||||
files:
|
||||
- provider_id: s3
|
||||
provider_type: remote::s3
|
||||
config:
|
||||
aws_access_key_id: ${env.AWS_ACCESS_KEY_ID:}
|
||||
aws_secret_access_key: ${env.AWS_SECRET_ACCESS_KEY:}
|
||||
region_name: ${env.AWS_REGION_NAME:}
|
||||
endpoint_url: ${env.AWS_ENDPOINT_URL:}
|
||||
bucket_name: ${env.AWS_BUCKET_NAME:}
|
||||
verify_tls: ${env.AWS_VERIFY_TLS:true}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db
|
||||
|
|
38
tests/integration/files/conftest.py
Normal file
38
tests/integration/files/conftest.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
from typing import AsyncGenerator
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.providers.remote.files.object.s3.config import S3FilesImplConfig
|
||||
from llama_stack.providers.remote.files.object.s3.s3_files import S3FilesAdapter
|
||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_config():
|
||||
"""Create S3 configuration for MinIO."""
|
||||
return S3FilesImplConfig(
|
||||
aws_access_key_id="ROOTNAME",
|
||||
aws_secret_access_key="CHANGEME123",
|
||||
region_name="us-east-1",
|
||||
endpoint_url="http://localhost:9000",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def kvstore() -> AsyncGenerator[KVStore, None]:
|
||||
"""Create a SQLite KV store for testing."""
|
||||
config = SqliteKVStoreConfig(
|
||||
path=":memory:" # Use in-memory SQLite for tests
|
||||
)
|
||||
store = await kvstore_impl(config)
|
||||
await store.initialize()
|
||||
yield store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def s3_files(s3_config, kvstore) -> AsyncGenerator[S3FilesAdapter, None]:
|
||||
"""Create S3FilesAdapter instance for testing."""
|
||||
adapter = S3FilesAdapter(s3_config, kvstore)
|
||||
await adapter.initialize()
|
||||
yield adapter
|
35
tests/unit/providers/files/test_remote_files_s3.py
Normal file
35
tests/unit/providers/files/test_remote_files_s3.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
import pytest
|
||||
|
||||
from llama_stack.providers.remote.files.object.s3.config import S3FilesImplConfig
|
||||
from llama_stack.providers.remote.files.object.s3.s3_files import S3FilesAdapter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_config():
|
||||
return S3FilesImplConfig(
|
||||
aws_access_key_id="test-key",
|
||||
aws_secret_access_key="test-secret",
|
||||
region_name="us-east-1",
|
||||
endpoint_url="http://localhost:9000",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def s3_files(s3_config):
|
||||
adapter = S3FilesAdapter(s3_config)
|
||||
await adapter.initialize()
|
||||
return adapter
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_upload_session(s3_files):
|
||||
bucket = "test-bucket"
|
||||
key = "test-file.txt"
|
||||
mime_type = "text/plain"
|
||||
size = 1024
|
||||
|
||||
response = await s3_files.create_upload_session(bucket, key, mime_type, size)
|
||||
assert response.id == f"{bucket}/{key}"
|
||||
assert response.size == size
|
||||
assert response.offset == 0
|
||||
assert response.url is not None
|
Loading…
Add table
Add a link
Reference in a new issue