From 822c047ba001a2cb72e2d20943c1849e866e4c8f Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 30 Sep 2025 10:44:50 -0700 Subject: [PATCH] fix(expires_after): make sure multipart/form-data is properly parsed --- .../providers/inline/files/localfs/files.py | 5 ++- .../providers/remote/files/s3/files.py | 5 ++- .../providers/utils/files/form_data.py | 41 +++++++++++++++++++ 3 files changed, 47 insertions(+), 4 deletions(-) create mode 100644 llama_stack/providers/utils/files/form_data.py diff --git a/llama_stack/providers/inline/files/localfs/files.py b/llama_stack/providers/inline/files/localfs/files.py index 6e0c72de3..be1da291a 100644 --- a/llama_stack/providers/inline/files/localfs/files.py +++ b/llama_stack/providers/inline/files/localfs/files.py @@ -9,7 +9,7 @@ import uuid from pathlib import Path from typing import Annotated -from fastapi import File, Form, Response, UploadFile +from fastapi import Depends, File, Form, Response, UploadFile from llama_stack.apis.common.errors import ResourceNotFoundError from llama_stack.apis.common.responses import Order @@ -23,6 +23,7 @@ from llama_stack.apis.files import ( ) from llama_stack.core.datatypes import AccessRule from llama_stack.log import get_logger +from llama_stack.providers.utils.files.form_data import parse_expires_after from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl @@ -87,7 +88,7 @@ class LocalfsFilesImpl(Files): self, file: Annotated[UploadFile, File()], purpose: Annotated[OpenAIFilePurpose, Form()], - expires_after: Annotated[ExpiresAfter | None, Form()] = None, + expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None, ) -> OpenAIFileObject: """Upload a file that can be used across various endpoints.""" if not self.sql_store: diff --git a/llama_stack/providers/remote/files/s3/files.py b/llama_stack/providers/remote/files/s3/files.py index 8520f70b6..eb339b31e 100644 --- a/llama_stack/providers/remote/files/s3/files.py +++ b/llama_stack/providers/remote/files/s3/files.py @@ -10,7 +10,7 @@ from typing import Annotated, Any import boto3 from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError -from fastapi import File, Form, Response, UploadFile +from fastapi import Depends, File, Form, Response, UploadFile from llama_stack.apis.common.errors import ResourceNotFoundError from llama_stack.apis.common.responses import Order @@ -23,6 +23,7 @@ from llama_stack.apis.files import ( OpenAIFilePurpose, ) from llama_stack.core.datatypes import AccessRule +from llama_stack.providers.utils.files.form_data import parse_expires_after from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl @@ -195,7 +196,7 @@ class S3FilesImpl(Files): self, file: Annotated[UploadFile, File()], purpose: Annotated[OpenAIFilePurpose, Form()], - expires_after: Annotated[ExpiresAfter | None, Form()] = None, + expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None, ) -> OpenAIFileObject: file_id = f"file-{uuid.uuid4().hex}" diff --git a/llama_stack/providers/utils/files/form_data.py b/llama_stack/providers/utils/files/form_data.py new file mode 100644 index 000000000..0e3cfc8d1 --- /dev/null +++ b/llama_stack/providers/utils/files/form_data.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from fastapi import Request + +from llama_stack.apis.files import ExpiresAfter + + +async def parse_expires_after(request: Request) -> ExpiresAfter | None: + """ + Dependency to parse expires_after from multipart form data. + Handles both bracket notation (expires_after[anchor], expires_after[seconds]) + and JSON string format. + """ + form = await request.form() + + # Check for bracket notation first + anchor_key = "expires_after[anchor]" + seconds_key = "expires_after[seconds]" + + if anchor_key in form and seconds_key in form: + anchor = form[anchor_key] + seconds = form[seconds_key] + return ExpiresAfter(anchor=anchor, seconds=int(seconds)) + + # Check for JSON string format + if "expires_after" in form: + value = form["expires_after"] + if isinstance(value, str): + import json + + try: + data = json.loads(value) + return ExpiresAfter(**data) + except (json.JSONDecodeError, TypeError): + pass + + return None