fix(expires_after): make sure multipart/form-data is properly parsed

This commit is contained in:
Ashwin Bharambe 2025-09-30 10:44:50 -07:00
parent 6cce553c93
commit 822c047ba0
3 changed files with 47 additions and 4 deletions

View file

@ -9,7 +9,7 @@ import uuid
from pathlib import Path from pathlib import Path
from typing import Annotated from typing import Annotated
from fastapi import File, Form, Response, UploadFile from fastapi import Depends, File, Form, Response, UploadFile
from llama_stack.apis.common.errors import ResourceNotFoundError from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.common.responses import Order from llama_stack.apis.common.responses import Order
@ -23,6 +23,7 @@ from llama_stack.apis.files import (
) )
from llama_stack.core.datatypes import AccessRule from llama_stack.core.datatypes import AccessRule
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.files.form_data import parse_expires_after
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
@ -87,7 +88,7 @@ class LocalfsFilesImpl(Files):
self, self,
file: Annotated[UploadFile, File()], file: Annotated[UploadFile, File()],
purpose: Annotated[OpenAIFilePurpose, Form()], purpose: Annotated[OpenAIFilePurpose, Form()],
expires_after: Annotated[ExpiresAfter | None, Form()] = None, expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
) -> OpenAIFileObject: ) -> OpenAIFileObject:
"""Upload a file that can be used across various endpoints.""" """Upload a file that can be used across various endpoints."""
if not self.sql_store: if not self.sql_store:

View file

@ -10,7 +10,7 @@ from typing import Annotated, Any
import boto3 import boto3
from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError
from fastapi import File, Form, Response, UploadFile from fastapi import Depends, File, Form, Response, UploadFile
from llama_stack.apis.common.errors import ResourceNotFoundError from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.common.responses import Order from llama_stack.apis.common.responses import Order
@ -23,6 +23,7 @@ from llama_stack.apis.files import (
OpenAIFilePurpose, OpenAIFilePurpose,
) )
from llama_stack.core.datatypes import AccessRule from llama_stack.core.datatypes import AccessRule
from llama_stack.providers.utils.files.form_data import parse_expires_after
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
@ -195,7 +196,7 @@ class S3FilesImpl(Files):
self, self,
file: Annotated[UploadFile, File()], file: Annotated[UploadFile, File()],
purpose: Annotated[OpenAIFilePurpose, Form()], purpose: Annotated[OpenAIFilePurpose, Form()],
expires_after: Annotated[ExpiresAfter | None, Form()] = None, expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
) -> OpenAIFileObject: ) -> OpenAIFileObject:
file_id = f"file-{uuid.uuid4().hex}" file_id = f"file-{uuid.uuid4().hex}"

View file

@ -0,0 +1,41 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from fastapi import Request
from llama_stack.apis.files import ExpiresAfter
async def parse_expires_after(request: Request) -> ExpiresAfter | None:
"""
Dependency to parse expires_after from multipart form data.
Handles both bracket notation (expires_after[anchor], expires_after[seconds])
and JSON string format.
"""
form = await request.form()
# Check for bracket notation first
anchor_key = "expires_after[anchor]"
seconds_key = "expires_after[seconds]"
if anchor_key in form and seconds_key in form:
anchor = form[anchor_key]
seconds = form[seconds_key]
return ExpiresAfter(anchor=anchor, seconds=int(seconds))
# Check for JSON string format
if "expires_after" in form:
value = form["expires_after"]
if isinstance(value, str):
import json
try:
data = json.loads(value)
return ExpiresAfter(**data)
except (json.JSONDecodeError, TypeError):
pass
return None