mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
feat: openai files provider (#3946)
# What does this PR do? - Adds OpenAI files provider - Note that file content retrieval is pretty limited by `purpose` https://community.openai.com/t/file-uploads-error-why-can-t-i-download-files-with-purpose-user-data/1357013?utm_source=chatgpt.com ## Test Plan Modify run yaml to use openai files provider: ``` files: - provider_id: openai provider_type: remote::openai config: api_key: ${env.OPENAI_API_KEY:=} metadata_store: backend: sql_default table_name: openai_files_metadata # Then run files tests ❯ uv run --no-sync ./scripts/integration-tests.sh --stack-config server:ci-tests --inference-mode replay --setup ollama --suite base --pattern test_files ```
This commit is contained in:
parent
feabcdd67b
commit
1f9d48cd54
6 changed files with 367 additions and 27 deletions
27
docs/docs/providers/files/remote_openai.mdx
Normal file
27
docs/docs/providers/files/remote_openai.mdx
Normal file
|
|
@ -0,0 +1,27 @@
|
||||||
|
---
|
||||||
|
description: "OpenAI Files API provider for managing files through OpenAI's native file storage service."
|
||||||
|
sidebar_label: Remote - Openai
|
||||||
|
title: remote::openai
|
||||||
|
---
|
||||||
|
|
||||||
|
# remote::openai
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
OpenAI Files API provider for managing files through OpenAI's native file storage service.
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
| Field | Type | Required | Default | Description |
|
||||||
|
|-------|------|----------|---------|-------------|
|
||||||
|
| `api_key` | `<class 'str'>` | No | | OpenAI API key for authentication |
|
||||||
|
| `metadata_store` | `<class 'llama_stack.core.storage.datatypes.SqlStoreReference'>` | No | | SQL store configuration for file metadata |
|
||||||
|
|
||||||
|
## Sample Configuration
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
api_key: ${env.OPENAI_API_KEY}
|
||||||
|
metadata_store:
|
||||||
|
table_name: openai_files_metadata
|
||||||
|
backend: sql_default
|
||||||
|
```
|
||||||
|
|
@ -28,4 +28,13 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
|
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
|
||||||
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
|
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
|
||||||
),
|
),
|
||||||
|
RemoteProviderSpec(
|
||||||
|
api=Api.files,
|
||||||
|
provider_type="remote::openai",
|
||||||
|
adapter_type="openai",
|
||||||
|
pip_packages=["openai"] + sql_store_pip_packages,
|
||||||
|
module="llama_stack.providers.remote.files.openai",
|
||||||
|
config_class="llama_stack.providers.remote.files.openai.config.OpenAIFilesImplConfig",
|
||||||
|
description="OpenAI Files API provider for managing files through OpenAI's native file storage service.",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
19
src/llama_stack/providers/remote/files/openai/__init__.py
Normal file
19
src/llama_stack/providers/remote/files/openai/__init__.py
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.core.datatypes import AccessRule, Api
|
||||||
|
|
||||||
|
from .config import OpenAIFilesImplConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: OpenAIFilesImplConfig, deps: dict[Api, Any], policy: list[AccessRule] | None = None):
|
||||||
|
from .files import OpenAIFilesImpl
|
||||||
|
|
||||||
|
impl = OpenAIFilesImpl(config, policy or [])
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
28
src/llama_stack/providers/remote/files/openai/config.py
Normal file
28
src/llama_stack/providers/remote/files/openai/config.py
Normal file
|
|
@ -0,0 +1,28 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.core.storage.datatypes import SqlStoreReference
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIFilesImplConfig(BaseModel):
|
||||||
|
"""Configuration for OpenAI Files API provider."""
|
||||||
|
|
||||||
|
api_key: str = Field(description="OpenAI API key for authentication")
|
||||||
|
metadata_store: SqlStoreReference = Field(description="SQL store configuration for file metadata")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"api_key": "${env.OPENAI_API_KEY}",
|
||||||
|
"metadata_store": SqlStoreReference(
|
||||||
|
backend="sql_default",
|
||||||
|
table_name="openai_files_metadata",
|
||||||
|
).model_dump(exclude_none=True),
|
||||||
|
}
|
||||||
239
src/llama_stack/providers/remote/files/openai/files.py
Normal file
239
src/llama_stack/providers/remote/files/openai/files.py
Normal file
|
|
@ -0,0 +1,239 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Annotated, Any
|
||||||
|
|
||||||
|
from fastapi import Depends, File, Form, Response, UploadFile
|
||||||
|
|
||||||
|
from llama_stack.apis.common.errors import ResourceNotFoundError
|
||||||
|
from llama_stack.apis.common.responses import Order
|
||||||
|
from llama_stack.apis.files import (
|
||||||
|
ExpiresAfter,
|
||||||
|
Files,
|
||||||
|
ListOpenAIFileResponse,
|
||||||
|
OpenAIFileDeleteResponse,
|
||||||
|
OpenAIFileObject,
|
||||||
|
OpenAIFilePurpose,
|
||||||
|
)
|
||||||
|
from llama_stack.core.datatypes import AccessRule
|
||||||
|
from llama_stack.providers.utils.files.form_data import parse_expires_after
|
||||||
|
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||||
|
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||||
|
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
from .config import OpenAIFilesImplConfig
|
||||||
|
|
||||||
|
|
||||||
|
def _make_file_object(
|
||||||
|
*,
|
||||||
|
id: str,
|
||||||
|
filename: str,
|
||||||
|
purpose: str,
|
||||||
|
bytes: int,
|
||||||
|
created_at: int,
|
||||||
|
expires_at: int,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> OpenAIFileObject:
|
||||||
|
"""
|
||||||
|
Construct an OpenAIFileObject and normalize expires_at.
|
||||||
|
|
||||||
|
If expires_at is greater than the max we treat it as no-expiration and
|
||||||
|
return None for expires_at.
|
||||||
|
"""
|
||||||
|
obj = OpenAIFileObject(
|
||||||
|
id=id,
|
||||||
|
filename=filename,
|
||||||
|
purpose=OpenAIFilePurpose(purpose),
|
||||||
|
bytes=bytes,
|
||||||
|
created_at=created_at,
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
if obj.expires_at is not None and obj.expires_at > (obj.created_at + ExpiresAfter.MAX):
|
||||||
|
obj.expires_at = None # type: ignore
|
||||||
|
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIFilesImpl(Files):
|
||||||
|
"""OpenAI Files API implementation."""
|
||||||
|
|
||||||
|
def __init__(self, config: OpenAIFilesImplConfig, policy: list[AccessRule]) -> None:
|
||||||
|
self._config = config
|
||||||
|
self.policy = policy
|
||||||
|
self._client: OpenAI | None = None
|
||||||
|
self._sql_store: AuthorizedSqlStore | None = None
|
||||||
|
|
||||||
|
def _now(self) -> int:
|
||||||
|
"""Return current UTC timestamp as int seconds."""
|
||||||
|
return int(datetime.now(UTC).timestamp())
|
||||||
|
|
||||||
|
async def _get_file(self, file_id: str, return_expired: bool = False) -> dict[str, Any]:
|
||||||
|
where: dict[str, str | dict] = {"id": file_id}
|
||||||
|
if not return_expired:
|
||||||
|
where["expires_at"] = {">": self._now()}
|
||||||
|
if not (row := await self.sql_store.fetch_one("openai_files", where=where)):
|
||||||
|
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||||
|
return row
|
||||||
|
|
||||||
|
async def _delete_file(self, file_id: str) -> None:
|
||||||
|
"""Delete a file from OpenAI and the database."""
|
||||||
|
try:
|
||||||
|
self.client.files.delete(file_id)
|
||||||
|
except Exception as e:
|
||||||
|
# If file doesn't exist on OpenAI side, just remove from metadata store
|
||||||
|
if "not found" not in str(e).lower():
|
||||||
|
raise RuntimeError(f"Failed to delete file from OpenAI: {e}") from e
|
||||||
|
|
||||||
|
await self.sql_store.delete("openai_files", where={"id": file_id})
|
||||||
|
|
||||||
|
async def _delete_if_expired(self, file_id: str) -> None:
|
||||||
|
"""If the file exists and is expired, delete it."""
|
||||||
|
if row := await self._get_file(file_id, return_expired=True):
|
||||||
|
if (expires_at := row.get("expires_at")) and expires_at <= self._now():
|
||||||
|
await self._delete_file(file_id)
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
self._client = OpenAI(api_key=self._config.api_key)
|
||||||
|
|
||||||
|
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store), self.policy)
|
||||||
|
await self._sql_store.create_table(
|
||||||
|
"openai_files",
|
||||||
|
{
|
||||||
|
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||||
|
"filename": ColumnType.STRING,
|
||||||
|
"purpose": ColumnType.STRING,
|
||||||
|
"bytes": ColumnType.INTEGER,
|
||||||
|
"created_at": ColumnType.INTEGER,
|
||||||
|
"expires_at": ColumnType.INTEGER,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client(self) -> OpenAI:
|
||||||
|
assert self._client is not None, "Provider not initialized"
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sql_store(self) -> AuthorizedSqlStore:
|
||||||
|
assert self._sql_store is not None, "Provider not initialized"
|
||||||
|
return self._sql_store
|
||||||
|
|
||||||
|
async def openai_upload_file(
|
||||||
|
self,
|
||||||
|
file: Annotated[UploadFile, File()],
|
||||||
|
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||||
|
expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
|
||||||
|
) -> OpenAIFileObject:
|
||||||
|
filename = getattr(file, "filename", None) or "uploaded_file"
|
||||||
|
content = await file.read()
|
||||||
|
file_size = len(content)
|
||||||
|
|
||||||
|
created_at = self._now()
|
||||||
|
|
||||||
|
expires_at = created_at + ExpiresAfter.MAX * 42
|
||||||
|
if purpose == OpenAIFilePurpose.BATCH:
|
||||||
|
expires_at = created_at + ExpiresAfter.MAX
|
||||||
|
|
||||||
|
if expires_after is not None:
|
||||||
|
expires_at = created_at + expires_after.seconds
|
||||||
|
|
||||||
|
try:
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
file_obj = BytesIO(content)
|
||||||
|
file_obj.name = filename
|
||||||
|
|
||||||
|
response = self.client.files.create(
|
||||||
|
file=file_obj,
|
||||||
|
purpose=purpose.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
file_id = response.id
|
||||||
|
|
||||||
|
entry: dict[str, Any] = {
|
||||||
|
"id": file_id,
|
||||||
|
"filename": filename,
|
||||||
|
"purpose": purpose.value,
|
||||||
|
"bytes": file_size,
|
||||||
|
"created_at": created_at,
|
||||||
|
"expires_at": expires_at,
|
||||||
|
}
|
||||||
|
|
||||||
|
await self.sql_store.insert("openai_files", entry)
|
||||||
|
|
||||||
|
return _make_file_object(**entry)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to upload file to OpenAI: {e}") from e
|
||||||
|
|
||||||
|
async def openai_list_files(
|
||||||
|
self,
|
||||||
|
after: str | None = None,
|
||||||
|
limit: int | None = 10000,
|
||||||
|
order: Order | None = Order.desc,
|
||||||
|
purpose: OpenAIFilePurpose | None = None,
|
||||||
|
) -> ListOpenAIFileResponse:
|
||||||
|
if not order:
|
||||||
|
order = Order.desc
|
||||||
|
|
||||||
|
where_conditions: dict[str, Any] = {"expires_at": {">": self._now()}}
|
||||||
|
if purpose:
|
||||||
|
where_conditions["purpose"] = purpose.value
|
||||||
|
|
||||||
|
paginated_result = await self.sql_store.fetch_all(
|
||||||
|
table="openai_files",
|
||||||
|
where=where_conditions,
|
||||||
|
order_by=[("created_at", order.value)],
|
||||||
|
cursor=("id", after) if after else None,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
files = [_make_file_object(**row) for row in paginated_result.data]
|
||||||
|
|
||||||
|
return ListOpenAIFileResponse(
|
||||||
|
data=files,
|
||||||
|
has_more=paginated_result.has_more,
|
||||||
|
first_id=files[0].id if files else "",
|
||||||
|
last_id=files[-1].id if files else "",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject:
|
||||||
|
await self._delete_if_expired(file_id)
|
||||||
|
row = await self._get_file(file_id)
|
||||||
|
return _make_file_object(**row)
|
||||||
|
|
||||||
|
async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse:
|
||||||
|
await self._delete_if_expired(file_id)
|
||||||
|
_ = await self._get_file(file_id)
|
||||||
|
await self._delete_file(file_id)
|
||||||
|
return OpenAIFileDeleteResponse(id=file_id, deleted=True)
|
||||||
|
|
||||||
|
async def openai_retrieve_file_content(self, file_id: str) -> Response:
|
||||||
|
await self._delete_if_expired(file_id)
|
||||||
|
|
||||||
|
row = await self._get_file(file_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.client.files.content(file_id)
|
||||||
|
file_content = response.content
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if "not found" in str(e).lower():
|
||||||
|
await self._delete_file(file_id)
|
||||||
|
raise ResourceNotFoundError(file_id, "File", "files.list()") from e
|
||||||
|
raise RuntimeError(f"Failed to download file from OpenAI: {e}") from e
|
||||||
|
|
||||||
|
return Response(
|
||||||
|
content=file_content,
|
||||||
|
media_type="application/octet-stream",
|
||||||
|
headers={"Content-Disposition": f'attachment; filename="{row["filename"]}"'},
|
||||||
|
)
|
||||||
|
|
@ -10,8 +10,18 @@ from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
from llama_stack.apis.files import OpenAIFilePurpose
|
||||||
from llama_stack.core.datatypes import User
|
from llama_stack.core.datatypes import User
|
||||||
|
|
||||||
|
purpose = OpenAIFilePurpose.ASSISTANTS
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def provider_type_is_openai(llama_stack_client):
|
||||||
|
providers = [provider for provider in llama_stack_client.providers.list() if provider.api == "files"]
|
||||||
|
assert len(providers) == 1, "Expected exactly one files provider"
|
||||||
|
return providers[0].provider_type == "remote::openai"
|
||||||
|
|
||||||
|
|
||||||
# a fixture to skip all these tests if a files provider is not available
|
# a fixture to skip all these tests if a files provider is not available
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
|
|
@ -20,7 +30,7 @@ def skip_if_no_files_provider(llama_stack_client):
|
||||||
pytest.skip("No files providers found")
|
pytest.skip("No files providers found")
|
||||||
|
|
||||||
|
|
||||||
def test_openai_client_basic_operations(openai_client):
|
def test_openai_client_basic_operations(openai_client, provider_type_is_openai):
|
||||||
"""Test basic file operations through OpenAI client."""
|
"""Test basic file operations through OpenAI client."""
|
||||||
from openai import NotFoundError
|
from openai import NotFoundError
|
||||||
|
|
||||||
|
|
@ -34,7 +44,7 @@ def test_openai_client_basic_operations(openai_client):
|
||||||
# Upload file using OpenAI client
|
# Upload file using OpenAI client
|
||||||
with BytesIO(test_content) as file_buffer:
|
with BytesIO(test_content) as file_buffer:
|
||||||
file_buffer.name = "openai_test.txt"
|
file_buffer.name = "openai_test.txt"
|
||||||
uploaded_file = client.files.create(file=file_buffer, purpose="assistants")
|
uploaded_file = client.files.create(file=file_buffer, purpose=purpose)
|
||||||
|
|
||||||
# Verify basic response structure
|
# Verify basic response structure
|
||||||
assert uploaded_file.id.startswith("file-")
|
assert uploaded_file.id.startswith("file-")
|
||||||
|
|
@ -50,16 +60,18 @@ def test_openai_client_basic_operations(openai_client):
|
||||||
retrieved_file = client.files.retrieve(uploaded_file.id)
|
retrieved_file = client.files.retrieve(uploaded_file.id)
|
||||||
assert retrieved_file.id == uploaded_file.id
|
assert retrieved_file.id == uploaded_file.id
|
||||||
|
|
||||||
# Retrieve file content - OpenAI client returns httpx Response object
|
# Retrieve file content
|
||||||
content_response = client.files.content(uploaded_file.id)
|
# OpenAI provider does not allow content retrieval with many `purpose` values
|
||||||
assert content_response.content == test_content
|
if not provider_type_is_openai:
|
||||||
|
content_response = client.files.content(uploaded_file.id)
|
||||||
|
assert content_response.content == test_content
|
||||||
|
|
||||||
# Delete file
|
# Delete file
|
||||||
delete_response = client.files.delete(uploaded_file.id)
|
delete_response = client.files.delete(uploaded_file.id)
|
||||||
assert delete_response.deleted is True
|
assert delete_response.deleted is True
|
||||||
|
|
||||||
# Retrieve file should fail
|
# Retrieve file should fail
|
||||||
with pytest.raises(NotFoundError, match="not found"):
|
with pytest.raises(NotFoundError):
|
||||||
client.files.retrieve(uploaded_file.id)
|
client.files.retrieve(uploaded_file.id)
|
||||||
|
|
||||||
# File should not be found in listing
|
# File should not be found in listing
|
||||||
|
|
@ -68,7 +80,7 @@ def test_openai_client_basic_operations(openai_client):
|
||||||
assert uploaded_file.id not in file_ids
|
assert uploaded_file.id not in file_ids
|
||||||
|
|
||||||
# Double delete should fail
|
# Double delete should fail
|
||||||
with pytest.raises(NotFoundError, match="not found"):
|
with pytest.raises(NotFoundError):
|
||||||
client.files.delete(uploaded_file.id)
|
client.files.delete(uploaded_file.id)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
|
@ -91,7 +103,7 @@ def test_expires_after(openai_client):
|
||||||
file_buffer.name = "expires_after.txt"
|
file_buffer.name = "expires_after.txt"
|
||||||
uploaded_file = client.files.create(
|
uploaded_file = client.files.create(
|
||||||
file=file_buffer,
|
file=file_buffer,
|
||||||
purpose="assistants",
|
purpose=purpose,
|
||||||
expires_after={"anchor": "created_at", "seconds": 4545},
|
expires_after={"anchor": "created_at", "seconds": 4545},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -126,7 +138,7 @@ def test_expires_after_requests(openai_client):
|
||||||
try:
|
try:
|
||||||
files = {"file": ("expires_after_with_requests.txt", BytesIO(b"expires_after via requests"))}
|
files = {"file": ("expires_after_with_requests.txt", BytesIO(b"expires_after via requests"))}
|
||||||
data = {
|
data = {
|
||||||
"purpose": "assistants",
|
"purpose": str(purpose),
|
||||||
"expires_after[anchor]": "created_at",
|
"expires_after[anchor]": "created_at",
|
||||||
"expires_after[seconds]": "4545",
|
"expires_after[seconds]": "4545",
|
||||||
}
|
}
|
||||||
|
|
@ -180,7 +192,7 @@ def test_files_authentication_isolation(mock_get_authenticated_user, llama_stack
|
||||||
|
|
||||||
with BytesIO(test_content_1) as file_buffer:
|
with BytesIO(test_content_1) as file_buffer:
|
||||||
file_buffer.name = "user1_file.txt"
|
file_buffer.name = "user1_file.txt"
|
||||||
user1_file = client.files.create(file=file_buffer, purpose="assistants")
|
user1_file = client.files.create(file=file_buffer, purpose=purpose)
|
||||||
|
|
||||||
# User 2 uploads a file
|
# User 2 uploads a file
|
||||||
mock_get_authenticated_user.return_value = user2
|
mock_get_authenticated_user.return_value = user2
|
||||||
|
|
@ -188,7 +200,7 @@ def test_files_authentication_isolation(mock_get_authenticated_user, llama_stack
|
||||||
|
|
||||||
with BytesIO(test_content_2) as file_buffer:
|
with BytesIO(test_content_2) as file_buffer:
|
||||||
file_buffer.name = "user2_file.txt"
|
file_buffer.name = "user2_file.txt"
|
||||||
user2_file = client.files.create(file=file_buffer, purpose="assistants")
|
user2_file = client.files.create(file=file_buffer, purpose=purpose)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# User 1 can see their own file
|
# User 1 can see their own file
|
||||||
|
|
@ -264,7 +276,9 @@ def test_files_authentication_isolation(mock_get_authenticated_user, llama_stack
|
||||||
|
|
||||||
|
|
||||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||||
def test_files_authentication_shared_attributes(mock_get_authenticated_user, llama_stack_client):
|
def test_files_authentication_shared_attributes(
|
||||||
|
mock_get_authenticated_user, llama_stack_client, provider_type_is_openai
|
||||||
|
):
|
||||||
"""Test access control with users having identical attributes."""
|
"""Test access control with users having identical attributes."""
|
||||||
client = llama_stack_client
|
client = llama_stack_client
|
||||||
|
|
||||||
|
|
@ -278,7 +292,7 @@ def test_files_authentication_shared_attributes(mock_get_authenticated_user, lla
|
||||||
|
|
||||||
with BytesIO(test_content) as file_buffer:
|
with BytesIO(test_content) as file_buffer:
|
||||||
file_buffer.name = "shared_attributes_file.txt"
|
file_buffer.name = "shared_attributes_file.txt"
|
||||||
shared_file = client.files.create(file=file_buffer, purpose="assistants")
|
shared_file = client.files.create(file=file_buffer, purpose=purpose)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# User B with identical attributes can access the file
|
# User B with identical attributes can access the file
|
||||||
|
|
@ -294,12 +308,13 @@ def test_files_authentication_shared_attributes(mock_get_authenticated_user, lla
|
||||||
assert retrieved_file.id == shared_file.id
|
assert retrieved_file.id == shared_file.id
|
||||||
|
|
||||||
# User B can access file content
|
# User B can access file content
|
||||||
content_response = client.files.content(shared_file.id)
|
if not provider_type_is_openai:
|
||||||
if isinstance(content_response, str):
|
content_response = client.files.content(shared_file.id)
|
||||||
content = bytes(content_response, "utf-8")
|
if isinstance(content_response, str):
|
||||||
else:
|
content = bytes(content_response, "utf-8")
|
||||||
content = content_response.content
|
else:
|
||||||
assert content == test_content
|
content = content_response.content
|
||||||
|
assert content == test_content
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
mock_get_authenticated_user.return_value = user_a
|
mock_get_authenticated_user.return_value = user_a
|
||||||
|
|
@ -321,7 +336,9 @@ def test_files_authentication_shared_attributes(mock_get_authenticated_user, lla
|
||||||
|
|
||||||
|
|
||||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||||
def test_files_authentication_anonymous_access(mock_get_authenticated_user, llama_stack_client):
|
def test_files_authentication_anonymous_access(
|
||||||
|
mock_get_authenticated_user, llama_stack_client, provider_type_is_openai
|
||||||
|
):
|
||||||
client = llama_stack_client
|
client = llama_stack_client
|
||||||
|
|
||||||
# Simulate anonymous user (no authentication)
|
# Simulate anonymous user (no authentication)
|
||||||
|
|
@ -331,7 +348,7 @@ def test_files_authentication_anonymous_access(mock_get_authenticated_user, llam
|
||||||
|
|
||||||
with BytesIO(test_content) as file_buffer:
|
with BytesIO(test_content) as file_buffer:
|
||||||
file_buffer.name = "anonymous_file.txt"
|
file_buffer.name = "anonymous_file.txt"
|
||||||
anonymous_file = client.files.create(file=file_buffer, purpose="assistants")
|
anonymous_file = client.files.create(file=file_buffer, purpose=purpose)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Anonymous user should be able to access their own uploaded file
|
# Anonymous user should be able to access their own uploaded file
|
||||||
|
|
@ -344,12 +361,13 @@ def test_files_authentication_anonymous_access(mock_get_authenticated_user, llam
|
||||||
assert retrieved_file.id == anonymous_file.id
|
assert retrieved_file.id == anonymous_file.id
|
||||||
|
|
||||||
# Can access file content
|
# Can access file content
|
||||||
content_response = client.files.content(anonymous_file.id)
|
if not provider_type_is_openai:
|
||||||
if isinstance(content_response, str):
|
content_response = client.files.content(anonymous_file.id)
|
||||||
content = bytes(content_response, "utf-8")
|
if isinstance(content_response, str):
|
||||||
else:
|
content = bytes(content_response, "utf-8")
|
||||||
content = content_response.content
|
else:
|
||||||
assert content == test_content
|
content = content_response.content
|
||||||
|
assert content == test_content
|
||||||
|
|
||||||
# Can delete the file
|
# Can delete the file
|
||||||
delete_response = client.files.delete(anonymous_file.id)
|
delete_response = client.files.delete(anonymous_file.id)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue