diff --git a/docs/docs/providers/files/remote_openai.mdx b/docs/docs/providers/files/remote_openai.mdx new file mode 100644 index 000000000..3b5c40aad --- /dev/null +++ b/docs/docs/providers/files/remote_openai.mdx @@ -0,0 +1,27 @@ +--- +description: "OpenAI Files API provider for managing files through OpenAI's native file storage service." +sidebar_label: Remote - Openai +title: remote::openai +--- + +# remote::openai + +## Description + +OpenAI Files API provider for managing files through OpenAI's native file storage service. + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `api_key` | `` | No | | OpenAI API key for authentication | +| `metadata_store` | `` | No | | SQL store configuration for file metadata | + +## Sample Configuration + +```yaml +api_key: ${env.OPENAI_API_KEY} +metadata_store: + table_name: openai_files_metadata + backend: sql_default +``` diff --git a/src/llama_stack/providers/registry/files.py b/src/llama_stack/providers/registry/files.py index 9acabfacd..3f5949ba2 100644 --- a/src/llama_stack/providers/registry/files.py +++ b/src/llama_stack/providers/registry/files.py @@ -28,4 +28,13 @@ def available_providers() -> list[ProviderSpec]: config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig", description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.", ), + RemoteProviderSpec( + api=Api.files, + provider_type="remote::openai", + adapter_type="openai", + pip_packages=["openai"] + sql_store_pip_packages, + module="llama_stack.providers.remote.files.openai", + config_class="llama_stack.providers.remote.files.openai.config.OpenAIFilesImplConfig", + description="OpenAI Files API provider for managing files through OpenAI's native file storage service.", + ), ] diff --git a/src/llama_stack/providers/remote/files/openai/__init__.py b/src/llama_stack/providers/remote/files/openai/__init__.py new file mode 100644 index 000000000..58f86ecfd --- /dev/null +++ b/src/llama_stack/providers/remote/files/openai/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.core.datatypes import AccessRule, Api + +from .config import OpenAIFilesImplConfig + + +async def get_adapter_impl(config: OpenAIFilesImplConfig, deps: dict[Api, Any], policy: list[AccessRule] | None = None): + from .files import OpenAIFilesImpl + + impl = OpenAIFilesImpl(config, policy or []) + await impl.initialize() + return impl diff --git a/src/llama_stack/providers/remote/files/openai/config.py b/src/llama_stack/providers/remote/files/openai/config.py new file mode 100644 index 000000000..a38031e41 --- /dev/null +++ b/src/llama_stack/providers/remote/files/openai/config.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from pydantic import BaseModel, Field + +from llama_stack.core.storage.datatypes import SqlStoreReference + + +class OpenAIFilesImplConfig(BaseModel): + """Configuration for OpenAI Files API provider.""" + + api_key: str = Field(description="OpenAI API key for authentication") + metadata_store: SqlStoreReference = Field(description="SQL store configuration for file metadata") + + @classmethod + def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]: + return { + "api_key": "${env.OPENAI_API_KEY}", + "metadata_store": SqlStoreReference( + backend="sql_default", + table_name="openai_files_metadata", + ).model_dump(exclude_none=True), + } diff --git a/src/llama_stack/providers/remote/files/openai/files.py b/src/llama_stack/providers/remote/files/openai/files.py new file mode 100644 index 000000000..c5d4194df --- /dev/null +++ b/src/llama_stack/providers/remote/files/openai/files.py @@ -0,0 +1,239 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from datetime import UTC, datetime +from typing import Annotated, Any + +from fastapi import Depends, File, Form, Response, UploadFile + +from llama_stack.apis.common.errors import ResourceNotFoundError +from llama_stack.apis.common.responses import Order +from llama_stack.apis.files import ( + ExpiresAfter, + Files, + ListOpenAIFileResponse, + OpenAIFileDeleteResponse, + OpenAIFileObject, + OpenAIFilePurpose, +) +from llama_stack.core.datatypes import AccessRule +from llama_stack.providers.utils.files.form_data import parse_expires_after +from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType +from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore +from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl +from openai import OpenAI + +from .config import OpenAIFilesImplConfig + + +def _make_file_object( + *, + id: str, + filename: str, + purpose: str, + bytes: int, + created_at: int, + expires_at: int, + **kwargs: Any, +) -> OpenAIFileObject: + """ + Construct an OpenAIFileObject and normalize expires_at. + + If expires_at is greater than the max we treat it as no-expiration and + return None for expires_at. + """ + obj = OpenAIFileObject( + id=id, + filename=filename, + purpose=OpenAIFilePurpose(purpose), + bytes=bytes, + created_at=created_at, + expires_at=expires_at, + ) + + if obj.expires_at is not None and obj.expires_at > (obj.created_at + ExpiresAfter.MAX): + obj.expires_at = None # type: ignore + + return obj + + +class OpenAIFilesImpl(Files): + """OpenAI Files API implementation.""" + + def __init__(self, config: OpenAIFilesImplConfig, policy: list[AccessRule]) -> None: + self._config = config + self.policy = policy + self._client: OpenAI | None = None + self._sql_store: AuthorizedSqlStore | None = None + + def _now(self) -> int: + """Return current UTC timestamp as int seconds.""" + return int(datetime.now(UTC).timestamp()) + + async def _get_file(self, file_id: str, return_expired: bool = False) -> dict[str, Any]: + where: dict[str, str | dict] = {"id": file_id} + if not return_expired: + where["expires_at"] = {">": self._now()} + if not (row := await self.sql_store.fetch_one("openai_files", where=where)): + raise ResourceNotFoundError(file_id, "File", "files.list()") + return row + + async def _delete_file(self, file_id: str) -> None: + """Delete a file from OpenAI and the database.""" + try: + self.client.files.delete(file_id) + except Exception as e: + # If file doesn't exist on OpenAI side, just remove from metadata store + if "not found" not in str(e).lower(): + raise RuntimeError(f"Failed to delete file from OpenAI: {e}") from e + + await self.sql_store.delete("openai_files", where={"id": file_id}) + + async def _delete_if_expired(self, file_id: str) -> None: + """If the file exists and is expired, delete it.""" + if row := await self._get_file(file_id, return_expired=True): + if (expires_at := row.get("expires_at")) and expires_at <= self._now(): + await self._delete_file(file_id) + + async def initialize(self) -> None: + self._client = OpenAI(api_key=self._config.api_key) + + self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store), self.policy) + await self._sql_store.create_table( + "openai_files", + { + "id": ColumnDefinition(type=ColumnType.STRING, primary_key=True), + "filename": ColumnType.STRING, + "purpose": ColumnType.STRING, + "bytes": ColumnType.INTEGER, + "created_at": ColumnType.INTEGER, + "expires_at": ColumnType.INTEGER, + }, + ) + + async def shutdown(self) -> None: + pass + + @property + def client(self) -> OpenAI: + assert self._client is not None, "Provider not initialized" + return self._client + + @property + def sql_store(self) -> AuthorizedSqlStore: + assert self._sql_store is not None, "Provider not initialized" + return self._sql_store + + async def openai_upload_file( + self, + file: Annotated[UploadFile, File()], + purpose: Annotated[OpenAIFilePurpose, Form()], + expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None, + ) -> OpenAIFileObject: + filename = getattr(file, "filename", None) or "uploaded_file" + content = await file.read() + file_size = len(content) + + created_at = self._now() + + expires_at = created_at + ExpiresAfter.MAX * 42 + if purpose == OpenAIFilePurpose.BATCH: + expires_at = created_at + ExpiresAfter.MAX + + if expires_after is not None: + expires_at = created_at + expires_after.seconds + + try: + from io import BytesIO + + file_obj = BytesIO(content) + file_obj.name = filename + + response = self.client.files.create( + file=file_obj, + purpose=purpose.value, + ) + + file_id = response.id + + entry: dict[str, Any] = { + "id": file_id, + "filename": filename, + "purpose": purpose.value, + "bytes": file_size, + "created_at": created_at, + "expires_at": expires_at, + } + + await self.sql_store.insert("openai_files", entry) + + return _make_file_object(**entry) + + except Exception as e: + raise RuntimeError(f"Failed to upload file to OpenAI: {e}") from e + + async def openai_list_files( + self, + after: str | None = None, + limit: int | None = 10000, + order: Order | None = Order.desc, + purpose: OpenAIFilePurpose | None = None, + ) -> ListOpenAIFileResponse: + if not order: + order = Order.desc + + where_conditions: dict[str, Any] = {"expires_at": {">": self._now()}} + if purpose: + where_conditions["purpose"] = purpose.value + + paginated_result = await self.sql_store.fetch_all( + table="openai_files", + where=where_conditions, + order_by=[("created_at", order.value)], + cursor=("id", after) if after else None, + limit=limit, + ) + + files = [_make_file_object(**row) for row in paginated_result.data] + + return ListOpenAIFileResponse( + data=files, + has_more=paginated_result.has_more, + first_id=files[0].id if files else "", + last_id=files[-1].id if files else "", + ) + + async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject: + await self._delete_if_expired(file_id) + row = await self._get_file(file_id) + return _make_file_object(**row) + + async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse: + await self._delete_if_expired(file_id) + _ = await self._get_file(file_id) + await self._delete_file(file_id) + return OpenAIFileDeleteResponse(id=file_id, deleted=True) + + async def openai_retrieve_file_content(self, file_id: str) -> Response: + await self._delete_if_expired(file_id) + + row = await self._get_file(file_id) + + try: + response = self.client.files.content(file_id) + file_content = response.content + + except Exception as e: + if "not found" in str(e).lower(): + await self._delete_file(file_id) + raise ResourceNotFoundError(file_id, "File", "files.list()") from e + raise RuntimeError(f"Failed to download file from OpenAI: {e}") from e + + return Response( + content=file_content, + media_type="application/octet-stream", + headers={"Content-Disposition": f'attachment; filename="{row["filename"]}"'}, + ) diff --git a/tests/integration/files/test_files.py b/tests/integration/files/test_files.py index 516b0bd98..d9e8dd501 100644 --- a/tests/integration/files/test_files.py +++ b/tests/integration/files/test_files.py @@ -10,8 +10,18 @@ from unittest.mock import patch import pytest import requests +from llama_stack.apis.files import OpenAIFilePurpose from llama_stack.core.datatypes import User +purpose = OpenAIFilePurpose.ASSISTANTS + + +@pytest.fixture() +def provider_type_is_openai(llama_stack_client): + providers = [provider for provider in llama_stack_client.providers.list() if provider.api == "files"] + assert len(providers) == 1, "Expected exactly one files provider" + return providers[0].provider_type == "remote::openai" + # a fixture to skip all these tests if a files provider is not available @pytest.fixture(autouse=True) @@ -20,7 +30,7 @@ def skip_if_no_files_provider(llama_stack_client): pytest.skip("No files providers found") -def test_openai_client_basic_operations(openai_client): +def test_openai_client_basic_operations(openai_client, provider_type_is_openai): """Test basic file operations through OpenAI client.""" from openai import NotFoundError @@ -34,7 +44,7 @@ def test_openai_client_basic_operations(openai_client): # Upload file using OpenAI client with BytesIO(test_content) as file_buffer: file_buffer.name = "openai_test.txt" - uploaded_file = client.files.create(file=file_buffer, purpose="assistants") + uploaded_file = client.files.create(file=file_buffer, purpose=purpose) # Verify basic response structure assert uploaded_file.id.startswith("file-") @@ -50,16 +60,18 @@ def test_openai_client_basic_operations(openai_client): retrieved_file = client.files.retrieve(uploaded_file.id) assert retrieved_file.id == uploaded_file.id - # Retrieve file content - OpenAI client returns httpx Response object - content_response = client.files.content(uploaded_file.id) - assert content_response.content == test_content + # Retrieve file content + # OpenAI provider does not allow content retrieval with many `purpose` values + if not provider_type_is_openai: + content_response = client.files.content(uploaded_file.id) + assert content_response.content == test_content # Delete file delete_response = client.files.delete(uploaded_file.id) assert delete_response.deleted is True # Retrieve file should fail - with pytest.raises(NotFoundError, match="not found"): + with pytest.raises(NotFoundError): client.files.retrieve(uploaded_file.id) # File should not be found in listing @@ -68,7 +80,7 @@ def test_openai_client_basic_operations(openai_client): assert uploaded_file.id not in file_ids # Double delete should fail - with pytest.raises(NotFoundError, match="not found"): + with pytest.raises(NotFoundError): client.files.delete(uploaded_file.id) finally: @@ -91,7 +103,7 @@ def test_expires_after(openai_client): file_buffer.name = "expires_after.txt" uploaded_file = client.files.create( file=file_buffer, - purpose="assistants", + purpose=purpose, expires_after={"anchor": "created_at", "seconds": 4545}, ) @@ -126,7 +138,7 @@ def test_expires_after_requests(openai_client): try: files = {"file": ("expires_after_with_requests.txt", BytesIO(b"expires_after via requests"))} data = { - "purpose": "assistants", + "purpose": str(purpose), "expires_after[anchor]": "created_at", "expires_after[seconds]": "4545", } @@ -180,7 +192,7 @@ def test_files_authentication_isolation(mock_get_authenticated_user, llama_stack with BytesIO(test_content_1) as file_buffer: file_buffer.name = "user1_file.txt" - user1_file = client.files.create(file=file_buffer, purpose="assistants") + user1_file = client.files.create(file=file_buffer, purpose=purpose) # User 2 uploads a file mock_get_authenticated_user.return_value = user2 @@ -188,7 +200,7 @@ def test_files_authentication_isolation(mock_get_authenticated_user, llama_stack with BytesIO(test_content_2) as file_buffer: file_buffer.name = "user2_file.txt" - user2_file = client.files.create(file=file_buffer, purpose="assistants") + user2_file = client.files.create(file=file_buffer, purpose=purpose) try: # User 1 can see their own file @@ -264,7 +276,9 @@ def test_files_authentication_isolation(mock_get_authenticated_user, llama_stack @patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") -def test_files_authentication_shared_attributes(mock_get_authenticated_user, llama_stack_client): +def test_files_authentication_shared_attributes( + mock_get_authenticated_user, llama_stack_client, provider_type_is_openai +): """Test access control with users having identical attributes.""" client = llama_stack_client @@ -278,7 +292,7 @@ def test_files_authentication_shared_attributes(mock_get_authenticated_user, lla with BytesIO(test_content) as file_buffer: file_buffer.name = "shared_attributes_file.txt" - shared_file = client.files.create(file=file_buffer, purpose="assistants") + shared_file = client.files.create(file=file_buffer, purpose=purpose) try: # User B with identical attributes can access the file @@ -294,12 +308,13 @@ def test_files_authentication_shared_attributes(mock_get_authenticated_user, lla assert retrieved_file.id == shared_file.id # User B can access file content - content_response = client.files.content(shared_file.id) - if isinstance(content_response, str): - content = bytes(content_response, "utf-8") - else: - content = content_response.content - assert content == test_content + if not provider_type_is_openai: + content_response = client.files.content(shared_file.id) + if isinstance(content_response, str): + content = bytes(content_response, "utf-8") + else: + content = content_response.content + assert content == test_content # Cleanup mock_get_authenticated_user.return_value = user_a @@ -321,7 +336,9 @@ def test_files_authentication_shared_attributes(mock_get_authenticated_user, lla @patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") -def test_files_authentication_anonymous_access(mock_get_authenticated_user, llama_stack_client): +def test_files_authentication_anonymous_access( + mock_get_authenticated_user, llama_stack_client, provider_type_is_openai +): client = llama_stack_client # Simulate anonymous user (no authentication) @@ -331,7 +348,7 @@ def test_files_authentication_anonymous_access(mock_get_authenticated_user, llam with BytesIO(test_content) as file_buffer: file_buffer.name = "anonymous_file.txt" - anonymous_file = client.files.create(file=file_buffer, purpose="assistants") + anonymous_file = client.files.create(file=file_buffer, purpose=purpose) try: # Anonymous user should be able to access their own uploaded file @@ -344,12 +361,13 @@ def test_files_authentication_anonymous_access(mock_get_authenticated_user, llam assert retrieved_file.id == anonymous_file.id # Can access file content - content_response = client.files.content(anonymous_file.id) - if isinstance(content_response, str): - content = bytes(content_response, "utf-8") - else: - content = content_response.content - assert content == test_content + if not provider_type_is_openai: + content_response = client.files.content(anonymous_file.id) + if isinstance(content_response, str): + content = bytes(content_response, "utf-8") + else: + content = content_response.content + assert content == test_content # Can delete the file delete_response = client.files.delete(anonymous_file.id)